diff --git a/flyteadmin/pkg/repositories/transformers/node_execution.go b/flyteadmin/pkg/repositories/transformers/node_execution.go index 107e9efb70..1f17f0f131 100644 --- a/flyteadmin/pkg/repositories/transformers/node_execution.go +++ b/flyteadmin/pkg/repositories/transformers/node_execution.go @@ -42,6 +42,7 @@ func addNodeRunningState(request *admin.NodeExecutionEventRequest, nodeExecution "failed to marshal occurredAt into a timestamp proto with error: %v", err) } closure.StartedAt = startedAtProto + closure.DeckUri = request.GetEvent().GetDeckUri() return nil } diff --git a/flyteadmin/pkg/repositories/transformers/node_execution_test.go b/flyteadmin/pkg/repositories/transformers/node_execution_test.go index e37d312612..dff9a5dba5 100644 --- a/flyteadmin/pkg/repositories/transformers/node_execution_test.go +++ b/flyteadmin/pkg/repositories/transformers/node_execution_test.go @@ -51,6 +51,7 @@ var childExecutionID = &core.WorkflowExecutionIdentifier{ const dynamicWorkflowClosureRef = "s3://bucket/admin/metadata/workflow" const testInputURI = "fake://bucket/inputs.pb" +const DeckURI = "fake://bucket/deck.html" var testInputs = &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -65,6 +66,7 @@ func TestAddRunningState(t *testing.T) { Event: &event.NodeExecutionEvent{ Phase: core.NodeExecution_RUNNING, OccurredAt: startedAtProto, + DeckUri: DeckURI, }, } nodeExecutionModel := models.NodeExecution{} @@ -73,6 +75,7 @@ func TestAddRunningState(t *testing.T) { assert.Nil(t, err) assert.Equal(t, startedAt, *nodeExecutionModel.StartedAt) assert.True(t, proto.Equal(startedAtProto, closure.GetStartedAt())) + assert.Equal(t, DeckURI, closure.GetDeckUri()) } func TestAddTerminalState_OutputURI(t *testing.T) { @@ -84,6 +87,7 @@ func TestAddTerminalState_OutputURI(t *testing.T) { OutputUri: outputURI, }, OccurredAt: occurredAtProto, + DeckUri: DeckURI, }, } startedAt := occurredAt.Add(-time.Minute) @@ -99,6 +103,7 @@ func TestAddTerminalState_OutputURI(t *testing.T) { assert.Nil(t, err) assert.EqualValues(t, outputURI, closure.GetOutputUri()) assert.Equal(t, time.Minute, nodeExecutionModel.Duration) + assert.Equal(t, DeckURI, closure.GetDeckUri()) } func TestAddTerminalState_OutputData(t *testing.T) { @@ -193,6 +198,36 @@ func TestAddTerminalState_Error(t *testing.T) { assert.Equal(t, time.Minute, nodeExecutionModel.Duration) } +func TestAddTerminalState_DeckURIInFailedExecution(t *testing.T) { + error := &core.ExecutionError{ + Code: "foo", + } + request := admin.NodeExecutionEventRequest{ + Event: &event.NodeExecutionEvent{ + Phase: core.NodeExecution_FAILED, + OutputResult: &event.NodeExecutionEvent_Error{ + Error: error, + }, + OccurredAt: occurredAtProto, + DeckUri: DeckURI, + }, + } + startedAt := occurredAt.Add(-time.Minute) + startedAtProto, _ := ptypes.TimestampProto(startedAt) + nodeExecutionModel := models.NodeExecution{ + StartedAt: &startedAt, + } + closure := admin.NodeExecutionClosure{ + StartedAt: startedAtProto, + } + err := addTerminalState(context.TODO(), &request, &nodeExecutionModel, &closure, + interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.Nil(t, err) + assert.True(t, proto.Equal(error, closure.GetError())) + assert.Equal(t, time.Minute, nodeExecutionModel.Duration) + assert.Equal(t, DeckURI, closure.GetDeckUri()) +} + func TestCreateNodeExecutionModel(t *testing.T) { parentTaskExecID := uint(8) request := &admin.NodeExecutionEventRequest{ diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 000d6bd7e7..4e7e752f62 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -41,6 +41,14 @@ import ( const pluginContextKey = contextutils.Key("plugin") +type DeckStatus int + +const ( + DeckUnknown DeckStatus = iota + DeckEnabled + DeckDisabled +) + type metrics struct { pluginPanics labeled.Counter unsupportedTaskType labeled.Counter @@ -71,10 +79,43 @@ func getPluginMetricKey(pluginID, taskType string) string { return taskType + "_" + pluginID } -func (p *pluginRequestedTransition) CacheHit(outputPath storage.DataReference, deckPath *storage.DataReference, entry catalog.Entry) { +func (p *pluginRequestedTransition) AddDeckURI(tCtx *taskExecutionContext) { + var deckURI *storage.DataReference + deckURIValue := tCtx.ow.GetDeckPath() + deckURI = &deckURIValue + + if p.execInfo.OutputInfo == nil { + p.execInfo.OutputInfo = &handler.OutputInfo{} + } + + p.execInfo.OutputInfo.DeckURI = deckURI +} + +func (p *pluginRequestedTransition) RemoveDeckURIIfDeckNotExists(ctx context.Context, tCtx *taskExecutionContext) error { + reader := tCtx.ow.GetReader() + if reader == nil { + return nil + } + + exists, err := reader.DeckExists(ctx) + if err != nil { + if p.execInfo.OutputInfo != nil { + p.execInfo.OutputInfo.DeckURI = nil + } + return regErrors.Wrapf(err, "failed to check existence of deck file") + } + + if !exists && p.execInfo.OutputInfo != nil { + p.execInfo.OutputInfo.DeckURI = nil + } + + return nil +} + +func (p *pluginRequestedTransition) CacheHit(outputPath storage.DataReference, entry catalog.Entry) { p.ttype = handler.TransitionTypeEphemeral p.pInfo = pluginCore.PhaseInfoSuccess(nil) - p.ObserveSuccess(outputPath, deckPath, &event.TaskNodeMetadata{CacheStatus: entry.GetStatus().GetCacheStatus(), CatalogKey: entry.GetStatus().GetMetadata()}) + p.ObserveSuccess(outputPath, &event.TaskNodeMetadata{CacheStatus: entry.GetStatus().GetCacheStatus(), CatalogKey: entry.GetStatus().GetMetadata()}) } func (p *pluginRequestedTransition) PopulateCacheInfo(entry catalog.Entry) { @@ -144,10 +185,13 @@ func (p *pluginRequestedTransition) FinalTaskEvent(input ToTaskExecutionEventInp return ToTaskExecutionEvent(input) } -func (p *pluginRequestedTransition) ObserveSuccess(outputPath storage.DataReference, deckPath *storage.DataReference, taskMetadata *event.TaskNodeMetadata) { - p.execInfo.OutputInfo = &handler.OutputInfo{ - OutputURI: outputPath, - DeckURI: deckPath, +func (p *pluginRequestedTransition) ObserveSuccess(outputPath storage.DataReference, taskMetadata *event.TaskNodeMetadata) { + if p.execInfo.OutputInfo == nil { + p.execInfo.OutputInfo = &handler.OutputInfo{ + OutputURI: outputPath, + } + } else { + p.execInfo.OutputInfo.OutputURI = outputPath } p.execInfo.TaskNodeInfo = &handler.TaskNodeInfo{ @@ -171,7 +215,8 @@ func (p *pluginRequestedTransition) FinalTransition(ctx context.Context) (handle } logger.Debugf(ctx, "Task still running") - return handler.DoTransition(p.ttype, handler.PhaseInfoRunning(nil)), nil + // Here will send the deck uri to flyteadmin + return handler.DoTransition(p.ttype, handler.PhaseInfoRunning(&p.execInfo)), nil } // The plugin interface available especially for testing. @@ -380,6 +425,40 @@ func (t Handler) fetchPluginTaskMetrics(pluginID, taskType string) (*taskMetrics return t.taskMetricsMap[metricNameKey], nil } +func GetDeckStatus(ctx context.Context, tCtx *taskExecutionContext) (DeckStatus, error) { + // GetDeckStatus determines whether a task generates a deck based on its execution context. + // + // This function evaluates the current condition of the task to determine the deck status: + // + // | Condition Description | Has Deck | + // |--------------------------------|----------| + // | Enabled and Running | Yes | + // | Unknown State with Deck | Yes | + // | Unknown State without Deck | No | + // | Enabled and Succeeded | Yes | + // | Enabled but Memory Exceeded | No | + // | Disabled | No | + // + // The lifecycle of deck generation is as follows: + // - During task execution, the condition is checked to determine if a deck should be generated. + // - In terminal states, if the status is DeckUnknown or DeckEnabled, a HEAD request can be made to verify the existence of the deck file. + template, err := tCtx.tr.Read(ctx) + if err != nil { + return DeckUnknown, regErrors.Wrapf(err, "failed to read task template") + } + + deckValue := template.GetMetadata().GetGeneratesDeck() + if deckValue == nil { + return DeckUnknown, nil + } + + if deckValue.GetValue() { + return DeckEnabled, nil + } + + return DeckDisabled, nil +} + func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *taskExecutionContext, ts handler.TaskNodeState) (*pluginRequestedTransition, error) { pluginTrns := &pluginRequestedTransition{} @@ -464,8 +543,30 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta } } + // Regardless of the observed phase, we always add the DeckUri to support real-time deck functionality. + // The deck should be accessible even if the task is still running or has failed. + deckStatus, err := GetDeckStatus(ctx, tCtx) + if err != nil { + return nil, err + } + + if deckStatus == DeckEnabled { + pluginTrns.AddDeckURI(tCtx) + } + + defer func() { + if (deckStatus == DeckUnknown || deckStatus == DeckEnabled) && pluginTrns.pInfo.Phase().IsTerminal() { + if err := pluginTrns.RemoveDeckURIIfDeckNotExists(ctx, tCtx); err != nil { + logger.Errorf(ctx, "Failed to remove deck URI if deck does not exist. Error: %v", err) + } + } + }() + switch pluginTrns.pInfo.Phase() { case pluginCore.PhaseSuccess: + if deckStatus == DeckUnknown { + pluginTrns.AddDeckURI(tCtx) + } // ------------------------------------- // TODO: @kumare create Issue# Remove the code after we use closures to handle dynamic nodes // This code only exists to support Dynamic tasks. Eventually dynamic tasks will use closure nodes to execute @@ -501,18 +602,7 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(), }) } else { - var deckURI *storage.DataReference - if tCtx.ow.GetReader() != nil { - exists, err := tCtx.ow.GetReader().DeckExists(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to check deck file existence. Error: %v", err) - return pluginTrns, regErrors.Wrapf(err, "failed to check existence of deck file") - } else if exists { - deckURIValue := tCtx.ow.GetDeckPath() - deckURI = &deckURIValue - } - } - pluginTrns.ObserveSuccess(tCtx.ow.GetOutputPath(), deckURI, + pluginTrns.ObserveSuccess(tCtx.ow.GetOutputPath(), &event.TaskNodeMetadata{ CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(), })