From a649812c641b7833a2e7bef4619046ed1871beae Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Mon, 24 Jul 2023 07:58:58 -0700 Subject: [PATCH] Make QueryWorkflowWithOptions go through interceptor (#1171) --- internal/interceptor.go | 10 +-- internal/internal_workflow_client.go | 95 ++++++++++++++++------------ test/integration_test.go | 11 ++++ 3 files changed, 73 insertions(+), 43 deletions(-) diff --git a/internal/interceptor.go b/internal/interceptor.go index e48dc0af6..b0b008fa0 100644 --- a/internal/interceptor.go +++ b/internal/interceptor.go @@ -27,6 +27,7 @@ import ( "time" commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" updatepb "go.temporal.io/api/update/v1" "go.temporal.io/sdk/converter" "go.temporal.io/sdk/internal/common/metrics" @@ -405,8 +406,9 @@ type ClientTerminateWorkflowInput struct { // ClientQueryWorkflowInput is the input to // ClientOutboundInterceptor.QueryWorkflow. type ClientQueryWorkflowInput struct { - WorkflowID string - RunID string - QueryType string - Args []interface{} + WorkflowID string + RunID string + QueryType string + Args []interface{} + QueryRejectCondition enumspb.QueryRejectCondition } diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index 8affbbf23..776d0b23a 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -194,6 +194,11 @@ type ( // func which use a next token to get next page of history events paginate func(nexttoken []byte) (*workflowservice.GetWorkflowExecutionHistoryResponse, error) } + + // queryRejectedError is a wrapper for QueryRejected + queryRejectedError struct { + queryRejected *querypb.QueryRejected + } ) // ExecuteWorkflow starts a workflow execution and returns a WorkflowRun that will allow you to wait until this workflow @@ -879,43 +884,28 @@ func (wc *WorkflowClient) QueryWorkflowWithOptions(ctx context.Context, request return nil, err } - var input *commonpb.Payloads - if len(request.Args) > 0 { - var err error - if input, err = encodeArgs(wc.dataConverter, request.Args); err != nil { - return nil, err - } - } - req := &workflowservice.QueryWorkflowRequest{ - Namespace: wc.namespace, - Execution: &commonpb.WorkflowExecution{ - WorkflowId: request.WorkflowID, - RunId: request.RunID, - }, - Query: &querypb.WorkflowQuery{ - QueryType: request.QueryType, - QueryArgs: input, - Header: request.Header, - }, - QueryRejectCondition: request.QueryRejectCondition, - } - - grpcCtx, cancel := newGRPCContext(ctx, defaultGrpcRetryParameters(ctx)) - defer cancel() - resp, err := wc.workflowService.QueryWorkflow(grpcCtx, req) + // Set header before interceptor run + ctx, err := contextWithHeaderPropagated(ctx, request.Header, wc.contextPropagators) if err != nil { return nil, err } - if resp.QueryRejected != nil { - return &QueryWorkflowWithOptionsResponse{ - QueryRejected: resp.QueryRejected, - QueryResult: nil, - }, nil + result, err := wc.interceptor.QueryWorkflow(ctx, &ClientQueryWorkflowInput{ + WorkflowID: request.WorkflowID, + RunID: request.RunID, + QueryType: request.QueryType, + Args: request.Args, + }) + if err != nil { + if err, ok := err.(*queryRejectedError); ok { + return &QueryWorkflowWithOptionsResponse{ + QueryRejected: err.queryRejected, + }, nil + } + return nil, err } return &QueryWorkflowWithOptionsResponse{ - QueryRejected: nil, - QueryResult: newEncodedValue(resp.QueryResult, wc.dataConverter), + QueryResult: result, }, nil } @@ -1724,17 +1714,40 @@ func (w *workflowClientInterceptor) QueryWorkflow( return nil, err } - result, err := w.client.QueryWorkflowWithOptions(ctx, &QueryWorkflowWithOptionsRequest{ - WorkflowID: in.WorkflowID, - RunID: in.RunID, - QueryType: in.QueryType, - Args: in.Args, - Header: header, - }) + var input *commonpb.Payloads + if len(in.Args) > 0 { + var err error + if input, err = encodeArgs(w.client.dataConverter, in.Args); err != nil { + return nil, err + } + } + req := &workflowservice.QueryWorkflowRequest{ + Namespace: w.client.namespace, + Execution: &commonpb.WorkflowExecution{ + WorkflowId: in.WorkflowID, + RunId: in.RunID, + }, + Query: &querypb.WorkflowQuery{ + QueryType: in.QueryType, + QueryArgs: input, + Header: header, + }, + QueryRejectCondition: in.QueryRejectCondition, + } + + grpcCtx, cancel := newGRPCContext(ctx, defaultGrpcRetryParameters(ctx)) + defer cancel() + resp, err := w.client.workflowService.QueryWorkflow(grpcCtx, req) if err != nil { return nil, err } - return result.QueryResult, nil + + if resp.QueryRejected != nil { + return nil, &queryRejectedError{ + queryRejected: resp.QueryRejected, + } + } + return newEncodedValue(resp.QueryResult, w.client.dataConverter), nil } func (w *workflowClientInterceptor) UpdateWorkflow( @@ -1875,3 +1888,7 @@ func (luh *lazyUpdateHandle) Get(ctx context.Context, valuePtr interface{}) erro } return enc.Get(valuePtr) } + +func (q *queryRejectedError) Error() string { + return q.queryRejected.GoString() +} diff --git a/test/integration_test.go b/test/integration_test.go index e7ce822ab..e171231a3 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1535,6 +1535,17 @@ func (ts *IntegrationTestSuite) TestInterceptorCalls() { ts.NoError(queryVal.Get(&queryRes)) ts.Equal("queryresult(queryarg)", queryRes) + // Query with options + response, err := ts.client.QueryWorkflowWithOptions(ctx, &client.QueryWorkflowWithOptionsRequest{ + WorkflowID: run.GetID(), + RunID: run.GetRunID(), + QueryType: "query", + Args: []interface{}{"queryarg"}, + }) + ts.NoError(err) + ts.NoError(response.QueryResult.Get(&queryRes)) + ts.Equal("queryresult(queryarg)", queryRes) + // Send signal ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "finish", "finished"))