From fb111011496a61a0f045e275228048169c73f64c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 22 May 2024 16:46:51 +0800 Subject: [PATCH] support any and non any workflow Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 5 -- flytekit/types/pickle/pickle.py | 23 ------ tests/flytekit/unit/core/test_type_engine.py | 75 +++++++++++++++++++- 3 files changed, 74 insertions(+), 29 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3a679269c8f..4b1d144c885 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1166,7 +1166,6 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ Converts a Literal value with an expected python type into a python value. """ - # print("Expected Python Type: ", expected_python_type) transformer = cls.get_transformer(expected_python_type) return transformer.to_python_value(ctx, lv, expected_python_type) @@ -1223,13 +1222,9 @@ def literal_map_to_kwargs( kwargs = {} for i, k in enumerate(lm.literals): try: - # print("converting input: ", k, " with value: ", lm.literals[k]) - # print("Type 1: ", python_interface_inputs[k]) kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) - # print("kwargs[k]:", kwargs[k]) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc - return kwargs @classmethod diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index c5f9e8ac6e8..b82980e9573 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -87,30 +87,7 @@ def assert_type(self, t: Type[T], v: T): # Every type can serialize to pickle, so we don't need to check the type here. ... - # def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - # try: - # uri = lv.scalar.blob.uri - # return FlytePickle.from_pickle(uri) - # except Exception as e: - # from datetime import datetime, timedelta - - # if lv.scalar: - # if lv.scalar.primitive: - # if lv.scalar.primitive.integer: - # return TypeEngine.to_python_value(ctx, lv, int) - # elif lv.scalar.primitive.float_value: - # return TypeEngine.to_python_value(ctx, lv, float) - # elif lv.scalar.primitive.string_value: - # return TypeEngine.to_python_value(ctx, lv, str) - # elif lv.scalar.primitive.boolean: - # return TypeEngine.to_python_value(ctx, lv, bool) - # elif lv.scalar.primitive.datetime: - # return TypeEngine.to_python_value(ctx, lv, datetime) - # elif lv.scalar.primitive.duration: - # return TypeEngine.to_python_value(ctx, lv, timedelta) - # raise None def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - print("lv:", lv) primitive = lv.scalar.primitive if primitive: from datetime import datetime, timedelta diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d4db4f34fe0..9bdae71b6a0 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field from datetime import timedelta from enum import Enum, auto -from typing import List, Optional, Type +from typing import Any, List, Optional, Tuple, Type import mock import pyarrow as pa @@ -2017,6 +2017,79 @@ def __init__(self, number: int): TypeEngine.to_literal(ctx, 1, typing.Optional[typing.Any], lt) +def test_non_any_as_any_input_workflow(): + @task + def foo(a: Any) -> int: + if type(a) == int: + return a + 1 + return 0 + + @workflow + def wf_int(a: int) -> int: + return foo(a=a) + + @workflow + def wf_float(a: float) -> int: + return foo(a=a) + + @workflow + def wf_str(a: str) -> int: + return foo(a=a) + + @workflow + def wf_bool(a: bool) -> int: + return foo(a=a) + + @workflow + def wf_datetime(a: datetime.datetime) -> int: + return foo(a=a) + + @workflow + def wf_duration(a: datetime.timedelta) -> int: + return foo(a=a) + + assert wf_int(a=1) == 2 + assert wf_float(a=1.0) == 0 + assert wf_str(a="1") == 0 + assert wf_bool(a=True) == 0 + assert wf_datetime(a=datetime.datetime.now()) == 0 + assert wf_duration(a=datetime.timedelta(seconds=1)) == 0 + + +def test_non_any_as_any_output_workflow(): + now = datetime.datetime.now(datetime.timezone.utc) + + @task + def foo_int() -> int: + return 1 + + @task + def foo_float() -> float: + return 1.0 + + @task + def foo_str() -> str: + return "1" + + @task + def foo_bool() -> bool: + return True + + @task + def foo_datetime() -> datetime.datetime: + return now + + @task + def foo_duration() -> datetime.timedelta: + return datetime.timedelta(seconds=1) + + @workflow + def wf() -> Tuple[Any, Any, Any, Any, Any, Any]: + return foo_int(), foo_float(), foo_str(), foo_bool(), foo_datetime(), foo_duration() + + assert wf() == (1, 1.0, "1", True, now, datetime.timedelta(seconds=1)) + + def test_enum_in_dataclass(): @dataclass class Datum(DataClassJsonMixin):