From 9247ff0474e9bee94a1ef47a0ab228044816e536 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 9 May 2023 14:55:33 -0700 Subject: [PATCH] Support union and none type in flyteidl (#401) * add support for Union Scalar Signed-off-by: Yubo Wang * support union type and literals Signed-off-by: Yubo Wang * change union type extraction Signed-off-by: Yubo Wang --------- Signed-off-by: Yubo Wang Co-authored-by: Yubo Wang Co-authored-by: Kevin Su --- clients/go/coreutils/extract_literal.go | 9 ++++ clients/go/coreutils/extract_literal_test.go | 40 +++++++++++++++- clients/go/coreutils/literals.go | 48 ++++++++++++++++++++ clients/go/coreutils/literals_test.go | 47 +++++++++++++++++++ go.mod | 4 +- 5 files changed, 145 insertions(+), 3 deletions(-) diff --git a/clients/go/coreutils/extract_literal.go b/clients/go/coreutils/extract_literal.go index afc7fd2a0..e3bf6b25c 100644 --- a/clients/go/coreutils/extract_literal.go +++ b/clients/go/coreutils/extract_literal.go @@ -62,6 +62,15 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { return scalarValue.Generic, nil case *core.Scalar_StructuredDataset: return scalarValue.StructuredDataset.Uri, nil + case *core.Scalar_Union: + // extract the value of the union but not the actual union object + extractedVal, err := ExtractFromLiteral(scalarValue.Union.Value) + if err != nil { + return nil, err + } + return extractedVal, nil + case *core.Scalar_NoneType: + return nil, nil default: return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue) } diff --git a/clients/go/coreutils/extract_literal_test.go b/clients/go/coreutils/extract_literal_test.go index 32d392322..39855e8e5 100644 --- a/clients/go/coreutils/extract_literal_test.go +++ b/clients/go/coreutils/extract_literal_test.go @@ -121,7 +121,7 @@ func TestFetchLiteral(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, p.GetScalar()) _, err = ExtractFromLiteral(p) - assert.NotNil(t, err) + assert.Nil(t, err) }) t.Run("Generic", func(t *testing.T) { @@ -199,4 +199,42 @@ func TestFetchLiteral(t *testing.T) { assert.NoError(t, err) assert.Equal(t, literalVal, extractedLiteralVal) }) + + t.Run("Union", func(t *testing.T) { + literalVal := int64(1) + var literalType = &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, + }, + }, + }, + } + lit, err := MakeLiteralForType(literalType, literalVal) + assert.NoError(t, err) + extractedLiteralVal, err := ExtractFromLiteral(lit) + assert.NoError(t, err) + assert.Equal(t, literalVal, extractedLiteralVal) + }) + + t.Run("Union with None", func(t *testing.T) { + var literalType = &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}}, + }, + }, + }, + } + lit, err := MakeLiteralForType(literalType, nil) + + assert.NoError(t, err) + extractedLiteralVal, err := ExtractFromLiteral(lit) + assert.NoError(t, err) + assert.Nil(t, extractedLiteralVal) + }) } diff --git a/clients/go/coreutils/literals.go b/clients/go/coreutils/literals.go index 4ad84f181..fb2105654 100644 --- a/clients/go/coreutils/literals.go +++ b/clients/go/coreutils/literals.go @@ -299,6 +299,28 @@ func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { return MakeLiteralForType(typ, nil) case *core.LiteralType_Schema: return MakeLiteralForType(typ, nil) + case *core.LiteralType_UnionType: + if len(t.UnionType.Variants) == 0 { + return nil, errors.Errorf("Union type must have at least one variant") + } + // For union types, we just return the default for the first variant + val, err := MakeDefaultLiteralForType(t.UnionType.Variants[0]) + if err != nil { + return nil, errors.Errorf("Failed to create default literal for first union type variant [%v]", t.UnionType.Variants[0]) + } + res := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Type: t.UnionType.Variants[0], + Value: val, + }, + }, + }, + }, + } + return res, nil } return nil, fmt.Errorf("failed to convert to a known Literal. Input Type [%v] not supported", typ.String()) @@ -588,6 +610,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro } return MakePrimitiveLiteral(newV) + case *core.LiteralType_UnionType: + // Try different types in the variants, return the first one matched + found := false + for _, subType := range newT.UnionType.Variants { + lv, err := MakeLiteralForType(subType, v) + if err == nil { + l = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: lv, + Type: subType, + }, + }, + }, + }, + } + found = true + break + } + } + if !found { + return nil, fmt.Errorf("incorrect union value [%s], supported values %+v", v, newT.UnionType.Variants) + } + default: return nil, fmt.Errorf("unsupported type %s", t.String()) } diff --git a/clients/go/coreutils/literals_test.go b/clients/go/coreutils/literals_test.go index 4b16478ad..61910ac52 100644 --- a/clients/go/coreutils/literals_test.go +++ b/clients/go/coreutils/literals_test.go @@ -261,6 +261,23 @@ func TestMakeDefaultLiteralForType(t *testing.T) { Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "x"}}}}}} assert.Equal(t, expected, l) }) + + t.Run("union", func(t *testing.T) { + l, err := MakeDefaultLiteralForType( + &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, + }, + }, + }, + }, + ) + assert.NoError(t, err) + assert.Equal(t, "*core.Union", reflect.TypeOf(l.GetScalar().GetUnion()).String()) + }) } func TestMustMakeDefaultLiteralForType(t *testing.T) { @@ -715,4 +732,34 @@ func TestMakeLiteralForType(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedVal, actualVal) }) + + t.Run("Union", func(t *testing.T) { + var literalType = &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + {Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, + }, + }, + }, + } + expectedLV := &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}, + Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_FloatValue{FloatValue: 0.1}}}}}}, + }, + }, + }}} + lv, err := MakeLiteralForType(literalType, float64(0.1)) + assert.NoError(t, err) + assert.Equal(t, expectedLV, lv) + expectedVal, err := ExtractFromLiteral(expectedLV) + assert.NoError(t, err) + actualVal, err := ExtractFromLiteral(lv) + assert.NoError(t, err) + assert.Equal(t, expectedVal, actualVal) + }) } diff --git a/go.mod b/go.mod index 33f85552c..3aacab796 100644 --- a/go.mod +++ b/go.mod @@ -84,8 +84,8 @@ require ( k8s.io/klog/v2 v2.5.0 // indirect ) -// These 2 versions were wrongly published. +// These 2 versions were wrongly published. retract ( - v1.4.0 v1.4.2 + v1.4.0 )