From 0b4db3f0241527040fc5378fb2644d8141a97751 Mon Sep 17 00:00:00 2001 From: Matt McShane Date: Thu, 27 Jul 2023 15:59:25 -0400 Subject: [PATCH] Expand WF context locking to cover WFT responses Not holding this lock while responding to a workflow task allows in-flight tasks to race past each other which has led to history corruption in the case where responses are not deduplicated. Furthermore the correctness of resetting the event level while not holding this lock is unclear at best. --- internal/internal_public.go | 12 +- internal/internal_task_handlers.go | 56 +++-- .../internal_task_handlers_interfaces_test.go | 14 +- internal/internal_task_handlers_test.go | 222 +++++++++++++----- internal/internal_task_pollers.go | 60 +++-- internal/internal_worker.go | 12 +- internal/internal_worker_test.go | 5 +- 7 files changed, 267 insertions(+), 114 deletions(-) diff --git a/internal/internal_public.go b/internal/internal_public.go index a990d5dd1..03490c465 100644 --- a/internal/internal_public.go +++ b/internal/internal_public.go @@ -78,6 +78,8 @@ type ( // WorkflowTaskHandler represents workflow task handlers. WorkflowTaskHandler interface { + WorkflowContextManager + // Processes the workflow task // The response could be: // - RespondWorkflowTaskCompletedRequest @@ -85,8 +87,16 @@ type ( // - RespondQueryTaskCompletedRequest ProcessWorkflowTask( task *workflowTask, + ctx *workflowExecutionContextImpl, f workflowTaskHeartbeatFunc, - ) (response interface{}, resetter EventLevelResetter, err error) + ) (response interface{}, err error) + } + + WorkflowContextManager interface { + GetOrCreateWorkflowContext( + task *workflowservice.PollWorkflowTaskQueueResponse, + historyIterator HistoryIterator, + ) (*workflowExecutionContextImpl, error) } // ActivityTaskHandler represents activity task handlers. diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 6c9be3b63..719c8b146 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -474,11 +474,16 @@ func newWorkflowExecutionContext( return workflowContext } +// Lock acquires the lock on this context object, use Unlock to release the +// lock. func (w *workflowExecutionContextImpl) Lock() { w.mutex.Lock() } -func (w *workflowExecutionContextImpl) Unlock(err error) { +// ErrorCleanup clears this context's state and removes it from cache as +// necessary if the supplied error is not nil or if this context's internal +// error field has been set or if the workflow run has completed. +func (w *workflowExecutionContextImpl) ErrorCleanup(err error) { if err != nil || w.err != nil || w.isWorkflowCompleted || (w.wth.cache.MaxWorkflowCacheSize() <= 0 && !w.hasPendingLocalActivityWork()) { // TODO: in case of closed, it asumes the close command always succeed. need server side change to return @@ -496,7 +501,14 @@ func (w *workflowExecutionContextImpl) Unlock(err error) { // exited w.clearState() } +} +// Unlock performs any necessary error cleanup that might be needed due to +// workflow completion or context errors and then releases the lock on this +// context. This implementation is _not_ idempotent so it should only be called +// once and only if the context lock is held by the calling goroutine. +func (w *workflowExecutionContextImpl) Unlock() { + w.ErrorCleanup(nil) w.mutex.Unlock() } @@ -631,7 +643,11 @@ func (wth *workflowTaskHandlerImpl) createWorkflowContext(task *workflowservice. return newWorkflowExecutionContext(workflowInfo, wth), nil } -func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( +// GetOrCreateWorkflowContext finds an existing cached context object for this +// run ID or creates a new object, adds it to cache, and returns it. In all +// non-error cases the returned context object is in a locked state (i.e. +// WorkflowContext.Lock() has been called). +func (wth *workflowTaskHandlerImpl) GetOrCreateWorkflowContext( task *workflowservice.PollWorkflowTaskQueueResponse, historyIterator HistoryIterator, ) (workflowContext *workflowExecutionContextImpl, err error) { @@ -681,10 +697,12 @@ func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( wth.cache.removeWorkflowContext(runID) workflowContext.clearState() } - workflowContext.Unlock(err) + workflowContext.ErrorCleanup(err) + workflowContext.Unlock() workflowContext = nil } } + // If the workflow was not cached or the cache was stale. if workflowContext == nil { if !isFullHistory { @@ -711,7 +729,8 @@ func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( err = workflowContext.resetStateIfDestroyed(task, historyIterator) if err != nil { - workflowContext.Unlock(err) + workflowContext.ErrorCleanup(err) + workflowContext.Unlock() } return @@ -756,10 +775,11 @@ func (w *workflowExecutionContextImpl) resetStateIfDestroyed(task *workflowservi // ProcessWorkflowTask processes all the events of the workflow task. func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( workflowTask *workflowTask, + workflowContext *workflowExecutionContextImpl, heartbeatFunc workflowTaskHeartbeatFunc, -) (completeRequest interface{}, resetter EventLevelResetter, errRet error) { +) (completeRequest interface{}, errRet error) { if workflowTask == nil || workflowTask.task == nil { - return nil, nil, errors.New("nil workflow task provided") + return nil, errors.New("nil workflow task provided") } task := workflowTask.task if task.History == nil || len(task.History.Events) == 0 { @@ -768,11 +788,11 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( } } if task.Query == nil && len(task.History.Events) == 0 { - return nil, nil, errors.New("nil or empty history") + return nil, errors.New("nil or empty history") } if task.Query != nil && len(task.Queries) != 0 { - return nil, nil, errors.New("invalid query workflow task") + return nil, errors.New("invalid query workflow task") } runID := task.WorkflowExecution.GetRunId() @@ -786,19 +806,14 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( tagPreviousStartedEventID, task.GetPreviousStartedEventId()) }) - workflowContext, err := wth.getOrCreateWorkflowContext(task, workflowTask.historyIterator) - if err != nil { - return nil, nil, err - } - - defer func() { - workflowContext.Unlock(errRet) - }() - - var response interface{} + var ( + response interface{} + err error + heartbeatTimer *time.Timer + ) - var heartbeatTimer *time.Timer defer func() { + workflowContext.ErrorCleanup(errRet) if heartbeatTimer != nil { heartbeatTimer.Stop() } @@ -882,7 +897,6 @@ processWorkflowLoop: } errRet = err completeRequest = response - resetter = workflowContext.SetPreviousStartedEventID return } @@ -1250,8 +1264,6 @@ func (w *workflowExecutionContextImpl) SetCurrentTask(task *workflowservice.Poll } func (w *workflowExecutionContextImpl) SetPreviousStartedEventID(eventID int64) { - w.mutex.Lock() // This call can race against the cache eviction thread - see clearState - defer w.mutex.Unlock() w.previousStartedEventID = eventID } diff --git a/internal/internal_task_handlers_interfaces_test.go b/internal/internal_task_handlers_interfaces_test.go index e43578d9a..be9d9c723 100644 --- a/internal/internal_task_handlers_interfaces_test.go +++ b/internal/internal_task_handlers_interfaces_test.go @@ -53,11 +53,19 @@ type sampleWorkflowTaskHandler struct{} func (wth sampleWorkflowTaskHandler) ProcessWorkflowTask( workflowTask *workflowTask, + _ *workflowExecutionContextImpl, _ workflowTaskHeartbeatFunc, -) (interface{}, EventLevelResetter, error) { +) (interface{}, error) { return &workflowservice.RespondWorkflowTaskCompletedRequest{ TaskToken: workflowTask.task.TaskToken, - }, nil, nil + }, nil +} + +func (wth sampleWorkflowTaskHandler) GetOrCreateWorkflowContext( + task *workflowservice.PollWorkflowTaskQueueResponse, + historyIterator HistoryIterator, +) (*workflowExecutionContextImpl, error) { + return nil, nil } func newSampleWorkflowTaskHandler() *sampleWorkflowTaskHandler { @@ -115,7 +123,7 @@ func (s *PollLayerInterfacesTestSuite) TestProcessWorkflowTaskInterface() { // Process task and respond to the service. taskHandler := newSampleWorkflowTaskHandler() - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil) + request, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil, nil) completionRequest := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) s.NoError(err) diff --git a/internal/internal_task_handlers_test.go b/internal/internal_task_handlers_test.go index 9436ec06d..5b671a782 100644 --- a/internal/internal_task_handlers_test.go +++ b/internal/internal_task_handlers_test.go @@ -143,6 +143,12 @@ func (t *TaskHandlersTestSuite) SetupSuite() { t.namespace = "default" } +func (t *TaskHandlersTestSuite) TearDownTest() { + if cache := *sharedWorkerCachePtr.workflowCache; cache != nil { + cache.Clear() + } +} + func TestTaskHandlersTestSuite(t *testing.T) { suite.Run(t, &TaskHandlersTestSuite{ registry: newRegistry(), @@ -514,13 +520,25 @@ func (t *TaskHandlersTestSuite) getTestWorkerExecutionParams() workerExecutionPa } } +func (t *TaskHandlersTestSuite) mustWorkflowContextImpl( + task *workflowTask, + cm WorkflowContextManager, +) *workflowExecutionContextImpl { + wfctx, err := cm.GetOrCreateWorkflowContext(task.task, task.historyIterator) + t.Require().NoError(err) + return wfctx +} + func (t *TaskHandlersTestSuite) testWorkflowTaskWorkflowExecutionStartedHelper(params workerExecutionParameters) { testEvents := []*historypb.HistoryEvent{ createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: testWorkflowTaskTaskqueue}}), } task := createWorkflowTask(testEvents, 0, "HelloWorld_Workflow") taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -561,7 +579,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_BinaryChecksum() { task := createWorkflowTask(testEvents, 8, "BinaryChecksumWorkflow") params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) @@ -589,7 +610,10 @@ func (t *TaskHandlersTestSuite) TestRespondsToWFTWithWorkerBinaryID() { params := t.getTestWorkerExecutionParams() params.WorkerBuildID = workerBuildID taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -618,7 +642,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_ActivityTaskScheduled() { task := createWorkflowTask(testEvents[0:3], 0, "HelloWorld_Workflow") params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) @@ -629,7 +656,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_ActivityTaskScheduled() { // Schedule an activity and see if we complete workflow, Having only one last command. task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response = request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -665,7 +695,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_Sticky() { task := createWorkflowTask(testEvents[0:1], 0, "HelloWorld_Workflow") task.StartedEventId = 1 task.WorkflowExecution = execution - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -676,7 +709,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_Sticky() { // then check the current state using query task task = createQueryTask([]*historypb.HistoryEvent{}, 6, "HelloWorld_Workflow", queryType) task.WorkflowExecution = execution - queryResp, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + queryResp, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.verifyQueryResult(queryResp, "waiting-activity-result") } @@ -704,30 +740,45 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_NonSticky() { // query after first workflow task (notice the previousStartEventID is always the last eventID for query task) task := createQueryTask(testEvents[0:3], 3, "HelloWorld_Workflow", queryType) taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "waiting-activity-result") // query after activity task complete but before second workflow task started task = createQueryTask(testEvents[0:7], 7, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "waiting-activity-result") // query after second workflow task task = createQueryTask(testEvents[0:8], 8, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "done") // query after second workflow task with extra events task = createQueryTask(testEvents[0:9], 9, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "done") task = createQueryTask(testEvents[0:9], 9, "HelloWorld_Workflow", "invalid-query-type") taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NotNil(response) queryResp, ok := response.(*workflowservice.RespondQueryTaskCompletedRequest) t.True(ok) @@ -769,8 +820,11 @@ func (t *TaskHandlersTestSuite) TestCacheEvictionWhenErrorOccurs() { task := createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) @@ -791,21 +845,27 @@ func (t *TaskHandlersTestSuite) TestWithMissingHistoryEvents() { } params := t.getTestWorkerExecutionParams() params.WorkflowPanicPolicy = BlockWorkflow + t.Require().Equal(0, params.cache.getWorkflowCache().Size(), + "Suite teardown should have reset cache state") for _, startEventID := range []int64{0, 3} { - taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - task := createWorkflowTask(testEvents, startEventID, "HelloWorld_Workflow") - // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() - // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) - - t.Error(err) - t.Nil(request) - t.Contains(err.Error(), "missing history events") + t.Run(fmt.Sprintf("startEventID=%v", startEventID), func() { + taskHandler := newWorkflowTaskHandler(params, nil, t.registry) + task := createWorkflowTask(testEvents, startEventID, "HelloWorld_Workflow") + // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() + // will fail as it can't find laTunnel in newWorkerCache(). + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() + + t.Error(err) + t.Nil(request) + t.Contains(err.Error(), "missing history events") - // There should be nothing in the cache. - t.EqualValues(params.cache.getWorkflowCache().Size(), 0) + t.Equal(0, params.cache.getWorkflowCache().Size(), "cache should be empty") + }) } } @@ -848,8 +908,11 @@ func (t *TaskHandlersTestSuite) TestWithTruncatedHistory() { task.StartedEventId = tc.startedEventID // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() if tc.isResultErr { t.Error(err, "testcase %v failed", i) @@ -908,7 +971,10 @@ func (t *TaskHandlersTestSuite) testSideEffectDeferHelper(cacheSize int) { taskHandler := newWorkflowTaskHandler(params, nil, t.registry) task := createWorkflowTask(testEvents, 0, workflowName) - _, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + _, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Nil(err) // Make sure the workflow coroutine has exited. @@ -940,7 +1006,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { params.WorkerStopChannel = stopC taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) // there should be no error as the history events matched the commands. t.NoError(err) @@ -951,8 +1020,11 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, stopC, nil) - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, stopC, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) t.Contains(err.Error(), "nondeterministic") @@ -962,7 +1034,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { params.WorkflowPanicPolicy = FailWorkflow failOnNondeterminismTaskHandler := newWorkflowTaskHandler(params, nil, t.registry) task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = failOnNondeterminismTaskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, failOnNondeterminismTaskHandler) + request, err = failOnNondeterminismTaskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() // When FailWorkflow policy is set, task handler does not return an error, // because it will indicate non determinism in the request. t.NoError(err) @@ -980,7 +1055,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { // now with different package name to activity type testEvents[4].GetActivityTaskScheduledEventAttributes().ActivityType.Name = "new-package.Greeter_Activity" task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) } @@ -997,7 +1075,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_WorkflowReturnsPanicError() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) r, ok := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1019,7 +1100,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_WorkflowPanics() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - _, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + _, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) _, ok := err.(*workflowPanicError) t.True(ok) @@ -1063,7 +1147,10 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowInfo() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) r, ok := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1097,9 +1184,12 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_InvalidQueryTask() { task := createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") task.Query = &querypb.WorkflowQuery{} task.Queries = map[string]*querypb.WorkflowQuery{"query_id": {}} - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) // query and queries are both specified so this is an invalid task - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) @@ -1139,7 +1229,10 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1159,7 +1252,10 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { secondTask := createWorkflowTaskWithQueries(testEvents, 3, "QuerySignalWorkflow", queries, false) secondTask.WorkflowExecution.RunId = task.WorkflowExecution.RunId - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: secondTask}, nil) + wftask = workflowTask{task: secondTask} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response = request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1182,10 +1278,12 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { } func (t *TaskHandlersTestSuite) assertQueryResultsEqual(expected map[string]*querypb.WorkflowQueryResult, actual map[string]*querypb.WorkflowQueryResult) { + t.T().Helper() t.Equal(len(expected), len(actual)) for expectedID, expectedResult := range expected { t.Contains(actual, expectedID) - t.True(proto.Equal(expectedResult, actual[expectedID])) + t.True(proto.Equal(expectedResult, actual[expectedID]), + "expected %v = %v", expectedResult, actual[expectedID]) } } @@ -1200,7 +1298,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_CancelActivityBeforeSent() { params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1232,7 +1333,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_PageToken() { }, } taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: historyIterator}, nil) + wftask := workflowTask{task: task, historyIterator: historyIterator} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1346,9 +1450,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_Messages() { }, } taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{ - task: task, historyIterator: historyIterator, - }, nil) + wftask := workflowTask{task: task, historyIterator: historyIterator} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1426,13 +1531,9 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_Workflow() { laResultCh := make(chan *localActivityResult) laRetryCh := make(chan *localActivityTask) - response, _, err := taskHandler.ProcessWorkflowTask( - &workflowTask{ - task: task, - laResultCh: laResultCh, - laRetryCh: laRetryCh, - }, - nil) + wftask := workflowTask{task: task, laResultCh: laResultCh, laRetryCh: laRetryCh} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) t.NotNil(response) t.NoError(err) asWFTComplete := response.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1511,14 +1612,15 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_WorkflowTaskHeartbeatFail }() laResultCh := make(chan *localActivityResult) - response, _, err := taskHandler.ProcessWorkflowTask( - &workflowTask{ - task: task, - laResultCh: laResultCh, - }, + wftask := workflowTask{task: task, laResultCh: laResultCh} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, err := taskHandler.ProcessWorkflowTask( + &wftask, + wfctx, func(response interface{}, startTime time.Time) (*workflowTask, error) { return nil, serviceerror.NewNotFound("Intentional wft heartbeat error") }) + wfctx.Unlock() t.Nil(response) t.Error(err) @@ -2242,7 +2344,7 @@ func TestResetIfDestroyedTaskPrep(t *testing.T) { require.EqualValues(t, 0, cache.Size()) // cache is empty so this should miss and build a new context with a // full history - _, err := weci.wth.getOrCreateWorkflowContext(task, histIter) + _, err := weci.wth.GetOrCreateWorkflowContext(task, histIter) require.NoError(t, err) require.Len(t, task.History.Events, len(fullHist.Events), diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index fb77ab1a3..6302296f4 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -94,6 +94,7 @@ type ( identity string service workflowservice.WorkflowServiceClient taskHandler WorkflowTaskHandler + contextManager WorkflowContextManager logger log.Logger dataConverter converter.DataConverter failureConverter converter.FailureConverter @@ -265,6 +266,7 @@ func (bp *basePoller) getCapabilities() *workflowservice.GetSystemInfoResponse_C // newWorkflowTaskPoller creates a new workflow task poller which must have a one to one relationship to workflow worker func newWorkflowTaskPoller( taskHandler WorkflowTaskHandler, + contextManager WorkflowContextManager, service workflowservice.WorkflowServiceClient, params workerExecutionParameters, ) *workflowTaskPoller { @@ -281,6 +283,7 @@ func newWorkflowTaskPoller( taskQueueName: params.TaskQueue, identity: params.Identity, taskHandler: taskHandler, + contextManager: contextManager, logger: params.Logger, dataConverter: params.DataConverter, failureConverter: params.FailureConverter, @@ -335,43 +338,52 @@ func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error { // close doneCh so local activity worker won't get blocked forever when trying to send back result to laResultCh. defer close(doneCh) + wfctx, err := wtp.contextManager.GetOrCreateWorkflowContext(task.task, task.historyIterator) + if err != nil { + return err + } + defer wfctx.Unlock() + for { - var response *workflowservice.RespondWorkflowTaskCompletedResponse startTime := time.Now() - task.doneCh = doneCh - task.laResultCh = laResultCh - task.laRetryCh = laRetryCh - completedRequest, resetter, err := wtp.taskHandler.ProcessWorkflowTask( - task, - func(response interface{}, startTime time.Time) (*workflowTask, error) { - wtp.logger.Debug("Force RespondWorkflowTaskCompleted.", "TaskStartedEventID", task.task.GetStartedEventId()) - heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime) - if err != nil { - return nil, err - } - if heartbeatResponse == nil || heartbeatResponse.WorkflowTask == nil { - return nil, nil - } - task := wtp.toWorkflowTask(heartbeatResponse.WorkflowTask) - task.doneCh = doneCh - task.laResultCh = laResultCh - task.laRetryCh = laRetryCh - return task, nil - }, - ) + completedRequest, err := func() (_ interface{}, retErr error) { + defer func() { wfctx.ErrorCleanup(retErr) }() + task.doneCh = doneCh + task.laResultCh = laResultCh + task.laRetryCh = laRetryCh + return wtp.taskHandler.ProcessWorkflowTask( + task, + wfctx, + func(response interface{}, startTime time.Time) (*workflowTask, error) { + wtp.logger.Debug("Force RespondWorkflowTaskCompleted.", "TaskStartedEventID", task.task.GetStartedEventId()) + heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime) + if err != nil { + return nil, err + } + if heartbeatResponse == nil || heartbeatResponse.WorkflowTask == nil { + return nil, nil + } + task := wtp.toWorkflowTask(heartbeatResponse.WorkflowTask) + task.doneCh = doneCh + task.laResultCh = laResultCh + task.laRetryCh = laRetryCh + return task, nil + }, + ) + }() if completedRequest == nil && err == nil { return nil } if _, ok := err.(workflowTaskHeartbeatError); ok { return err } - response, err = wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime) + response, err := wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime) if err != nil { return err } if eventLevel := response.GetResetHistoryEventId(); eventLevel != 0 { - resetter(eventLevel) + wfctx.SetPreviousStartedEventID(eventLevel) } if response == nil || response.WorkflowTask == nil { diff --git a/internal/internal_worker.go b/internal/internal_worker.go index e77b65203..bb47760b1 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -290,18 +290,19 @@ func newWorkflowWorkerInternal(service workflowservice.WorkflowServiceClient, pa } else { taskHandler = newWorkflowTaskHandler(params, ppMgr, registry) } - return newWorkflowTaskWorkerInternal(taskHandler, service, params, workerStopChannel, registry.interceptors) + return newWorkflowTaskWorkerInternal(taskHandler, taskHandler, service, params, workerStopChannel, registry.interceptors) } func newWorkflowTaskWorkerInternal( taskHandler WorkflowTaskHandler, + contextManager WorkflowContextManager, service workflowservice.WorkflowServiceClient, params workerExecutionParameters, stopC chan struct{}, interceptors []WorkerInterceptor, ) *workflowWorker { ensureRequiredParams(¶ms) - poller := newWorkflowTaskPoller(taskHandler, service, params) + poller := newWorkflowTaskPoller(taskHandler, contextManager, service, params) worker := newBaseWorker(baseWorkerOptions{ pollerCount: params.MaxConcurrentWorkflowTaskQueuePollers, pollerRate: defaultPollerRate, @@ -1367,7 +1368,12 @@ func (aw *WorkflowReplayer) replayWorkflowHistory(logger log.Logger, service wor }, } taskHandler := newWorkflowTaskHandler(params, nil, aw.registry) - resp, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: iterator}, nil) + wfctx, err := taskHandler.GetOrCreateWorkflowContext(task, iterator) + defer wfctx.Unlock() + if err != nil { + return err + } + resp, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: iterator}, wfctx, nil) if err != nil { return err } diff --git a/internal/internal_worker_test.go b/internal/internal_worker_test.go index 215d3d709..b8e868049 100644 --- a/internal/internal_worker_test.go +++ b/internal/internal_worker_test.go @@ -1623,7 +1623,10 @@ func (s *internalWorkerTestSuite) testWorkflowTaskHandlerHelper(params workerExe } r := newWorkflowTaskHandler(params, nil, s.registry) - _, _, err := r.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wfctx, err := r.GetOrCreateWorkflowContext(task, nil) + s.NoError(err) + _, err = r.ProcessWorkflowTask(&workflowTask{task: task}, wfctx, nil) + wfctx.Unlock() s.NoError(err) }