diff --git a/flytepropeller/pkg/controller/executors/failure_node_lookup.go b/flytepropeller/pkg/controller/executors/failure_node_lookup.go index f4f0989d7b..535a2774dc 100644 --- a/flytepropeller/pkg/controller/executors/failure_node_lookup.go +++ b/flytepropeller/pkg/controller/executors/failure_node_lookup.go @@ -37,8 +37,8 @@ func (f FailureNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, erro return nil, nil } -func (f FailureNodeLookup) GetOriginalError() *core.ExecutionError { - return f.OriginalError +func (f FailureNodeLookup) GetOriginalError() (*core.ExecutionError, error) { + return f.OriginalError, nil } func NewFailureNodeLookup(nodeLookup NodeLookup, failureNode v1alpha1.ExecutableNode, failureNodeStatus v1alpha1.ExecutableNodeStatus, originalError *core.ExecutionError) NodeLookup { diff --git a/flytepropeller/pkg/controller/executors/failure_node_lookup_test.go b/flytepropeller/pkg/controller/executors/failure_node_lookup_test.go index e9d6857ec4..5dbcfedf47 100644 --- a/flytepropeller/pkg/controller/executors/failure_node_lookup_test.go +++ b/flytepropeller/pkg/controller/executors/failure_node_lookup_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" ) @@ -26,7 +27,10 @@ func TestNewFailureNodeLookup(t *testing.T) { nl := nl{} en := en{} ns := ns{} - nodeLoopUp := NewFailureNodeLookup(nl, en, ns) + execErr := &core.ExecutionError{ + Message: "node failure", + } + nodeLoopUp := NewFailureNodeLookup(nl, en, ns, execErr) assert.NotNil(t, nl) typed := nodeLoopUp.(FailureNodeLookup) assert.Equal(t, nl, typed.NodeLookup) @@ -38,6 +42,9 @@ func TestNewTestFailureNodeLookup(t *testing.T) { n := &mocks.ExecutableNode{} ns := &mocks.ExecutableNodeStatus{} failureNodeID := "fn1" + originalErr := &core.ExecutionError{ + Message: "node failure", + } nl := NewTestNodeLookup( map[string]v1alpha1.ExecutableNode{v1alpha1.StartNodeID: n, failureNodeID: n}, map[string]v1alpha1.ExecutableNodeStatus{v1alpha1.StartNodeID: ns, failureNodeID: ns}, @@ -45,7 +52,7 @@ func TestNewTestFailureNodeLookup(t *testing.T) { assert.NotNil(t, nl) - failureNodeLookup := NewFailureNodeLookup(nl, n, ns) + failureNodeLookup := NewFailureNodeLookup(nl, n, ns, originalErr).(FailureNodeLookup) r, ok := failureNodeLookup.GetNode(v1alpha1.StartNodeID) assert.True(t, ok) assert.Equal(t, n, r) @@ -64,4 +71,9 @@ func TestNewTestFailureNodeLookup(t *testing.T) { nodeIDs, err = failureNodeLookup.FromNode(failureNodeID) assert.Nil(t, nodeIDs) assert.Nil(t, err) + + oe, err := failureNodeLookup.GetOriginalError() + assert.NotNil(t, oe) + assert.Equal(t, originalErr, oe) + assert.Nil(t, err) } diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 8d638ae641..f6fec1a123 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -768,9 +768,9 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur // Resolve error input if current node is an on failure node failureNodeLookup, ok := nCtx.ContextualNodeLookup().(executors.FailureNodeLookup) if ok { - originalErr := failureNodeLookup.GetOriginalError() + originalErr, _ := failureNodeLookup.GetOriginalError() if originalErr != nil { - ResolveErrorInput(ctx, nodeInputs, node.GetID(), originalErr) + ResolveOnFailureNodeInput(ctx, nodeInputs, node.GetID(), originalErr) } } p := common.CheckOffloadingCompat(ctx, nCtx, nodeInputs.GetLiterals(), node, c.literalOffloadingConfig) diff --git a/flytepropeller/pkg/controller/nodes/resolve.go b/flytepropeller/pkg/controller/nodes/resolve.go index 5e364040a2..2cb7aef5c2 100644 --- a/flytepropeller/pkg/controller/nodes/resolve.go +++ b/flytepropeller/pkg/controller/nodes/resolve.go @@ -106,11 +106,10 @@ func Resolve(ctx context.Context, outputResolver OutputResolver, nl executors.No }, nil } -func ResolveErrorInput(ctx context.Context, nodeInputs *core.LiteralMap, nodeID v1alpha1.NodeID, execErr *core.ExecutionError) { +func ResolveOnFailureNodeInput(ctx context.Context, nodeInputs *core.LiteralMap, nodeID v1alpha1.NodeID, execErr *core.ExecutionError) { literals := nodeInputs.GetLiterals() if literal, exists := literals["err"]; exists { // make new Scalar for literal map - logger.Debugf(ctx, "Processing literal for key 'err'") errorUnion := &core.Scalar_Union{ Union: &core.Union{ Value: &core.Literal{ diff --git a/flytepropeller/pkg/controller/nodes/resolve_test.go b/flytepropeller/pkg/controller/nodes/resolve_test.go index 10b9e4e45d..ff7af5278e 100644 --- a/flytepropeller/pkg/controller/nodes/resolve_test.go +++ b/flytepropeller/pkg/controller/nodes/resolve_test.go @@ -467,3 +467,65 @@ func TestResolve(t *testing.T) { }) } + +func TestResolveErrorInput(t *testing.T) { + ctx := context.Background() + t.Run("ResolveErrorInputs", func(t *testing.T) { + noneLiteral, _ := coreutils.MakeLiteral(nil) + inputLiterals := make(map[string]*core.Literal, 1) + inputLiterals["err"] = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: noneLiteral, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + Structure: &core.TypeStructure{ + Tag: "none", + }, + }, + }, + }, + }, + }, + } + inputLiteralMap := &core.LiteralMap{ + Literals: inputLiterals, + } + nID := "fn" + execErr := &core.ExecutionError{ + Message: "node failure", + } + expectedLiterals := make(map[string]*core.Literal, 1) + errorLiteral, _ := coreutils.MakeLiteral(&core.Error{Message: execErr.Message, FailedNodeId: nID,}) + expectedLiterals["err"] = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: errorLiteral, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_ERROR, + }, + Structure: &core.TypeStructure{ + Tag: "FlyteError", + }, + }, + }, + }, + }, + }, + } + expectedLiteralMap := &core.LiteralMap{ + Literals: expectedLiterals, + } + // Execute resolve + ResolveOnFailureNodeInput(ctx, inputLiteralMap, nID, execErr) + flyteassert.EqualLiteralMap(t, expectedLiteralMap, inputLiteralMap) + }) + +} diff --git a/flytepropeller/pkg/utils/assert/literals.go b/flytepropeller/pkg/utils/assert/literals.go index c0fac675ed..31c5d2e846 100644 --- a/flytepropeller/pkg/utils/assert/literals.go +++ b/flytepropeller/pkg/utils/assert/literals.go @@ -9,6 +9,21 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" ) +func EqualLiteralType(t *testing.T, lt1 *core.LiteralType, lt2 *core.LiteralType) { + if !assert.Equal(t, lt1 == nil, lt2 == nil) { + assert.FailNow(t, "One of the values is nil") + } + assert.Equal(t, reflect.TypeOf(lt1.GetType()), reflect.TypeOf(lt2.GetType())) + switch lt1.GetType().(type) { + case *core.LiteralType_Simple: + assert.Equal(t, lt1.GetType().(*core.LiteralType_Simple).Simple, lt2.GetType().(*core.LiteralType_Simple).Simple) + default: + assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(lt1.GetType())) + } + + assert.Equal(t, lt1.GetStructure().GetTag(), lt2.GetStructure().GetTag()) +} + func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) { if !assert.Equal(t, p1 == nil, p2 == nil) { assert.FailNow(t, "One of the values is nil") @@ -27,6 +42,23 @@ func EqualPrimitive(t *testing.T, p1 *core.Primitive, p2 *core.Primitive) { } } +func EqualError(t *testing.T, e1 *core.Error, e2 *core.Error) { + if !assert.Equal(t, e1 == nil, e2 == nil) { + assert.FailNow(t, "One of the values is nil") + } + assert.Equal(t, e1.GetMessage(), e2.GetMessage()) + assert.Equal(t, e1.GetFailedNodeId(), e2.GetFailedNodeId()) +} + +func EqualUnion(t *testing.T, u1 *core.Union, u2 *core.Union) { + if !assert.Equal(t, u1 == nil, u2 == nil) { + assert.FailNow(t, "One of the values is nil") + } + assert.Equal(t, reflect.TypeOf(u1.GetValue()), reflect.TypeOf(u2.GetValue())) + EqualLiterals(t, u1.GetValue(), u2.GetValue()) + EqualLiteralType(t, u1.GetType(), u2.GetType()) +} + func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) { if !assert.Equal(t, p1 == nil, p2 == nil) { assert.FailNow(t, "One of the values is nil") @@ -38,6 +70,10 @@ func EqualScalar(t *testing.T, p1 *core.Scalar, p2 *core.Scalar) { switch p1.GetValue().(type) { case *core.Scalar_Primitive: EqualPrimitive(t, p1.GetPrimitive(), p2.GetPrimitive()) + case *core.Scalar_Error: + EqualError(t, p1.GetError(), p2.GetError()) + case *core.Scalar_Union: + EqualUnion(t, p1.GetUnion(), p2.GetUnion()) default: assert.FailNow(t, "Not yet implemented for types %v", reflect.TypeOf(p1.GetValue())) }