From 903b071ee0dcd9e93fedd3311fefa6b378240d80 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 21 Oct 2024 16:03:38 -0700 Subject: [PATCH] use .value for enums to handle dynamic case Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 3 ++- tests/flytekit/unit/core/test_unions.py | 36 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 21b3906a17..e358c02a26 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -916,7 +916,8 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]: raise ValueError(f"Enum transformer cannot reverse {literal_type}") def assert_type(self, t: Type[enum.Enum], v: T): - if v not in t: + val = v.value if isinstance(v, enum.Enum) else v + if val not in t: raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") diff --git a/tests/flytekit/unit/core/test_unions.py b/tests/flytekit/unit/core/test_unions.py index 3bb38e0ad1..e446d260db 100644 --- a/tests/flytekit/unit/core/test_unions.py +++ b/tests/flytekit/unit/core/test_unions.py @@ -1,5 +1,6 @@ import typing from dataclasses import dataclass +from enum import Enum import pytest @@ -36,3 +37,38 @@ class C: with pytest.raises(TypeTransformerFailedError): TypeEngine.to_literal(ctx, 3, guessed, lt) + + +def test_asserting_enum(): + class Color(Enum): + RED = "one" + GREEN = "two" + BLUE = "blue" + + lt = TypeEngine.to_literal_type(Color) + guessed = TypeEngine.guess_python_type(lt) + tf = TypeEngine.get_transformer(guessed) + tf.assert_type(guessed, "one") + tf.assert_type(guessed, guessed("two")) + tf.assert_type(Color, "one") + + guessed2 = TypeEngine.guess_python_type(lt) + tf.assert_type(guessed, guessed2("two")) + + +@pytest.mark.sandbox_test +def test_with_remote(): + from flytekit.remote.remote import FlyteRemote + from typing_extensions import Annotated, get_args + from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings + + r = FlyteRemote( + Config.auto(config_file="/Users/ytong/.flyte/config-sandbox.yaml"), + default_project="flytesnacks", + default_domain="development", + ) + lp = r.fetch_launch_plan(name="yt_dbg.scratchpad.union_enums.wf", version="oppOd5jst-LWExhTLM0F2w") + guessed_union_type = TypeEngine.guess_python_type(lp.interface.inputs["x"].type) + guessed_enum = get_args(guessed_union_type)[0] + val = guessed_enum("one") + r.execute(lp, inputs={"x": val})