Skip to content

Commit

Permalink
use .value for enums to handle dynamic case
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Oct 21, 2024
1 parent 99d6d61 commit 903b071
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
3 changes: 2 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
raise ValueError(f"Enum transformer cannot reverse {literal_type}")

def assert_type(self, t: Type[enum.Enum], v: T):
if v not in t:
val = v.value if isinstance(v, enum.Enum) else v
if val not in t:
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


Expand Down
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_unions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from dataclasses import dataclass
from enum import Enum

import pytest

Expand Down Expand Up @@ -36,3 +37,38 @@ class C:

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, 3, guessed, lt)


def test_asserting_enum():
class Color(Enum):
RED = "one"
GREEN = "two"
BLUE = "blue"

lt = TypeEngine.to_literal_type(Color)
guessed = TypeEngine.guess_python_type(lt)
tf = TypeEngine.get_transformer(guessed)
tf.assert_type(guessed, "one")
tf.assert_type(guessed, guessed("two"))
tf.assert_type(Color, "one")

guessed2 = TypeEngine.guess_python_type(lt)
tf.assert_type(guessed, guessed2("two"))


@pytest.mark.sandbox_test
def test_with_remote():
from flytekit.remote.remote import FlyteRemote
from typing_extensions import Annotated, get_args
from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings

r = FlyteRemote(
Config.auto(config_file="/Users/ytong/.flyte/config-sandbox.yaml"),
default_project="flytesnacks",
default_domain="development",
)
lp = r.fetch_launch_plan(name="yt_dbg.scratchpad.union_enums.wf", version="oppOd5jst-LWExhTLM0F2w")
guessed_union_type = TypeEngine.guess_python_type(lp.interface.inputs["x"].type)
guessed_enum = get_args(guessed_union_type)[0]
val = guessed_enum("one")
r.execute(lp, inputs={"x": val})

0 comments on commit 903b071

Please sign in to comment.