Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for update info #1162

Merged
merged 3 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ type WorkflowOutboundInterceptor interface {
// GetInfo intercepts workflow.GetInfo.
GetInfo(ctx Context) *WorkflowInfo

// GetUpdateInfo intercepts workflow.GetUpdateInfo.
//
// NOTE: Experimental
GetUpdateInfo(ctx Context) *UpdateInfo

// GetLogger intercepts workflow.GetLogger.
GetLogger(ctx Context) log.Logger

Expand Down
5 changes: 5 additions & 0 deletions internal/interceptor_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ func (w *WorkflowOutboundInterceptorBase) GetInfo(ctx Context) *WorkflowInfo {
return w.Next.GetInfo(ctx)
}

// GetUpdateInfo implements WorkflowOutboundInterceptor.GetUpdateInfo.
func (w *WorkflowOutboundInterceptorBase) GetUpdateInfo(ctx Context) *UpdateInfo {
return w.Next.GetUpdateInfo(ctx)
}

// GetLogger implements WorkflowOutboundInterceptor.GetLogger.
func (w *WorkflowOutboundInterceptorBase) GetLogger(ctx Context) log.Logger {
return w.Next.GetLogger(ctx)
Expand Down
8 changes: 4 additions & 4 deletions internal/internal_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ type (
cancelHandler func() // A cancel handler to be invoked on a cancel notification
signalHandler func(name string, input *commonpb.Payloads, header *commonpb.Header) error // A signal handler to be invoked on a signal event
queryHandler func(queryType string, queryArgs *commonpb.Payloads, header *commonpb.Header) (*commonpb.Payloads, error)
updateHandler func(name string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks)
updateHandler func(name string, id string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks)

logger log.Logger
isReplay bool // flag to indicate if workflow is in replay mode
Expand Down Expand Up @@ -323,8 +323,8 @@ func (wc *workflowEnvironmentImpl) takeOutgoingMessages() []*protocolpb.Message
return retval
}

func (wc *workflowEnvironmentImpl) ScheduleUpdate(name string, args *commonpb.Payloads, hdr *commonpb.Header, callbacks UpdateCallbacks) {
wc.updateHandler(name, args, hdr, callbacks)
func (wc *workflowEnvironmentImpl) ScheduleUpdate(name string, id string, args *commonpb.Payloads, hdr *commonpb.Header, callbacks UpdateCallbacks) {
wc.updateHandler(name, id, args, hdr, callbacks)
}

func withExpectedEventPredicate(pred func(*historypb.HistoryEvent) bool) msgSendOpt {
Expand Down Expand Up @@ -577,7 +577,7 @@ func (wc *workflowEnvironmentImpl) RegisterQueryHandler(
}

func (wc *workflowEnvironmentImpl) RegisterUpdateHandler(
handler func(string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks),
handler func(string, string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks),
) {
wc.updateHandler = handler
}
Expand Down
5 changes: 4 additions & 1 deletion internal/internal_event_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,16 @@ func TestUpdateEvents(t *testing.T) {

var (
gotName string
gotID string
gotArgs *commonpb.Payloads
gotHeader *commonpb.Header
)

weh := &workflowExecutionEventHandlerImpl{
workflowEnvironmentImpl: &workflowEnvironmentImpl{
updateHandler: func(name string, args *commonpb.Payloads, header *commonpb.Header, cb UpdateCallbacks) {
updateHandler: func(name string, id string, args *commonpb.Payloads, header *commonpb.Header, cb UpdateCallbacks) {
gotName = name
gotID = id
gotArgs = args
gotHeader = header
},
Expand Down Expand Up @@ -468,6 +470,7 @@ func TestUpdateEvents(t *testing.T) {
require.NoError(t, err)

require.Equal(t, input.Name, gotName)
require.Equal(t, t.Name()+"-id", gotID)
require.True(t, proto.Equal(input.Header, gotHeader))
require.True(t, proto.Equal(input.Args, gotArgs))

Expand Down
13 changes: 9 additions & 4 deletions internal/internal_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ type (

// updateProtocol wraps an updateEnv and some protocol metadata to
// implement the UpdateCallbacks abstraction. It handles callbacks by
// sending protocol lmessages.
// sending protocol messages.
updateProtocol struct {
protoInstanceID string
clientIdentity string
requestMsgID string
requestSeqID int64
scheduleUpdate func(name string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks)
scheduleUpdate func(name string, id string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks)
env updateEnv
state updateState
}
Expand All @@ -114,7 +114,7 @@ type (
// update callbacks.
func newUpdateProtocol(
protoInstanceID string,
scheduleUpdate func(name string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks),
scheduleUpdate func(name string, id string, args *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks),
env updateEnv,
) *updateProtocol {
return &updateProtocol{
Expand Down Expand Up @@ -143,7 +143,7 @@ func (up *updateProtocol) HandleMessage(msg *protocolpb.Message) error {
up.requestMsgID = msg.GetId()
up.requestSeqID = msg.GetEventId()
input := req.GetInput()
up.scheduleUpdate(input.GetName(), input.GetArgs(), input.GetHeader(), up)
up.scheduleUpdate(input.GetName(), req.GetMeta().GetUpdateId(), input.GetArgs(), input.GetHeader(), up)
up.state = updateStateRequestInitiated
return nil
}
Expand Down Expand Up @@ -241,6 +241,7 @@ func (up *updateProtocol) checkAcceptedEvent(e *historypb.HistoryEvent) bool {
func defaultUpdateHandler(
rootCtx Context,
name string,
id string,
serializedArgs *commonpb.Payloads,
header *commonpb.Header,
callbacks UpdateCallbacks,
Expand All @@ -253,6 +254,10 @@ func defaultUpdateHandler(
return
}
scheduler.Spawn(ctx, name, func(ctx Context) {
ctx = WithValue(ctx, updateInfoContextKey, &UpdateInfo{
ID: id,
})

eo := getWorkflowEnvOptions(ctx)

// If we suspect that handler registration has not occurred (e.g.
Expand Down
22 changes: 11 additions & 11 deletions internal/internal_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
UpdateHandlerOptions{},
)
var rejectErr error
defaultUpdateHandler(ctx, "will_not_be_found", args, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, "will_not_be_found", "testID", args, hdr, &testUpdateCallbacks{
RejectImpl: func(err error) { rejectErr = err },
}, runOnCallingThread)
require.ErrorContains(t, rejectErr, "unknown update")
Expand All @@ -221,7 +221,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
)
junkArgs := &commonpb.Payloads{Payloads: []*commonpb.Payload{&commonpb.Payload{}}}
var rejectErr error
defaultUpdateHandler(ctx, t.Name(), junkArgs, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, t.Name(), "testID", junkArgs, hdr, &testUpdateCallbacks{
RejectImpl: func(err error) { rejectErr = err },
}, runOnCallingThread)
require.ErrorContains(t, rejectErr, "unable to decode")
Expand All @@ -238,7 +238,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
UpdateHandlerOptions{Validator: validatorFunc},
)
var rejectErr error
defaultUpdateHandler(ctx, t.Name(), args, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{
RejectImpl: func(err error) { rejectErr = err },
}, runOnCallingThread)
require.Equal(t, validatorFunc(ctx, argStr), rejectErr)
Expand All @@ -252,7 +252,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
accepted bool
result interface{}
)
defaultUpdateHandler(ctx, t.Name(), args, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{
AcceptImpl: func() { accepted = true },
CompleteImpl: func(success interface{}, err error) {
resultErr = err
Expand All @@ -272,7 +272,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
accepted bool
result interface{}
)
defaultUpdateHandler(ctx, t.Name(), args, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{
AcceptImpl: func() { accepted = true },
CompleteImpl: func(success interface{}, err error) {
resultErr = err
Expand Down Expand Up @@ -323,7 +323,7 @@ func TestDefaultUpdateHandler(t *testing.T) {
mustSetUpdateHandler(t, ctx, t.Name(), updateFunc, UpdateHandlerOptions{})
},
}
defaultUpdateHandler(ctx, t.Name(), args, hdr, &testUpdateCallbacks{
defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{
RejectImpl: func(err error) { rejectErr = err },
AcceptImpl: func() { accepted = true },
CompleteImpl: func(success interface{}, err error) {
Expand All @@ -344,7 +344,7 @@ func TestDefaultUpdateHandler(t *testing.T) {

func TestInvalidUpdateStateTransitions(t *testing.T) {
// these would all reflect programming errors so we expect panics
stubUpdateHandler := func(string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
stubUpdateHandler := func(string, string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
requestMsg := protocolpb.Message{
Id: t.Name() + "-id",
ProtocolInstanceId: t.Name() + "-proto-id",
Expand Down Expand Up @@ -412,8 +412,8 @@ func TestInvalidUpdateStateTransitions(t *testing.T) {
}

func TestCompletedEventPredicate(t *testing.T) {
updateID := t.Name() + "-updaet-id"
stubUpdateHandler := func(string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
updateID := t.Name() + "-update-id"
stubUpdateHandler := func(string, string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
requestMsg := protocolpb.Message{
Id: t.Name() + "-id",
ProtocolInstanceId: updateID,
Expand Down Expand Up @@ -450,10 +450,10 @@ func TestCompletedEventPredicate(t *testing.T) {
}

func TestAcceptedEventPredicate(t *testing.T) {
updateID := t.Name() + "-updaet-id"
updateID := t.Name() + "-update-id"
requestMsgID := t.Name() + "request-msg-id"
requestSeqID := int64(1234)
stubUpdateHandler := func(string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
stubUpdateHandler := func(string, string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks) {}
request := updatepb.Request{
Meta: &updatepb.Meta{UpdateId: updateID},
}
Expand Down
2 changes: 1 addition & 1 deletion internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ type (
handler func(queryType string, queryArgs *commonpb.Payloads, header *commonpb.Header) (*commonpb.Payloads, error),
)
RegisterUpdateHandler(
handler func(string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks),
handler func(string, string, *commonpb.Payloads, *commonpb.Header, UpdateCallbacks),
)
IsReplaying() bool
MutableSideEffect(id string, f func() interface{}, equals func(a, b interface{}) bool) converter.EncodedValue
Expand Down
5 changes: 3 additions & 2 deletions internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ const (
workflowResultContextKey = "workflowResult"
coroutinesContextKey = "coroutines"
workflowEnvOptionsContextKey = "wfEnvOptions"
updateInfoContextKey = "updateInfo"
)

// Assert that structs do indeed implement the interfaces
Expand Down Expand Up @@ -541,8 +542,8 @@ func (d *syncWorkflowDefinition) Execute(env WorkflowEnvironment, header *common
)

getWorkflowEnvironment(d.rootCtx).RegisterUpdateHandler(
func(name string, serializedArgs *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks) {
defaultUpdateHandler(d.rootCtx, name, serializedArgs, header, callbacks, coroScheduler{d.dispatcher})
func(name string, id string, serializedArgs *commonpb.Payloads, header *commonpb.Header, callbacks UpdateCallbacks) {
defaultUpdateHandler(d.rootCtx, name, id, serializedArgs, header, callbacks, coroScheduler{d.dispatcher})
})

getWorkflowEnvironment(d.rootCtx).RegisterQueryHandler(
Expand Down
8 changes: 4 additions & 4 deletions internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ type (
workflowCancelHandler func()
signalHandler func(name string, input *commonpb.Payloads, header *commonpb.Header) error
queryHandler func(string, *commonpb.Payloads, *commonpb.Header) (*commonpb.Payloads, error)
updateHandler func(name string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks)
updateHandler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks)
startedHandler func(r WorkflowExecution, e error)

isWorkflowCompleted bool
Expand Down Expand Up @@ -2021,7 +2021,7 @@ func (env *testWorkflowEnvironmentImpl) RegisterSignalHandler(
}

func (env *testWorkflowEnvironmentImpl) RegisterUpdateHandler(
handler func(name string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks),
handler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks),
) {
env.updateHandler = handler
}
Expand Down Expand Up @@ -2361,12 +2361,12 @@ func (env *testWorkflowEnvironmentImpl) queryWorkflow(queryType string, args ...
return newEncodedValue(blob, env.GetDataConverter()), nil
}

func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, uc UpdateCallbacks, args ...interface{}) {
func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, id string, uc UpdateCallbacks, args ...interface{}) {
data, err := encodeArgs(env.GetDataConverter(), args)
if err != nil {
panic(err)
}
env.updateHandler(name, data, nil, uc)
env.updateHandler(name, id, data, nil, uc)
}

func (env *testWorkflowEnvironmentImpl) queryWorkflowByID(workflowID, queryType string, args ...interface{}) (converter.EncodedValue, error) {
Expand Down
19 changes: 19 additions & 0 deletions internal/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,11 @@ type WorkflowInfo struct {
currentHistoryLength int
}

// UpdateInfo information about a currently running update
type UpdateInfo struct {
ID string
}

// GetBinaryChecksum return binary checksum.
func (wInfo *WorkflowInfo) GetBinaryChecksum() string {
if wInfo.BinaryChecksum == "" {
Expand All @@ -1017,6 +1022,20 @@ func (wc *workflowEnvironmentInterceptor) GetInfo(ctx Context) *WorkflowInfo {
return wc.env.WorkflowInfo()
}

// GetUpdateInfo extracts info of a currently running update from a context.
func GetUpdateInfo(ctx Context) *UpdateInfo {
i := getWorkflowOutboundInterceptor(ctx)
return i.GetUpdateInfo(ctx)
}

func (wc *workflowEnvironmentInterceptor) GetUpdateInfo(ctx Context) *UpdateInfo {
uc := ctx.Value(updateInfoContextKey)
if uc == nil {
panic("getWorkflowOutboundInterceptor: No update associated with this context")
}
return uc.(*UpdateInfo)
}

// GetLogger returns a logger to be used in workflow's context
func GetLogger(ctx Context) log.Logger {
i := getWorkflowOutboundInterceptor(ctx)
Expand Down
4 changes: 2 additions & 2 deletions internal/workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,8 @@ func (e *TestWorkflowEnvironment) QueryWorkflow(queryType string, args ...interf
return e.impl.queryWorkflow(queryType, args...)
}

func (e *TestWorkflowEnvironment) UpdateWorkflow(name string, uc UpdateCallbacks, args ...interface{}) {
e.impl.updateWorkflow(name, uc, args...)
func (e *TestWorkflowEnvironment) UpdateWorkflow(name string, id string, uc UpdateCallbacks, args ...interface{}) {
e.impl.updateWorkflow(name, id, uc, args...)
}

// QueryWorkflowByID queries a child workflow by its ID and returns the result synchronously
Expand Down
32 changes: 32 additions & 0 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,38 @@ func (ts *IntegrationTestSuite) TestInspectLocalActivityInfoLocalActivityWorkerO
ts.Nil(err)
}

func (ts *IntegrationTestSuite) TestUpdateInfo() {
ctx := context.Background()
run, err := ts.client.ExecuteWorkflow(ctx,
ts.startWorkflowOptions("test-update-info"), ts.workflows.UpdateInfoWorkflow)
ts.Nil(err)
// Send an update request with a know update ID
handler, err := ts.client.UpdateWorkflowWithOptions(ctx, &client.UpdateWorkflowWithOptionsRequest{
UpdateID: "testID",
WorkflowID: run.GetID(),
RunID: run.GetRunID(),
UpdateName: "update",
})
ts.NoError(err)
// Verify the upate handler can access the update info and return the updateID
var result string
ts.NoError(handler.Get(ctx, &result))
ts.Equal("testID", result)
// Test the update validator can also use the update info
handler, err = ts.client.UpdateWorkflowWithOptions(ctx, &client.UpdateWorkflowWithOptionsRequest{
UpdateID: "notTestID",
WorkflowID: run.GetID(),
RunID: run.GetRunID(),
UpdateName: "update",
})
ts.NoError(err)
err = handler.Get(ctx, nil)
ts.Error(err)
// complete workflow
ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "finish", "finished"))
ts.NoError(run.Get(ctx, nil))
}

func (ts *IntegrationTestSuite) TestBasicSession() {
var expected []string
err := ts.executeWorkflow("test-basic-session", ts.workflows.BasicSession, &expected)
Expand Down
19 changes: 19 additions & 0 deletions test/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,24 @@ func (w *Workflows) ActivityRetryOnHBTimeout(ctx workflow.Context) ([]string, er
return []string{"heartbeatAndSleep", "heartbeatAndSleep", "heartbeatAndSleep"}, nil
}

func (w *Workflows) UpdateInfoWorkflow(ctx workflow.Context) error {
err := workflow.SetUpdateHandlerWithOptions(ctx, "update", func(ctx workflow.Context) (string, error) {
return workflow.GetUpdateInfo(ctx).ID, nil
}, workflow.UpdateHandlerOptions{
Validator: func(ctx workflow.Context) error {
if workflow.GetUpdateInfo(ctx).ID != "testID" {
return errors.New("invalid update ID")
}
return nil
},
})
if err != nil {
return errors.New("failed to register update handler")
}
workflow.GetSignalChannel(ctx, "finish").Receive(ctx, nil)
return nil
}

func (w *Workflows) ActivityHeartbeatWithRetry(ctx workflow.Context) (heartbeatCounts int, err error) {
// Make retries fast
opts := w.defaultActivityOptions()
Expand Down Expand Up @@ -2268,6 +2286,7 @@ func (w *Workflows) register(worker worker.Worker) {
worker.RegisterWorkflow(w.WorkflowWithParallelSideEffects)
worker.RegisterWorkflow(w.WorkflowWithParallelMutableSideEffects)
worker.RegisterWorkflow(w.LocalActivityStaleCache)
worker.RegisterWorkflow(w.UpdateInfoWorkflow)
worker.RegisterWorkflow(w.SignalWorkflow)
worker.RegisterWorkflow(w.CronWorkflow)
worker.RegisterWorkflow(w.CancelTimerConcurrentWithOtherCommandWorkflow)
Expand Down
Loading
Loading