diff --git a/internal/internal_public.go b/internal/internal_public.go index a990d5dd1..03490c465 100644 --- a/internal/internal_public.go +++ b/internal/internal_public.go @@ -78,6 +78,8 @@ type ( // WorkflowTaskHandler represents workflow task handlers. WorkflowTaskHandler interface { + WorkflowContextManager + // Processes the workflow task // The response could be: // - RespondWorkflowTaskCompletedRequest @@ -85,8 +87,16 @@ type ( // - RespondQueryTaskCompletedRequest ProcessWorkflowTask( task *workflowTask, + ctx *workflowExecutionContextImpl, f workflowTaskHeartbeatFunc, - ) (response interface{}, resetter EventLevelResetter, err error) + ) (response interface{}, err error) + } + + WorkflowContextManager interface { + GetOrCreateWorkflowContext( + task *workflowservice.PollWorkflowTaskQueueResponse, + historyIterator HistoryIterator, + ) (*workflowExecutionContextImpl, error) } // ActivityTaskHandler represents activity task handlers. diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 6c9be3b63..719c8b146 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -474,11 +474,16 @@ func newWorkflowExecutionContext( return workflowContext } +// Lock acquires the lock on this context object, use Unlock to release the +// lock. func (w *workflowExecutionContextImpl) Lock() { w.mutex.Lock() } -func (w *workflowExecutionContextImpl) Unlock(err error) { +// ErrorCleanup clears this context's state and removes it from cache as +// necessary if the supplied error is not nil or if this context's internal +// error field has been set or if the workflow run has completed. +func (w *workflowExecutionContextImpl) ErrorCleanup(err error) { if err != nil || w.err != nil || w.isWorkflowCompleted || (w.wth.cache.MaxWorkflowCacheSize() <= 0 && !w.hasPendingLocalActivityWork()) { // TODO: in case of closed, it asumes the close command always succeed. need server side change to return @@ -496,7 +501,14 @@ func (w *workflowExecutionContextImpl) Unlock(err error) { // exited w.clearState() } +} +// Unlock performs any necessary error cleanup that might be needed due to +// workflow completion or context errors and then releases the lock on this +// context. This implementation is _not_ idempotent so it should only be called +// once and only if the context lock is held by the calling goroutine. +func (w *workflowExecutionContextImpl) Unlock() { + w.ErrorCleanup(nil) w.mutex.Unlock() } @@ -631,7 +643,11 @@ func (wth *workflowTaskHandlerImpl) createWorkflowContext(task *workflowservice. return newWorkflowExecutionContext(workflowInfo, wth), nil } -func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( +// GetOrCreateWorkflowContext finds an existing cached context object for this +// run ID or creates a new object, adds it to cache, and returns it. In all +// non-error cases the returned context object is in a locked state (i.e. +// WorkflowContext.Lock() has been called). +func (wth *workflowTaskHandlerImpl) GetOrCreateWorkflowContext( task *workflowservice.PollWorkflowTaskQueueResponse, historyIterator HistoryIterator, ) (workflowContext *workflowExecutionContextImpl, err error) { @@ -681,10 +697,12 @@ func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( wth.cache.removeWorkflowContext(runID) workflowContext.clearState() } - workflowContext.Unlock(err) + workflowContext.ErrorCleanup(err) + workflowContext.Unlock() workflowContext = nil } } + // If the workflow was not cached or the cache was stale. if workflowContext == nil { if !isFullHistory { @@ -711,7 +729,8 @@ func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( err = workflowContext.resetStateIfDestroyed(task, historyIterator) if err != nil { - workflowContext.Unlock(err) + workflowContext.ErrorCleanup(err) + workflowContext.Unlock() } return @@ -756,10 +775,11 @@ func (w *workflowExecutionContextImpl) resetStateIfDestroyed(task *workflowservi // ProcessWorkflowTask processes all the events of the workflow task. func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( workflowTask *workflowTask, + workflowContext *workflowExecutionContextImpl, heartbeatFunc workflowTaskHeartbeatFunc, -) (completeRequest interface{}, resetter EventLevelResetter, errRet error) { +) (completeRequest interface{}, errRet error) { if workflowTask == nil || workflowTask.task == nil { - return nil, nil, errors.New("nil workflow task provided") + return nil, errors.New("nil workflow task provided") } task := workflowTask.task if task.History == nil || len(task.History.Events) == 0 { @@ -768,11 +788,11 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( } } if task.Query == nil && len(task.History.Events) == 0 { - return nil, nil, errors.New("nil or empty history") + return nil, errors.New("nil or empty history") } if task.Query != nil && len(task.Queries) != 0 { - return nil, nil, errors.New("invalid query workflow task") + return nil, errors.New("invalid query workflow task") } runID := task.WorkflowExecution.GetRunId() @@ -786,19 +806,14 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask( tagPreviousStartedEventID, task.GetPreviousStartedEventId()) }) - workflowContext, err := wth.getOrCreateWorkflowContext(task, workflowTask.historyIterator) - if err != nil { - return nil, nil, err - } - - defer func() { - workflowContext.Unlock(errRet) - }() - - var response interface{} + var ( + response interface{} + err error + heartbeatTimer *time.Timer + ) - var heartbeatTimer *time.Timer defer func() { + workflowContext.ErrorCleanup(errRet) if heartbeatTimer != nil { heartbeatTimer.Stop() } @@ -882,7 +897,6 @@ processWorkflowLoop: } errRet = err completeRequest = response - resetter = workflowContext.SetPreviousStartedEventID return } @@ -1250,8 +1264,6 @@ func (w *workflowExecutionContextImpl) SetCurrentTask(task *workflowservice.Poll } func (w *workflowExecutionContextImpl) SetPreviousStartedEventID(eventID int64) { - w.mutex.Lock() // This call can race against the cache eviction thread - see clearState - defer w.mutex.Unlock() w.previousStartedEventID = eventID } diff --git a/internal/internal_task_handlers_interfaces_test.go b/internal/internal_task_handlers_interfaces_test.go index e43578d9a..be9d9c723 100644 --- a/internal/internal_task_handlers_interfaces_test.go +++ b/internal/internal_task_handlers_interfaces_test.go @@ -53,11 +53,19 @@ type sampleWorkflowTaskHandler struct{} func (wth sampleWorkflowTaskHandler) ProcessWorkflowTask( workflowTask *workflowTask, + _ *workflowExecutionContextImpl, _ workflowTaskHeartbeatFunc, -) (interface{}, EventLevelResetter, error) { +) (interface{}, error) { return &workflowservice.RespondWorkflowTaskCompletedRequest{ TaskToken: workflowTask.task.TaskToken, - }, nil, nil + }, nil +} + +func (wth sampleWorkflowTaskHandler) GetOrCreateWorkflowContext( + task *workflowservice.PollWorkflowTaskQueueResponse, + historyIterator HistoryIterator, +) (*workflowExecutionContextImpl, error) { + return nil, nil } func newSampleWorkflowTaskHandler() *sampleWorkflowTaskHandler { @@ -115,7 +123,7 @@ func (s *PollLayerInterfacesTestSuite) TestProcessWorkflowTaskInterface() { // Process task and respond to the service. taskHandler := newSampleWorkflowTaskHandler() - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil) + request, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil, nil) completionRequest := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) s.NoError(err) diff --git a/internal/internal_task_handlers_test.go b/internal/internal_task_handlers_test.go index 9436ec06d..5b671a782 100644 --- a/internal/internal_task_handlers_test.go +++ b/internal/internal_task_handlers_test.go @@ -143,6 +143,12 @@ func (t *TaskHandlersTestSuite) SetupSuite() { t.namespace = "default" } +func (t *TaskHandlersTestSuite) TearDownTest() { + if cache := *sharedWorkerCachePtr.workflowCache; cache != nil { + cache.Clear() + } +} + func TestTaskHandlersTestSuite(t *testing.T) { suite.Run(t, &TaskHandlersTestSuite{ registry: newRegistry(), @@ -514,13 +520,25 @@ func (t *TaskHandlersTestSuite) getTestWorkerExecutionParams() workerExecutionPa } } +func (t *TaskHandlersTestSuite) mustWorkflowContextImpl( + task *workflowTask, + cm WorkflowContextManager, +) *workflowExecutionContextImpl { + wfctx, err := cm.GetOrCreateWorkflowContext(task.task, task.historyIterator) + t.Require().NoError(err) + return wfctx +} + func (t *TaskHandlersTestSuite) testWorkflowTaskWorkflowExecutionStartedHelper(params workerExecutionParameters) { testEvents := []*historypb.HistoryEvent{ createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: testWorkflowTaskTaskqueue}}), } task := createWorkflowTask(testEvents, 0, "HelloWorld_Workflow") taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -561,7 +579,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_BinaryChecksum() { task := createWorkflowTask(testEvents, 8, "BinaryChecksumWorkflow") params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) @@ -589,7 +610,10 @@ func (t *TaskHandlersTestSuite) TestRespondsToWFTWithWorkerBinaryID() { params := t.getTestWorkerExecutionParams() params.WorkerBuildID = workerBuildID taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -618,7 +642,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_ActivityTaskScheduled() { task := createWorkflowTask(testEvents[0:3], 0, "HelloWorld_Workflow") params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) @@ -629,7 +656,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_ActivityTaskScheduled() { // Schedule an activity and see if we complete workflow, Having only one last command. task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response = request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -665,7 +695,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_Sticky() { task := createWorkflowTask(testEvents[0:1], 0, "HelloWorld_Workflow") task.StartedEventId = 1 task.WorkflowExecution = execution - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -676,7 +709,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_Sticky() { // then check the current state using query task task = createQueryTask([]*historypb.HistoryEvent{}, 6, "HelloWorld_Workflow", queryType) task.WorkflowExecution = execution - queryResp, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + queryResp, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.verifyQueryResult(queryResp, "waiting-activity-result") } @@ -704,30 +740,45 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_QueryWorkflow_NonSticky() { // query after first workflow task (notice the previousStartEventID is always the last eventID for query task) task := createQueryTask(testEvents[0:3], 3, "HelloWorld_Workflow", queryType) taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "waiting-activity-result") // query after activity task complete but before second workflow task started task = createQueryTask(testEvents[0:7], 7, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "waiting-activity-result") // query after second workflow task task = createQueryTask(testEvents[0:8], 8, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "done") // query after second workflow task with extra events task = createQueryTask(testEvents[0:9], 9, "HelloWorld_Workflow", queryType) taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.verifyQueryResult(response, "done") task = createQueryTask(testEvents[0:9], 9, "HelloWorld_Workflow", "invalid-query-type") taskHandler = newWorkflowTaskHandler(params, nil, t.registry) - response, _, _ = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + response, _ = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NotNil(response) queryResp, ok := response.(*workflowservice.RespondQueryTaskCompletedRequest) t.True(ok) @@ -769,8 +820,11 @@ func (t *TaskHandlersTestSuite) TestCacheEvictionWhenErrorOccurs() { task := createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) @@ -791,21 +845,27 @@ func (t *TaskHandlersTestSuite) TestWithMissingHistoryEvents() { } params := t.getTestWorkerExecutionParams() params.WorkflowPanicPolicy = BlockWorkflow + t.Require().Equal(0, params.cache.getWorkflowCache().Size(), + "Suite teardown should have reset cache state") for _, startEventID := range []int64{0, 3} { - taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - task := createWorkflowTask(testEvents, startEventID, "HelloWorld_Workflow") - // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() - // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) - - t.Error(err) - t.Nil(request) - t.Contains(err.Error(), "missing history events") + t.Run(fmt.Sprintf("startEventID=%v", startEventID), func() { + taskHandler := newWorkflowTaskHandler(params, nil, t.registry) + task := createWorkflowTask(testEvents, startEventID, "HelloWorld_Workflow") + // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() + // will fail as it can't find laTunnel in newWorkerCache(). + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() + + t.Error(err) + t.Nil(request) + t.Contains(err.Error(), "missing history events") - // There should be nothing in the cache. - t.EqualValues(params.cache.getWorkflowCache().Size(), 0) + t.Equal(0, params.cache.getWorkflowCache().Size(), "cache should be empty") + }) } } @@ -848,8 +908,11 @@ func (t *TaskHandlersTestSuite) TestWithTruncatedHistory() { task.StartedEventId = tc.startedEventID // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() if tc.isResultErr { t.Error(err, "testcase %v failed", i) @@ -908,7 +971,10 @@ func (t *TaskHandlersTestSuite) testSideEffectDeferHelper(cacheSize int) { taskHandler := newWorkflowTaskHandler(params, nil, t.registry) task := createWorkflowTask(testEvents, 0, workflowName) - _, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + _, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Nil(err) // Make sure the workflow coroutine has exited. @@ -940,7 +1006,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { params.WorkerStopChannel = stopC taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) // there should be no error as the history events matched the commands. t.NoError(err) @@ -951,8 +1020,11 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") // newWorkflowTaskWorkerInternal will set the laTunnel in taskHandler, without it, ProcessWorkflowTask() // will fail as it can't find laTunnel in newWorkerCache(). - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, stopC, nil) - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, stopC, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) t.Contains(err.Error(), "nondeterministic") @@ -962,7 +1034,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { params.WorkflowPanicPolicy = FailWorkflow failOnNondeterminismTaskHandler := newWorkflowTaskHandler(params, nil, t.registry) task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = failOnNondeterminismTaskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, failOnNondeterminismTaskHandler) + request, err = failOnNondeterminismTaskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() // When FailWorkflow policy is set, task handler does not return an error, // because it will indicate non determinism in the request. t.NoError(err) @@ -980,7 +1055,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_NondeterministicDetection() { // now with different package name to activity type testEvents[4].GetActivityTaskScheduledEventAttributes().ActivityType.Name = "new-package.Greeter_Activity" task = createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask = workflowTask{task: task} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) } @@ -997,7 +1075,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_WorkflowReturnsPanicError() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) r, ok := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1019,7 +1100,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_WorkflowPanics() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - _, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + _, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) _, ok := err.(*workflowPanicError) t.True(ok) @@ -1063,7 +1147,10 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowInfo() { params.WorkflowPanicPolicy = BlockWorkflow taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.NoError(err) t.NotNil(request) r, ok := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1097,9 +1184,12 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_InvalidQueryTask() { task := createWorkflowTask(testEvents, 3, "HelloWorld_Workflow") task.Query = &querypb.WorkflowQuery{} task.Queries = map[string]*querypb.WorkflowQuery{"query_id": {}} - newWorkflowTaskWorkerInternal(taskHandler, t.service, params, make(chan struct{}), nil) + newWorkflowTaskWorkerInternal(taskHandler, taskHandler, t.service, params, make(chan struct{}), nil) // query and queries are both specified so this is an invalid task - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() t.Error(err) t.Nil(request) @@ -1139,7 +1229,10 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1159,7 +1252,10 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { secondTask := createWorkflowTaskWithQueries(testEvents, 3, "QuerySignalWorkflow", queries, false) secondTask.WorkflowExecution.RunId = task.WorkflowExecution.RunId - request, _, err = taskHandler.ProcessWorkflowTask(&workflowTask{task: secondTask}, nil) + wftask = workflowTask{task: secondTask} + wfctx = t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err = taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response = request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1182,10 +1278,12 @@ func (t *TaskHandlersTestSuite) TestConsistentQuery_Success() { } func (t *TaskHandlersTestSuite) assertQueryResultsEqual(expected map[string]*querypb.WorkflowQueryResult, actual map[string]*querypb.WorkflowQueryResult) { + t.T().Helper() t.Equal(len(expected), len(actual)) for expectedID, expectedResult := range expected { t.Contains(actual, expectedID) - t.True(proto.Equal(expectedResult, actual[expectedID])) + t.True(proto.Equal(expectedResult, actual[expectedID]), + "expected %v = %v", expectedResult, actual[expectedID]) } } @@ -1200,7 +1298,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_CancelActivityBeforeSent() { params := t.getTestWorkerExecutionParams() taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wftask := workflowTask{task: task} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1232,7 +1333,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_PageToken() { }, } taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: historyIterator}, nil) + wftask := workflowTask{task: task, historyIterator: historyIterator} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1346,9 +1450,10 @@ func (t *TaskHandlersTestSuite) TestWorkflowTask_Messages() { }, } taskHandler := newWorkflowTaskHandler(params, nil, t.registry) - request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{ - task: task, historyIterator: historyIterator, - }, nil) + wftask := workflowTask{task: task, historyIterator: historyIterator} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + request, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock() response := request.(*workflowservice.RespondWorkflowTaskCompletedRequest) t.NoError(err) t.NotNil(response) @@ -1426,13 +1531,9 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_Workflow() { laResultCh := make(chan *localActivityResult) laRetryCh := make(chan *localActivityTask) - response, _, err := taskHandler.ProcessWorkflowTask( - &workflowTask{ - task: task, - laResultCh: laResultCh, - laRetryCh: laRetryCh, - }, - nil) + wftask := workflowTask{task: task, laResultCh: laResultCh, laRetryCh: laRetryCh} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) t.NotNil(response) t.NoError(err) asWFTComplete := response.(*workflowservice.RespondWorkflowTaskCompletedRequest) @@ -1511,14 +1612,15 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_WorkflowTaskHeartbeatFail }() laResultCh := make(chan *localActivityResult) - response, _, err := taskHandler.ProcessWorkflowTask( - &workflowTask{ - task: task, - laResultCh: laResultCh, - }, + wftask := workflowTask{task: task, laResultCh: laResultCh} + wfctx := t.mustWorkflowContextImpl(&wftask, taskHandler) + response, err := taskHandler.ProcessWorkflowTask( + &wftask, + wfctx, func(response interface{}, startTime time.Time) (*workflowTask, error) { return nil, serviceerror.NewNotFound("Intentional wft heartbeat error") }) + wfctx.Unlock() t.Nil(response) t.Error(err) @@ -2242,7 +2344,7 @@ func TestResetIfDestroyedTaskPrep(t *testing.T) { require.EqualValues(t, 0, cache.Size()) // cache is empty so this should miss and build a new context with a // full history - _, err := weci.wth.getOrCreateWorkflowContext(task, histIter) + _, err := weci.wth.GetOrCreateWorkflowContext(task, histIter) require.NoError(t, err) require.Len(t, task.History.Events, len(fullHist.Events), diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index fb77ab1a3..6302296f4 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -94,6 +94,7 @@ type ( identity string service workflowservice.WorkflowServiceClient taskHandler WorkflowTaskHandler + contextManager WorkflowContextManager logger log.Logger dataConverter converter.DataConverter failureConverter converter.FailureConverter @@ -265,6 +266,7 @@ func (bp *basePoller) getCapabilities() *workflowservice.GetSystemInfoResponse_C // newWorkflowTaskPoller creates a new workflow task poller which must have a one to one relationship to workflow worker func newWorkflowTaskPoller( taskHandler WorkflowTaskHandler, + contextManager WorkflowContextManager, service workflowservice.WorkflowServiceClient, params workerExecutionParameters, ) *workflowTaskPoller { @@ -281,6 +283,7 @@ func newWorkflowTaskPoller( taskQueueName: params.TaskQueue, identity: params.Identity, taskHandler: taskHandler, + contextManager: contextManager, logger: params.Logger, dataConverter: params.DataConverter, failureConverter: params.FailureConverter, @@ -335,43 +338,52 @@ func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error { // close doneCh so local activity worker won't get blocked forever when trying to send back result to laResultCh. defer close(doneCh) + wfctx, err := wtp.contextManager.GetOrCreateWorkflowContext(task.task, task.historyIterator) + if err != nil { + return err + } + defer wfctx.Unlock() + for { - var response *workflowservice.RespondWorkflowTaskCompletedResponse startTime := time.Now() - task.doneCh = doneCh - task.laResultCh = laResultCh - task.laRetryCh = laRetryCh - completedRequest, resetter, err := wtp.taskHandler.ProcessWorkflowTask( - task, - func(response interface{}, startTime time.Time) (*workflowTask, error) { - wtp.logger.Debug("Force RespondWorkflowTaskCompleted.", "TaskStartedEventID", task.task.GetStartedEventId()) - heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime) - if err != nil { - return nil, err - } - if heartbeatResponse == nil || heartbeatResponse.WorkflowTask == nil { - return nil, nil - } - task := wtp.toWorkflowTask(heartbeatResponse.WorkflowTask) - task.doneCh = doneCh - task.laResultCh = laResultCh - task.laRetryCh = laRetryCh - return task, nil - }, - ) + completedRequest, err := func() (_ interface{}, retErr error) { + defer func() { wfctx.ErrorCleanup(retErr) }() + task.doneCh = doneCh + task.laResultCh = laResultCh + task.laRetryCh = laRetryCh + return wtp.taskHandler.ProcessWorkflowTask( + task, + wfctx, + func(response interface{}, startTime time.Time) (*workflowTask, error) { + wtp.logger.Debug("Force RespondWorkflowTaskCompleted.", "TaskStartedEventID", task.task.GetStartedEventId()) + heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime) + if err != nil { + return nil, err + } + if heartbeatResponse == nil || heartbeatResponse.WorkflowTask == nil { + return nil, nil + } + task := wtp.toWorkflowTask(heartbeatResponse.WorkflowTask) + task.doneCh = doneCh + task.laResultCh = laResultCh + task.laRetryCh = laRetryCh + return task, nil + }, + ) + }() if completedRequest == nil && err == nil { return nil } if _, ok := err.(workflowTaskHeartbeatError); ok { return err } - response, err = wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime) + response, err := wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime) if err != nil { return err } if eventLevel := response.GetResetHistoryEventId(); eventLevel != 0 { - resetter(eventLevel) + wfctx.SetPreviousStartedEventID(eventLevel) } if response == nil || response.WorkflowTask == nil { diff --git a/internal/internal_worker.go b/internal/internal_worker.go index e77b65203..bb47760b1 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -290,18 +290,19 @@ func newWorkflowWorkerInternal(service workflowservice.WorkflowServiceClient, pa } else { taskHandler = newWorkflowTaskHandler(params, ppMgr, registry) } - return newWorkflowTaskWorkerInternal(taskHandler, service, params, workerStopChannel, registry.interceptors) + return newWorkflowTaskWorkerInternal(taskHandler, taskHandler, service, params, workerStopChannel, registry.interceptors) } func newWorkflowTaskWorkerInternal( taskHandler WorkflowTaskHandler, + contextManager WorkflowContextManager, service workflowservice.WorkflowServiceClient, params workerExecutionParameters, stopC chan struct{}, interceptors []WorkerInterceptor, ) *workflowWorker { ensureRequiredParams(¶ms) - poller := newWorkflowTaskPoller(taskHandler, service, params) + poller := newWorkflowTaskPoller(taskHandler, contextManager, service, params) worker := newBaseWorker(baseWorkerOptions{ pollerCount: params.MaxConcurrentWorkflowTaskQueuePollers, pollerRate: defaultPollerRate, @@ -1367,7 +1368,12 @@ func (aw *WorkflowReplayer) replayWorkflowHistory(logger log.Logger, service wor }, } taskHandler := newWorkflowTaskHandler(params, nil, aw.registry) - resp, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: iterator}, nil) + wfctx, err := taskHandler.GetOrCreateWorkflowContext(task, iterator) + defer wfctx.Unlock() + if err != nil { + return err + } + resp, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: task, historyIterator: iterator}, wfctx, nil) if err != nil { return err } diff --git a/internal/internal_worker_test.go b/internal/internal_worker_test.go index 215d3d709..b8e868049 100644 --- a/internal/internal_worker_test.go +++ b/internal/internal_worker_test.go @@ -1623,7 +1623,10 @@ func (s *internalWorkerTestSuite) testWorkflowTaskHandlerHelper(params workerExe } r := newWorkflowTaskHandler(params, nil, s.registry) - _, _, err := r.ProcessWorkflowTask(&workflowTask{task: task}, nil) + wfctx, err := r.GetOrCreateWorkflowContext(task, nil) + s.NoError(err) + _, err = r.ProcessWorkflowTask(&workflowTask{task: task}, wfctx, nil) + wfctx.Unlock() s.NoError(err) }