diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 6721e9afff..8230cf22c8 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -356,7 +356,6 @@ jobs: - flytekit-pandera - flytekit-papermill - flytekit-polars - - flytekit-pydantic - flytekit-ray - flytekit-snowflake - flytekit-spark diff --git a/dev-requirements.in b/dev-requirements.in index 27c17ac6d0..20aba11e9d 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -30,6 +30,7 @@ torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 # pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436 torch; python_version<'3.12' +pydantic # TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. # We have temporarily disabled this feature on Windows and are using python-magic for Mac OS and Linux instead. diff --git a/dev-requirements.txt b/dev-requirements.txt index 5fd363804e..002f5421c4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -22,6 +22,8 @@ aiosignal==1.3.1 # via aiohttp annotated-types==0.7.0 # via pydantic +appnope==0.1.4 + # via ipykernel asn1crypto==1.5.1 # via snowflake-connector-python asttokens==2.4.1 @@ -31,7 +33,6 @@ attrs==23.2.0 # aiohttp # hypothesis # jsonlines - # visions autoflake==2.3.1 # via -r dev-requirements.in azure-core==1.30.1 @@ -73,8 +74,8 @@ cloudpickle==3.0.0 # via flytekit codespell==2.3.0 # via -r dev-requirements.in -contourpy==1.3.0 - # via matplotlib +comm==0.2.2 + # via ipykernel coverage[toml]==7.5.3 # via # -r dev-requirements.in @@ -89,12 +90,10 @@ cryptography==43.0.1 # pyjwt # pyopenssl # snowflake-connector-python -cycler==0.12.1 - # via matplotlib -dacite==1.8.1 - # via ydata-profiling dataclasses-json==0.5.9 # via flytekit +debugpy==1.8.7 + # via ipykernel decorator==5.1.1 # via # gcsfs @@ -119,8 +118,6 @@ flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteid # via # -r dev-requirements.in # flytekit -fonttools==4.54.1 - # via matplotlib frozenlist==1.4.1 # via # aiohttp @@ -185,8 +182,6 @@ grpcio-status==1.62.2 # via # flytekit # google-api-core -htmlmin==0.1.12 - # via ydata-profiling hypothesis==6.103.0 # via -r dev-requirements.in icdiff==2.0.7 @@ -198,16 +193,16 @@ idna==3.7 # requests # snowflake-connector-python # yarl -imagehash==4.3.1 - # via - # visions - # ydata-profiling importlib-metadata==7.1.0 # via flytekit iniconfig==2.0.0 # via pytest -ipython==8.25.0 +ipykernel==6.29.5 # via -r dev-requirements.in +ipython==8.25.0 + # via + # -r dev-requirements.in + # ipykernel isodate==0.6.1 # via azure-storage-blob jaraco-classes==3.4.0 @@ -222,38 +217,35 @@ jaraco-functools==4.0.1 # via keyring jedi==0.19.1 # via ipython -jinja2==3.1.4 - # via ydata-profiling jmespath==1.0.1 # via botocore joblib==1.4.2 # via # -r dev-requirements.in # flytekit - # phik # scikit-learn jsonlines==4.0.0 # via flytekit jsonpickle==3.0.4 # via flytekit +jupyter-client==8.6.3 + # via + # -r dev-requirements.in + # ipykernel +jupyter-core==5.7.2 + # via + # ipykernel + # jupyter-client keyring==25.2.1 # via flytekit keyrings-alt==5.0.1 # via -r dev-requirements.in -kiwisolver==1.4.7 - # via matplotlib kubernetes==29.0.0 # via -r dev-requirements.in -llvmlite==0.43.0 - # via numba -markdown==3.7 - # via -r dev-requirements.in markdown-it-py==3.0.0 # via # flytekit # rich -markupsafe==2.1.5 - # via jinja2 marshmallow==3.21.2 # via # dataclasses-json @@ -267,14 +259,10 @@ marshmallow-jsonschema==0.13.0 # via flytekit mashumaro==3.13 # via flytekit -matplotlib==3.9.2 - # via - # phik - # seaborn - # wordcloud - # ydata-profiling matplotlib-inline==0.1.7 - # via ipython + # via + # ipykernel + # ipython mdurl==0.1.2 # via markdown-it-py mock==5.1.0 @@ -290,45 +278,29 @@ msal==1.28.0 # msal-extensions msal-extensions==1.1.0 # via azure-identity +msgpack==1.1.0 + # via flytekit multidict==6.0.5 # via # aiohttp # yarl -multimethod==1.12 - # via - # visions - # ydata-profiling mypy==1.6.1 # via -r dev-requirements.in mypy-extensions==1.0.0 # via # mypy # typing-inspect -networkx==3.3 - # via visions +nest-asyncio==1.6.0 + # via ipykernel nodeenv==1.9.0 # via pre-commit -numba==0.60.0 - # via ydata-profiling numpy==1.26.4 # via # -r dev-requirements.in - # contourpy - # imagehash - # matplotlib - # numba # pandas - # patsy - # phik # pyarrow - # pywavelets # scikit-learn # scipy - # seaborn - # statsmodels - # visions - # wordcloud - # ydata-profiling oauthlib==3.2.2 # via # kubernetes @@ -339,43 +311,25 @@ packaging==24.0 # via # docker # google-cloud-bigquery + # ipykernel # marshmallow - # matplotlib # msal-extensions - # plotly # pytest # setuptools-scm # snowflake-connector-python - # statsmodels pandas==2.2.2 - # via - # -r dev-requirements.in - # phik - # seaborn - # statsmodels - # visions - # ydata-profiling + # via -r dev-requirements.in parso==0.8.4 # via jedi -patsy==0.5.6 - # via statsmodels pexpect==4.9.0 # via ipython -phik==0.12.4 - # via ydata-profiling pillow==10.3.0 - # via - # -r dev-requirements.in - # imagehash - # matplotlib - # visions - # wordcloud + # via -r dev-requirements.in platformdirs==4.2.2 # via + # jupyter-core # snowflake-connector-python # virtualenv -plotly==5.24.1 - # via -r dev-requirements.in pluggy==1.5.0 # via pytest portalocker==2.8.2 @@ -405,6 +359,8 @@ protobuf==4.25.3 # protoc-gen-openapiv2 protoc-gen-openapiv2==0.0.1 # via flyteidl +psutil==6.1.0 + # via ipykernel ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 @@ -420,14 +376,13 @@ pyasn1-modules==0.4.0 pycparser==2.22 # via cffi pydantic==2.9.2 - # via ydata-profiling + # via -r dev-requirements.in pydantic-core==2.23.4 # via pydantic pyflakes==3.2.0 # via autoflake pygments==2.18.0 # via - # -r dev-requirements.in # flytekit # ipython # rich @@ -437,8 +392,6 @@ pyjwt[crypto]==2.8.0 # snowflake-connector-python pyopenssl==24.2.1 # via snowflake-connector-python -pyparsing==3.1.4 - # via matplotlib pytest==8.2.1 # via # -r dev-requirements.in @@ -465,8 +418,8 @@ python-dateutil==2.9.0.post0 # botocore # croniter # google-cloud-bigquery + # jupyter-client # kubernetes - # matplotlib # pandas python-json-logger==2.0.7 # via flytekit @@ -479,14 +432,15 @@ pytz==2024.1 # croniter # pandas # snowflake-connector-python -pywavelets==1.7.0 - # via imagehash pyyaml==6.0.1 # via # flytekit # kubernetes # pre-commit - # ydata-profiling +pyzmq==26.2.0 + # via + # ipykernel + # jupyter-client requests==2.32.3 # via # azure-core @@ -501,7 +455,6 @@ requests==2.32.3 # msal # requests-oauthlib # snowflake-connector-python - # ydata-profiling requests-oauthlib==2.0.0 # via # google-auth-oauthlib @@ -519,14 +472,7 @@ s3fs==2024.5.0 scikit-learn==1.5.0 # via -r dev-requirements.in scipy==1.13.1 - # via - # imagehash - # phik - # scikit-learn - # statsmodels - # ydata-profiling -seaborn==0.13.2 - # via ydata-profiling + # via scikit-learn setuptools-scm==8.1.0 # via -r dev-requirements.in six==1.16.0 @@ -535,7 +481,6 @@ six==1.16.0 # azure-core # isodate # kubernetes - # patsy # python-dateutil snowflake-connector-python==3.12.1 # via -r dev-requirements.in @@ -547,22 +492,22 @@ stack-data==0.6.3 # via ipython statsd==3.3.0 # via flytekit -statsmodels==0.14.3 - # via ydata-profiling -tenacity==9.0.0 - # via plotly threadpoolctl==3.5.0 # via scikit-learn tomlkit==0.13.2 # via snowflake-connector-python -tqdm==4.66.5 - # via ydata-profiling +tornado==6.4.1 + # via + # ipykernel + # jupyter-client traitlets==5.14.3 # via + # comm + # ipykernel # ipython + # jupyter-client + # jupyter-core # matplotlib-inline -typeguard==4.3.0 - # via ydata-profiling types-croniter==2.0.0.20240423 # via -r dev-requirements.in types-decorator==5.1.8.20240310 @@ -584,7 +529,6 @@ typing-extensions==4.12.0 # pydantic-core # rich-click # snowflake-connector-python - # typeguard # typing-inspect typing-inspect==0.9.0 # via dataclasses-json @@ -600,22 +544,16 @@ urllib3==2.2.1 # types-requests virtualenv==20.26.2 # via pre-commit -visions[type-image-path]==0.7.6 - # via ydata-profiling wcwidth==0.2.13 # via prompt-toolkit websocket-client==1.8.0 # via # docker # kubernetes -wordcloud==1.9.3 - # via ydata-profiling wrapt==1.16.0 # via aiobotocore yarl==1.9.4 # via aiohttp -ydata-profiling==4.10.0 - # via -r dev-requirements.in zipp==3.19.1 # via importlib-metadata diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 968e3153eb..a6eb70004b 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -216,6 +216,7 @@ else: from importlib.metadata import entry_points + from flytekit._version import __version__ from flytekit.configuration import Config from flytekit.core.array_node_map_task import map_task diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 30bc43a106..23ab244d90 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -60,7 +60,7 @@ # This is relevant for cases like Dict[int, str]. # If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` def _default_msgpack_decoder(data: bytes) -> Any: - return msgpack.unpackb(data, raw=False, strict_map_key=False) + return msgpack.unpackb(data, strict_map_key=False) class BatchSize: @@ -215,16 +215,41 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + """ + This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and attribute access.` + + For untyped dict, dataclass, and pydantic basemodel: + Life Cycle (Untyped Dict as example): + python val -> msgpack bytes -> binary literal scalar -> msgpack bytes -> python val + (to_literal) (from_binary_idl) + + For attribute access: + Life Cycle: + python val -> msgpack bytes -> binary literal scalar -> resolved golang value -> binary literal scalar -> msgpack bytes -> python val + (to_literal) (propeller attribute access) (from_binary_idl) + """ if binary_idl_object.tag == MESSAGEPACK: try: decoder = self._msgpack_decoder[expected_python_type] except KeyError: decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) self._msgpack_decoder[expected_python_type] = decoder - return decoder.decode(binary_idl_object.value) + python_val = decoder.decode(binary_idl_object.value) + + return python_val else: raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[T]) -> Optional[T]: + """ + TODO: Support all Flyte Types. + This is for dataclass attribute access from input created from the Flyte Console. + + Note: + - This can be removed in the future when the Flyte Console support generate Binary IDL Scalar as input. + """ + raise NotImplementedError(f"Conversion from generic idl to python type {expected_python_type} not implemented") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -321,6 +346,51 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ) return self._to_literal_transformer(python_val) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + if binary_idl_object.tag == MESSAGEPACK: + if expected_python_type in [datetime.date, datetime.datetime, datetime.timedelta]: + """ + MessagePack doesn't support datetime, date, and timedelta. + However, mashumaro's MessagePackEncoder and MessagePackDecoder can convert them to str and vice versa. + That's why we need to use mashumaro's MessagePackDecoder here. + """ + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) + self._msgpack_decoder[expected_python_type] = decoder + python_val = decoder.decode(binary_idl_object.value) + else: + python_val = msgpack.loads(binary_idl_object.value) + """ + In the case below, when using Union Transformer + Simple Transformer, then `a` + can be converted to int, bool, str and float if we use MessagePackDecoder[expected_python_type]. + + Life Cycle: + 1 -> msgpack bytes -> (1, true, "1", 1.0) + + Example Code: + @dataclass + class DC: + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task(container_image=custom_image) + def add(a: Union[int, bool, str, float], b: Union[int, bool, str, float]) -> Union[int, bool, str, float]: + return a + b + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc.a, dc.b) + + wf(DC(1, 1)) + """ + assert type(python_val) == expected_python_type + + return python_val + else: + raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`") + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: expected_python_type = get_underlying_type(expected_python_type) @@ -1124,6 +1194,8 @@ def lazy_import_transformers(cls): from flytekit.extras import pytorch # noqa: F401 if is_imported("sklearn"): from flytekit.extras import sklearn # noqa: F401 + if is_imported("pydantic"): + from flytekit.extras import pydantic # noqa: F401 if is_imported("pandas"): try: from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401 @@ -1776,9 +1848,6 @@ async def async_to_python_value( ) -> Optional[typing.Any]: expected_python_type = get_underlying_type(expected_python_type) - if lv.scalar is not None and lv.scalar.binary is not None: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) - union_tag = None union_type = None if lv.scalar is not None and lv.scalar.union is not None: @@ -1807,9 +1876,15 @@ async def async_to_python_value( assert lv.scalar.union is not None # type checker if isinstance(trans, AsyncTypeTransformer): - res = await trans.async_to_python_value(ctx, lv.scalar.union.value, v) + if lv.scalar.binary: + res = await trans.async_to_python_value(ctx, lv, v) + else: + res = await trans.async_to_python_value(ctx, lv.scalar.union.value, v) else: - res = trans.to_python_value(ctx, lv.scalar.union.value, v) + if lv.scalar.binary: + res = trans.to_python_value(ctx, lv, v) + else: + res = trans.to_python_value(ctx, lv.scalar.union.value, v) if isinstance(res, asyncio.Future): res = await res @@ -2010,7 +2085,42 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p return await FlytePickle.from_pickle(uri) try: - return json.loads(_json_format.MessageToJson(lv.scalar.generic)) + """ + Handles the case where Flyte Console provides input as a protobuf struct. + When resolving an attribute like 'dc.dict_int_ff', FlytePropeller retrieves a dictionary. + Mashumaro's decoder can convert this dictionary to the expected Python object if the correct type is provided. + Since Flyte Types handle their own deserialization, the dictionary is automatically converted to the expected Python object. + + Example Code: + @dataclass + class DC: + dict_int_ff: Dict[int, FlyteFile] + + @workflow + def wf(dc: DC): + t_ff(dc.dict_int_ff) + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> dictionary -> expected Python object + (console user input) (console output) (propeller) (flytekit dict transformer) (mashumaro decoder) + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + - Title: Binary IDL With MessagePack + - Link: https://github.com/flyteorg/flytekit/pull/2760 + """ + + dict_obj = json.loads(_json_format.MessageToJson(lv.scalar.generic)) + msgpack_bytes = msgpack.dumps(dict_obj) + + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) + self._msgpack_decoder[expected_python_type] = decoder + + return decoder.decode(msgpack_bytes) except TypeError: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") @@ -2194,6 +2304,34 @@ def _check_and_covert_float(lv: Literal) -> float: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float") +def _handle_flyte_console_float_input_to_int(lv: Literal) -> int: + """ + Flyte Console is written by JavaScript and JavaScript has only one number type which is float. + We have to convert float to int back in the following example. + + Example Code: + @dataclass + class DC: + a: int + + @workflow + def wf(dc: DC): + t_int(a=dc.a) + + Life Cycle: + json str -> protobuf struct -> resolved float -> float -> int + (console user input) (console output) (propeller) (flytekit simple transformer) (_handle_flyte_console_float_input_to_int) + """ + if lv.scalar.primitive.integer is not None: + return lv.scalar.primitive.integer + + if lv.scalar.primitive.float_value is not None: + logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") + return int(lv.scalar.primitive.float_value) + + raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int") + + def _check_and_convert_void(lv: Literal) -> None: if lv.scalar.none_type is None: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None") @@ -2205,7 +2343,7 @@ def _check_and_convert_void(lv: Literal) -> None: int, _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER), lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))), - lambda x: x.scalar.primitive.integer, + _handle_flyte_console_float_input_to_int, ) FloatTransformer = SimpleTransformer( diff --git a/flytekit/extras/pydantic/__init__.py b/flytekit/extras/pydantic/__init__.py new file mode 100644 index 0000000000..3f7744fe2f --- /dev/null +++ b/flytekit/extras/pydantic/__init__.py @@ -0,0 +1,11 @@ +from flytekit.loggers import logger + +try: + # isolate the exception to the pydantic import + # model_validator and model_serializer are only available in pydantic > 2 + from pydantic import model_serializer, model_validator + + from . import transformer +except (ImportError, OSError) as e: + logger.warning(f"Meet error when importing pydantic: `{e}`") + logger.warning("Flytekit only support pydantic version > 2.") diff --git a/flytekit/extras/pydantic/decorator.py b/flytekit/extras/pydantic/decorator.py new file mode 100644 index 0000000000..ec61710dcf --- /dev/null +++ b/flytekit/extras/pydantic/decorator.py @@ -0,0 +1,57 @@ +import logging +from typing import Any, Callable, TypeVar, Union + +logger = logging.getLogger(__name__) + +try: + # isolate the exception to the pydantic import + # model_validator and model_serializer are only available in pydantic > 2 + from pydantic import model_serializer, model_validator + +except ImportError: + logger.warning( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use FlyteTypes in pydantic BaseModel." + ) + + FuncType = TypeVar("FuncType", bound=Callable[..., Any]) + + from typing_extensions import Literal as typing_literal + + def model_serializer( + __f: Union[Callable[..., Any], None] = None, + *, + mode: typing_literal["plain", "wrap"] = "plain", + when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always", + return_type: Any = None, + ) -> Callable[[Any], Any]: + """Placeholder decorator for Pydantic model_serializer.""" + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args, **kwargs): + raise Exception( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." + ) + + return wrapper + + # If no function (__f) is provided, return the decorator + if __f is None: + return decorator + # If __f is provided, directly decorate the function + return decorator(__f) + + def model_validator( + *, + mode: typing_literal["wrap", "before", "after"], + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Placeholder decorator for Pydantic model_validator.""" + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args, **kwargs): + raise Exception( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." + ) + + return wrapper + + return decorator diff --git a/flytekit/extras/pydantic/transformer.py b/flytekit/extras/pydantic/transformer.py new file mode 100644 index 0000000000..4abefcc298 --- /dev/null +++ b/flytekit/extras/pydantic/transformer.py @@ -0,0 +1,81 @@ +import json +from typing import Type + +import msgpack +from google.protobuf import json_format as _json_format +from pydantic import BaseModel + +from flytekit import FlyteContext +from flytekit.core.constants import MESSAGEPACK +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.loggers import logger +from flytekit.models import types +from flytekit.models.literals import Binary, Literal, Scalar +from flytekit.models.types import LiteralType, TypeStructure + + +class PydanticTransformer(TypeTransformer[BaseModel]): + def __init__(self): + super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False) + + def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: + schema = t.model_json_schema() + literal_type = {} + fields = t.__annotations__.items() + + for name, python_type in fields: + try: + literal_type[name] = TypeEngine.to_literal_type(python_type) + except Exception as e: + logger.warning( + "Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e) + ) + + ts = TypeStructure(tag="", dataclass_type=literal_type) + + return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) + + def to_literal( + self, + ctx: FlyteContext, + python_val: BaseModel, + python_type: Type[BaseModel], + expected: types.LiteralType, + ) -> Literal: + """ + For pydantic basemodel, we have to go through json first. + This is for handling enum in basemodel. + More details: https://github.com/flyteorg/flytekit/pull/2792 + """ + json_str = python_val.model_dump_json() + dict_obj = json.loads(json_str) + msgpack_bytes = msgpack.dumps(dict_obj) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) + + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel: + if binary_idl_object.tag == MESSAGEPACK: + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) + json_str = json.dumps(dict_obj) + python_val = expected_python_type.model_validate_json( + json_data=json_str, strict=False, context={"deserialize": True} + ) + return python_val + else: + raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: + """ + There will have 2 kinds of literal values: + 1. protobuf Struct (From Flyte Console) + 2. binary scalar (Others) + Hence we have to handle 2 kinds of cases. + """ + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + + json_str = _json_format.MessageToJson(lv.scalar.generic) + python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True}) + return python_val + + +TypeEngine.register(PydanticTransformer()) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 1d038b0319..8cc2cc21cf 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -2,6 +2,7 @@ import datetime import enum import importlib +import importlib.util import json import logging import os @@ -39,7 +40,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool: return False else: try: - from pydantic.v1 import BaseModel + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2) except ImportError: from pydantic import BaseModel @@ -374,7 +378,24 @@ def has_nested_dataclass(t: typing.Type) -> bool: return parsed_value if is_pydantic_basemodel(self._python_type): - return self._python_type.parse_raw(json.dumps(parsed_value)) # type: ignore + """ + This function supports backward compatibility for the Pydantic v1 plugin. + If the class is a Pydantic BaseModel, it attempts to parse JSON input using + the appropriate version of Pydantic (v1 or v2). + """ + try: + if importlib.util.find_spec("pydantic.v1") is not None: + from pydantic import BaseModel as BaseModelV2 + + if issubclass(self._python_type, BaseModelV2): + return self._python_type.model_validate_json( + json.dumps(parsed_value), strict=False, context={"deserialize": True} + ) + except ImportError: + pass + + # The behavior of the Pydantic v1 plugin. + return self._python_type.parse_raw(json.dumps(parsed_value)) # Ensure that the python type has `from_json` function if not hasattr(self._python_type, "from_json"): diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 87b494d0ae..83bb0c8fa8 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -16,7 +16,7 @@ import typing -from .types import FlyteDirectory +from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer # The following section provides some predefined aliases for commonly used FlyteDirectory formats. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 7e22879126..d7126b7367 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -1,28 +1,32 @@ from __future__ import annotations +import json import os import pathlib import random import typing from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Generator, Tuple +from typing import Any, Dict, Generator, Tuple from uuid import UUID import fsspec import msgpack from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.types import SerializableType -from flytekit import BlobType from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.pydantic.decorator import model_serializer, model_validator from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types +from flytekit.models.core.types import BlobType from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types.file import FileExt, FlyteFile @@ -131,12 +135,21 @@ def _serialize(self) -> typing.Dict[str, str]: @classmethod def _deserialize(cls, value) -> "FlyteDirectory": - path = value.get("path", None) + return FlyteDirToMultipartBlobTransformer().dict_to_flyte_directory(dict_obj=value, expected_python_type=cls) - if path is None: - raise ValueError("FlyteDirectory's path should not be None") + @model_serializer + def serialize_flyte_dir(self) -> Dict[str, str]: + lv = FlyteDirToMultipartBlobTransformer().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + return {"path": lv.scalar.blob.uri} - return FlyteDirToMultipartBlobTransformer().to_python_value( + @model_validator(mode="after") + def deserialize_flyte_dir(self, info) -> FlyteDirectory: + if info.context is None or info.context.get("deserialize") is not True: + return self + + pv = FlyteDirToMultipartBlobTransformer().to_python_value( FlyteContextManager.current_context(), Literal( scalar=Scalar( @@ -146,12 +159,13 @@ def _deserialize(cls, value) -> "FlyteDirectory": format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=path, + uri=self.path, ) ) ), - cls, + type(self), ) + return pv def __init__( self, @@ -532,42 +546,106 @@ async def async_to_literal( else: return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path))) + def dict_to_flyte_directory( + self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Type[FlyteDirectory] + ) -> FlyteDirectory: + path = dict_obj.get("path", None) + + if path is None: + raise ValueError("FlyteDirectory's path should not be None") + + return self.to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + fd: FlyteDirectory + + @workflow + def wf(dc: DC): + t_fd(dc.fd) + + Note: + - The deserialization is the same as put a flyte directory in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - path = python_val.get("path", None) - - if path is None: - raise ValueError("FlyteDirectory's path should not be None") - - return FlyteDirToMultipartBlobTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ), - uri=path, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_flyte_directory(python_val, expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: typing.Type[FlyteDirectory]) -> FlyteDirectory: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + fd: FlyteDirectory + + @workflow + def wf(dc: DC): + t_fd(dc.fd) + + Note: + - The deserialization is the same as put a flyte directory in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_directory(python_val, expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: - if lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) - - uri = lv.scalar.blob.uri + # Handle dataclass attribute access + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) + + try: + uri = lv.scalar.blob.uri + except AttributeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART: raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.") diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 838516f33d..8d69247e10 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -25,7 +25,7 @@ from typing_extensions import Annotated, get_args, get_origin -from .file import FlyteFile +from .file import FlyteFile, FlyteFilePathTransformer class FileExt: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index a087af11eb..4da08a48b0 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -1,16 +1,19 @@ from __future__ import annotations +import json import mimetypes import os import pathlib import typing from contextlib import contextmanager from dataclasses import dataclass, field -from typing import cast +from typing import Dict, cast from urllib.parse import unquote import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -24,6 +27,7 @@ get_underlying_type, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.pydantic.decorator import model_serializer, model_validator from flytekit.loggers import logger from flytekit.models.core import types as _core_types from flytekit.models.core.types import BlobType @@ -159,12 +163,19 @@ def _serialize(self) -> typing.Dict[str, str]: @classmethod def _deserialize(cls, value) -> "FlyteFile": - path = value.get("path", None) + return FlyteFilePathTransformer().dict_to_flyte_file(dict_obj=value, expected_python_type=cls) - if path is None: - raise ValueError("FlyteFile's path should not be None") + @model_serializer + def serialize_flyte_file(self) -> Dict[str, str]: + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_file(self, info) -> "FlyteFile": + if info.context is None or info.context.get("deserialize") is not True: + return self - return FlyteFilePathTransformer().to_python_value( + pv = FlyteFilePathTransformer().to_python_value( FlyteContextManager.current_context(), Literal( scalar=Scalar( @@ -174,12 +185,13 @@ def _deserialize(cls, value) -> "FlyteFile": format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=path, + uri=self.path, ) ) ), - cls, + type(self), ) + return pv @classmethod def extension(cls) -> str: @@ -548,41 +560,103 @@ def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, s return {"ContentEncoding": "gzip"} return {} + def dict_to_flyte_file( + self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] + ) -> FlyteFile: + path = dict_obj.get("path", None) + + if path is None: + raise ValueError("FlyteFile's path should not be None") + + return self.to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + ff: FlyteFile + + @workflow + def wf(dc: DC): + t_ff(dc.ff) + + Note: + - The deserialization is the same as put a flyte file in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - path = python_val.get("path", None) - - if path is None: - raise ValueError("FlyteFile's path should not be None") - - return FlyteFilePathTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) - ), - uri=path, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_flyte_file(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] + ) -> FlyteFile: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + ff: FlyteFile + + @workflow + def wf(dc: DC): + t_ff(dc.ff) + + Note: + - The deserialization is the same as put a flyte file in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_file(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) try: uri = lv.scalar.blob.uri diff --git a/flytekit/types/schema/__init__.py b/flytekit/types/schema/__init__.py index 080927021a..33ee8ef72c 100644 --- a/flytekit/types/schema/__init__.py +++ b/flytekit/types/schema/__init__.py @@ -1,5 +1,6 @@ from .types import ( FlyteSchema, + FlyteSchemaTransformer, LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 28a2c542ef..7992368613 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -1,16 +1,19 @@ from __future__ import annotations import datetime +import json import os import typing from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Type +from typing import Dict, Optional, Type import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -18,6 +21,7 @@ from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.extras.pydantic.decorator import model_serializer, model_validator from flytekit.loggers import logger from flytekit.models.literals import Binary, Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -185,7 +189,7 @@ class FlyteSchema(SerializableType, DataClassJSONMixin): This is the main schema class that users should use. """ - def _serialize(self) -> typing.Dict[str, typing.Optional[str]]: + def _serialize(self) -> Dict[str, Optional[str]]: FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) return {"remote_path": self.remote_path} @@ -203,6 +207,23 @@ def _deserialize(cls, value) -> "FlyteSchema": cls, ) + @model_serializer + def serialize_flyte_schema(self) -> Dict[str, Optional[str]]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"remote_path": self.remote_path} + + @model_validator(mode="after") + def deserialize_flyte_schema(self, info) -> FlyteSchema: + if info.context is None or info.context.get("deserialize") is not True: + return self + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), + type(self), + ) + @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: return {} @@ -445,29 +466,89 @@ async def async_to_literal( ) return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type)))) + def dict_to_flyte_schema( + self, dict_obj: typing.Dict[str, str], expected_python_type: Type[FlyteSchema] + ) -> FlyteSchema: + remote_path = dict_obj.get("remote_path", None) + + if remote_path is None: + raise ValueError("FlyteSchema's path should not be None") + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))), + expected_python_type, + ) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: - if binary_idl_object.tag == MESSAGEPACK: - python_val = msgpack.loads(binary_idl_object.value) - remote_path = python_val.get("remote_path", None) + """ + If the input is from flytekit, the Life Cycle will be as follows: - if remote_path is None: - raise ValueError("FlyteSchema's path should not be None") + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) - t = FlyteSchemaTransformer() - return t.to_python_value( - FlyteContextManager.current_context(), - Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))), - expected_python_type, - ) + Example Code: + @dataclass + class DC: + fs: FlyteSchema + + @workflow + def wf(dc: DC): + t_fs(dc.fs) + + Note: + - The deserialization is the same as put a flyte schema in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + if binary_idl_object.tag == MESSAGEPACK: + python_val = msgpack.loads(binary_idl_object.value) + return self.dict_to_flyte_schema(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + fs: FlyteSchema + + @workflow + def wf(dc: DC): + t_fs(dc.fs) + + Note: + - The deserialization is the same as put a flyte schema in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_schema(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema] ) -> FlyteSchema: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) def downloader(x, y): ctx.file_access.get_data(x, y, is_multipart=True) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 05d1fa86e3..254ff16721 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -7,9 +7,9 @@ :template: custom.rst :toctree: generated/ - StructuredDataset - StructuredDatasetEncoder - StructuredDatasetDecoder + StructuredDataset + StructuredDatasetDecoder + StructuredDatasetEncoder """ from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer @@ -19,7 +19,9 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, + StructuredDatasetMetadata, StructuredDatasetTransformerEngine, + StructuredDatasetType, ) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 57c028e71c..0404e3b380 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -2,6 +2,7 @@ import _datetime import collections +import json import types import typing from abc import ABC, abstractmethod @@ -11,6 +12,8 @@ import msgpack from dataclasses_json import config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -21,6 +24,7 @@ from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.deck.renderer import Renderable +from flytekit.extras.pydantic.decorator import model_serializer, model_validator from flytekit.loggers import developer_logger, logger from flytekit.models import literals from flytekit.models import types as type_models @@ -91,6 +95,38 @@ def _deserialize(cls, value) -> "StructuredDataset": cls, ) + @model_serializer + def serialize_structured_dataset(self) -> Dict[str, Optional[str]]: + lv = StructuredDatasetTransformerEngine().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return { + "uri": sd.uri, + "file_format": sd.file_format, + } + + @model_validator(mode="after") + def deserialize_structured_dataset(self, info) -> StructuredDataset: + if info.context is None or info.context.get("deserialize") is not True: + return self + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=self.file_format) + ), + uri=self.uri, + ) + ) + ), + type(self), + ) + @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: return {} @@ -724,34 +760,93 @@ def encode( sd._already_uploaded = True return lit + def dict_to_structured_dataset( + self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: + uri = dict_obj.get("uri", None) + file_format = dict_obj.get("file_format", None) + + if uri is None: + raise ValueError("StructuredDataset's uri and file format should not be None") + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=file_format) + ), + uri=uri, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + sd: StructuredDataset + + @workflow + def wf(dc: DC): + t_sd(dc.sd) + + Note: + - The deserialization is the same as put a structured dataset in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - uri = python_val.get("uri", None) - file_format = python_val.get("file_format", None) - - if uri is None: - raise ValueError("StructuredDataset's uri and file format should not be None") - - return StructuredDatasetTransformerEngine().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - structured_dataset=StructuredDataset( - metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=file_format) - ), - uri=uri, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + sd: StructuredDataset + + @workflow + def wf(dc: DC): + t_sd(dc.sd) + + Note: + - The deserialization is the same as put a structured dataset in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: @@ -786,8 +881,11 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... +-----------------------------+-----------------------------------------+--------------------------------------+ """ # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) # Detect annotations and extract out all the relevant information that the user might supply expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index 3f42c9cd21..8fef623d03 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -1,5 +1,10 @@ # Flytekit Pydantic Plugin +## Warning +This plugin is deprecated and will be removed in the future. +Please directly install `pydantic` and use `BaseModel` in your Flyte tasks. + +## Introduction Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations. The plugin adds type support for pydantic models. diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 23e7e341bd..491bd8c9c4 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1,4 +1,11 @@ +from flytekit.loggers import logger + from .basemodel_transformer import BaseModelTransformer from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types _set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model +logger.warning( + "The Flytekit Pydantic V1 plugin is deprecated.\n" + "Please uninstall `flytekitplugins-pydantic` and install Pydantic directly.\n" + "You can now use Pydantic V2 BaseModels in Flytekit tasks." +) diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py index 63e2c941e7..7001506a70 100644 --- a/plugins/flytekit-pydantic/setup.py +++ b/plugins/flytekit-pydantic/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.7.0b0", "pydantic"] +plugin_requires = ["flytekit>=1.7.0b0", "pydantic<2"] __version__ = "0.0.0+develop" diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index aa7e7dca4f..4eac1a1296 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -22,7 +22,10 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar +import json # Fixture that ensures a dummy local file @pytest.fixture @@ -364,3 +367,13 @@ def my_wf(path: SvgDirectory) -> DC: dc1 = my_wf(path=svg_directory) dc2 = DC(f=svg_directory) assert dc1 == dc2 + +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_directory} + json_str = json.dumps(dict_obj) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, FlyteDirectory) + assert isinstance(downstream_input, FlyteDirectory) + assert downstream_input == FlyteDirectory(local_dummy_directory) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 352984ca37..33b97aa589 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -1,3 +1,4 @@ +import json import os import pathlib import tempfile @@ -20,7 +21,9 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar # Fixture that ensures a dummy local file @pytest.fixture @@ -705,3 +708,12 @@ def test_new_remote_file(): nf = FlyteFile.new_remote_file(name="foo.txt") assert isinstance(nf, FlyteFile) assert nf.path.endswith('foo.txt') + +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_file} + json_str = json.dumps(dict_obj) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, FlyteFile) + assert downstream_input == FlyteFile(local_dummy_file) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 7966b00f2c..0db5f10a46 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3287,7 +3287,7 @@ class InnerWorkflowOutput(DataClassJSONMixin): @task def inner_task(input: float) -> float | None: - if input == 0: + if input == 0.0: return None return input @@ -3322,7 +3322,7 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: float_value_output = outer_workflow(OuterWorkflowInput(input=1.0)).nullable_output assert float_value_output == 1.0, f"Float value was {float_value_output}, not 1.0 as expected" - none_value_output = outer_workflow(OuterWorkflowInput(input=0)).nullable_output + none_value_output = outer_workflow(OuterWorkflowInput(input=0.0)).nullable_output assert none_value_output is None, f"None value was {none_value_output}, not None as expected" diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 986fac0c7c..5651b80450 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from datetime import date, datetime, timedelta from enum import Enum -from typing import Dict, List +from typing import Dict, List, Optional, Union import pytest from google.protobuf import json_format as _json_format @@ -627,7 +627,7 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING @@ -690,8 +690,6 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - @workflow def wf(dc: DC): t_inner(dc.inner_dc) @@ -710,6 +708,8 @@ def wf(dc: DC): wf(dc=DC()) def test_backward_compatible_with_dataclass_in_protobuf_struct(local_dummy_file, local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + # This test also test how Flyte Console with attribute access on the Struct object @dataclass class InnerDC: @@ -777,7 +777,7 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING @@ -838,8 +838,6 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - # This is the old dataclass serialization behavior. # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 dc = DC() @@ -875,3 +873,335 @@ def test_backward_compatible_with_untyped_dict_in_protobuf_struct(): downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, dict) assert dict_input == downstream_input + +def test_flyte_console_input_with_typed_dict_with_flyte_types_in_dataclass_in_protobuf_struct(local_dummy_file, local_dummy_directory): + # TODO: We can add more nested cases for non-flyte types. + """ + Handles the case where Flyte Console provides input as a protobuf struct. + When resolving an attribute like 'dc.dict_int_ff', FlytePropeller retrieves a dictionary. + Mashumaro's decoder can convert this dictionary to the expected Python object if the correct type is provided. + Since Flyte Types handle their own deserialization, the dictionary is automatically converted to the expected Python object. + + Example Code: + @dataclass + class DC: + dict_int_ff: Dict[int, FlyteFile] + + @workflow + def wf(dc: DC): + t_ff(dc.dict_int_ff) + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> dictionary -> expected Python object + (console user input) (console output) (propeller) (flytekit dict transformer) (mashumaro decoder) + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + - Title: Binary IDL With MessagePack + - Link: https://github.com/flyteorg/flytekit/pull/2760 + """ + + dict_int_flyte_file = {"1" : {"path": local_dummy_file}} + json_str = json.dumps(dict_int_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteFile]) + assert downstream_input == {1: FlyteFile(local_dummy_file)} + + # FlyteConsole trims trailing ".0" when converting float-like strings + dict_float_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_float_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + assert downstream_input == {1.0: FlyteFile(local_dummy_file)} + + dict_float_flyte_file = {"1.0": {"path": local_dummy_file}} + json_str = json.dumps(dict_float_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + assert downstream_input == {1.0: FlyteFile(local_dummy_file)} + + dict_str_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_str_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[str, FlyteFile]) + assert downstream_input == {"1": FlyteFile(local_dummy_file)} + + dict_int_flyte_directory = {"1": {"path": local_dummy_directory}} + json_str = json.dumps(dict_int_flyte_directory) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteDirectory]) + assert downstream_input == {1: FlyteDirectory(local_dummy_directory)} + + # FlyteConsole trims trailing ".0" when converting float-like strings + dict_float_flyte_directory = {"1": {"path": local_dummy_directory}} + json_str = json.dumps(dict_float_flyte_directory) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} + + dict_float_flyte_directory = {"1.0": {"path": local_dummy_directory}} + json_str = json.dumps(dict_float_flyte_directory) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} + + dict_str_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_str_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, + Dict[str, FlyteFile]) + assert downstream_input == {"1": FlyteFile(local_dummy_file)} + +def test_all_types_with_optional_in_dataclass_basemodel_wf(local_dummy_file, local_dummy_directory): + @dataclass + class InnerDC: + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field(default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Optional[Status] = field(default=Status.PENDING) + + @dataclass + class DC: + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field(default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_dc: Optional[InnerDC] = field(default_factory=lambda: InnerDC()) + enum_status: Optional[Status] = field(default=Status.PENDING) + + @task + def t_inner(inner_dc: InnerDC): + assert type(inner_dc) is InnerDC + + # f: List[FlyteFile] + for ff in inner_dc.f: # type: ignore + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_dc.j.items(): # type: ignore + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_dc.n) is FlyteFile + with open(inner_dc.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_dc.o) is FlyteDirectory + assert not inner_dc.o.downloaded + with open(os.path.join(inner_dc.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_dc.o.downloaded + + # enum: Status + assert inner_dc.enum_status == Status.PENDING + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + @workflow + def wf(dc: DC): + t_inner(dc.inner_dc) + t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, + d=dc.d, e=dc.e, f=dc.f, + g=dc.g, h=dc.h, i=dc.i, + j=dc.j, k=dc.k, l=dc.l, + m=dc.m, n=dc.n, o=dc.o, + enum_status=dc.enum_status) + + wf(dc=DC()) + + +def test_all_types_with_optional_and_none_in_dataclass_wf(): + @dataclass + class InnerDC: + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + enum_status: Optional[Status] = None + + @dataclass + class DC: + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + inner_dc: Optional[InnerDC] = None + enum_status: Optional[Status] = None + + @task + def t_inner(inner_dc: Optional[InnerDC]): + return inner_dc + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + return + + @workflow + def wf(dc: DC): + t_inner(dc.inner_dc) + t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, + d=dc.d, e=dc.e, f=dc.f, + g=dc.g, h=dc.h, i=dc.i, + j=dc.j, k=dc.k, l=dc.l, + m=dc.m, n=dc.n, o=dc.o, + enum_status=dc.enum_status) + + wf(dc=DC()) + +def test_union_in_dataclass_wf(): + @dataclass + class DC: + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task + def add(a: Union[int, bool, str, float], b: Union[int, bool, str, float]) -> Union[int, bool, str, float]: + return a + b # type: ignore + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc.a, dc.b) + + assert wf(dc=DC(a=1, b=2)) == 3 + assert wf(dc=DC(a=True, b=False)) == True + assert wf(dc=DC(a=False, b=False)) == False + assert wf(dc=DC(a="hello", b="world")) == "helloworld" + assert wf(dc=DC(a=1.0, b=2.0)) == 3.0 + + @task + def add(dc1: DC, dc2: DC) -> Union[int, bool, str, float]: + return dc1.a + dc2.b # type: ignore + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc, dc) + + assert wf(dc=DC(a=1, b=2)) == 3 + + @workflow + def wf(dc: DC) -> DC: + return dc + + assert wf(dc=DC(a=1, b=2)) == DC(a=1, b=2) diff --git a/tests/flytekit/unit/extras/pydantic/test_pydantic_transformer.py b/tests/flytekit/unit/extras/pydantic/test_pydantic_transformer.py new file mode 100644 index 0000000000..0aa758ce27 --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic/test_pydantic_transformer.py @@ -0,0 +1,775 @@ +import os +import tempfile +from dataclasses import field +from enum import Enum +from typing import Dict, List, Optional, Union +from pydantic import BaseModel, Field +from unittest.mock import patch +import pytest +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct + +from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import Literal, Scalar +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset + + +class Status(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +@pytest.fixture +def local_dummy_file(): + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello FlyteFile") + yield path + finally: + os.remove(path) + + +@pytest.fixture +def local_dummy_directory(): + temp_dir = tempfile.TemporaryDirectory() + try: + with open(os.path.join(temp_dir.name, "file"), "w") as tmp: + tmp.write("Hello FlyteDirectory") + yield temp_dir.name + finally: + temp_dir.cleanup() + + +def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + + class BM(BaseModel): + flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_bm: InnerBM = field(default_factory=lambda: InnerBM()) + + @task + def t1(path: FlyteFile) -> FlyteFile: + return path + + @task + def t2(path: FlyteDirectory) -> FlyteDirectory: + return path + + @workflow + def wf(bm: BM) -> (FlyteFile, FlyteFile, FlyteDirectory, FlyteDirectory): + f1 = t1(path=bm.flytefile) + f2 = t1(path=bm.inner_bm.flytefile) + d1 = t2(path=bm.flytedir) + d2 = t2(path=bm.inner_bm.flytedir) + return f1, f2, d1, d2 + + o1, o2, o3, o4 = wf(bm=BM()) + with open(o1, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(o2, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(os.path.join(o3, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + with open(os.path.join(o4, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + +def test_all_types_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Status = field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_bm: InnerBM = field(default_factory=lambda: InnerBM()) + enum_status: Status = field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + + @task + def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], + h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], + k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, + n: FlyteFile, o: FlyteDirectory, enum_status: Status): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes(a=bm.a, b=bm.b, c=bm.c, + d=bm.d, e=bm.e, f=bm.f, + g=bm.g, h=bm.h, i=bm.i, + j=bm.j, k=bm.k, l=bm.l, + m=bm.m, n=bm.n, o=bm.o, enum_status=bm.enum_status) + + t_test_all_attributes(a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, + d=bm.inner_bm.d, e=bm.inner_bm.e, f=bm.inner_bm.f, + g=bm.inner_bm.g, h=bm.inner_bm.h, i=bm.inner_bm.i, + j=bm.inner_bm.j, k=bm.inner_bm.k, l=bm.inner_bm.l, + m=bm.inner_bm.m, n=bm.inner_bm.n, o=bm.inner_bm.o, enum_status=bm.inner_bm.enum_status) + + wf(bm=BM()) + +def test_all_types_with_optional_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field(default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Optional[Status] = field(default=Status.PENDING) + + class BM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field(default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_bm: Optional[InnerBM] = field(default_factory=lambda: InnerBM()) + enum_status: Optional[Status] = field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes(a=bm.a, b=bm.b, c=bm.c, + d=bm.d, e=bm.e, f=bm.f, + g=bm.g, h=bm.h, i=bm.i, + j=bm.j, k=bm.k, l=bm.l, + m=bm.m, n=bm.n, o=bm.o, + enum_status=bm.enum_status) + + wf(bm=BM()) + + +def test_all_types_with_optional_and_none_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + enum_status: Optional[Status] = None + + class BM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + inner_bm: Optional[InnerBM] = None + enum_status: Optional[Status] = None + + @task + def t_inner(inner_bm: Optional[InnerBM]): + return inner_bm + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + return + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes(a=bm.a, b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, f=bm.f, + g=bm.g, h=bm.h, i=bm.i, + j=bm.j, k=bm.k, l=bm.l, + m=bm.m, n=bm.n, o=bm.o, + enum_status=bm.enum_status) + + wf(bm=BM()) + +def test_input_from_flyte_console_pydantic_basemodel(local_dummy_file, local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Status = field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + m: dict = field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_bm: InnerBM = field(default_factory=lambda: InnerBM()) + enum_status: Status = field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], + h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], + k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict, + n: FlyteFile, o: FlyteDirectory, enum_status: Status): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + # This is the old dataclass serialization behavior. + # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29bmd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 + bm = BM() + json_str = bm.model_dump_json() + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, BM) + t_inner(downstream_input.inner_bm) + t_test_all_attributes(a=downstream_input.a, b=downstream_input.b, c=downstream_input.c, + d=downstream_input.d, e=downstream_input.e, f=downstream_input.f, + g=downstream_input.g, h=downstream_input.h, i=downstream_input.i, + j=downstream_input.j, k=downstream_input.k, l=downstream_input.l, + m=downstream_input.m, n=downstream_input.n, o=downstream_input.o, + enum_status=downstream_input.enum_status) + t_test_all_attributes(a=downstream_input.inner_bm.a, b=downstream_input.inner_bm.b, c=downstream_input.inner_bm.c, + d=downstream_input.inner_bm.d, e=downstream_input.inner_bm.e, f=downstream_input.inner_bm.f, + g=downstream_input.inner_bm.g, h=downstream_input.inner_bm.h, i=downstream_input.inner_bm.i, + j=downstream_input.inner_bm.j, k=downstream_input.inner_bm.k, l=downstream_input.inner_bm.l, + m=downstream_input.inner_bm.m, n=downstream_input.inner_bm.n, o=downstream_input.inner_bm.o, + enum_status=downstream_input.inner_bm.enum_status) + +def test_dataclasss_in_pydantic_basemodel(): + from dataclasses import dataclass + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes(a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d) + + bm = BM() + wf(bm=bm) + +def test_pydantic_dataclasss_in_pydantic_basemodel(): + from pydantic.dataclasses import dataclass + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes(a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d) + + bm = BM() + wf(bm=bm) + +def test_flyte_types_deserialization_not_called_when_using_constructor(local_dummy_file, local_dummy_directory): + # Mocking both FlyteFilePathTransformer and FlyteDirectoryPathTransformer + with patch('flytekit.types.file.FlyteFilePathTransformer.to_python_value') as mock_file_to_python_value, \ + patch('flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value') as mock_directory_to_python_value, \ + patch('flytekit.types.structured.StructuredDatasetTransformerEngine.to_python_value') as mock_structured_dataset_to_python_value, \ + patch('flytekit.types.schema.FlyteSchemaTransformer.to_python_value') as mock_schema_to_python_value: + + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + sd: StructuredDataset = field(default_factory=lambda: StructuredDataset()) + fsc: FlyteSchema = field(default_factory=lambda: FlyteSchema()) + + # Create an instance of BM (should not call the deserialization) + BM() + + mock_file_to_python_value.assert_not_called() + mock_directory_to_python_value.assert_not_called() + mock_structured_dataset_to_python_value.assert_not_called() + mock_schema_to_python_value.assert_not_called() + +def test_flyte_types_deserialization_called_once_when_using_model_validate_json(local_dummy_file, local_dummy_directory): + """ + It's hard to mock flyte schema and structured dataset in tests, so we will only test FlyteFile and FlyteDirectory + """ + with patch('flytekit.types.file.FlyteFilePathTransformer.to_python_value') as mock_file_to_python_value, \ + patch('flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value') as mock_directory_to_python_value: + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + + # Create instances of FlyteFile and FlyteDirectory + bm = BM(ff=FlyteFile(local_dummy_file), fd=FlyteDirectory(local_dummy_directory)) + + # Serialize and Deserialize with model_validate_json + json_str = bm.model_dump_json() + bm.model_validate_json(json_data=json_str, strict=False, context={"deserialize": True}) + + # Assert that the to_python_value method was called once + mock_file_to_python_value.assert_called_once() + mock_directory_to_python_value.assert_called_once() + +def test_union_in_basemodel_wf(): + class bm(BaseModel): + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task + def add(a: Union[int, bool, str, float], b: Union[int, bool, str, float]) -> Union[int, bool, str, float]: + return a + b # type: ignore + + @workflow + def wf(bm: bm) -> Union[int, bool, str, float]: + return add(bm.a, bm.b) + + assert wf(bm=bm(a=1, b=2)) == 3 + assert wf(bm=bm(a=True, b=False)) == True + assert wf(bm=bm(a=False, b=False)) == False + assert wf(bm=bm(a="hello", b="world")) == "helloworld" + assert wf(bm=bm(a=1.0, b=2.0)) == 3.0 + + @task + def add_bm(bm1: bm, bm2: bm) -> Union[int, bool, str, float]: + return bm1.a + bm2.b # type: ignore + + @workflow + def wf_add_bm(bm: bm) -> Union[int, bool, str, float]: + return add_bm(bm, bm) + + assert wf_add_bm(bm=bm(a=1, b=2)) == 3 + + @workflow + def wf_return_bm(bm: bm) -> bm: + return bm + + assert wf_return_bm(bm=bm(a=1, b=2)) == bm(a=1, b=2) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index f107384b96..9487c6f4c3 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -16,7 +16,6 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteAssertion from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -49,7 +48,6 @@ ) df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - def test_protocol(): assert get_protocol("s3://my-s3-bucket/file") == "s3" assert get_protocol("/file") == "file" @@ -57,8 +55,6 @@ def test_protocol(): def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) - - def test_formats_make_sense(): @task def t1(a: pd.DataFrame) -> pd.DataFrame: