diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f6999cf11..1663240a5d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: hooks: # Run the linter. - id: ruff - args: [--fix, --show-fixes, --output-format=full] + args: [--fix, --show-fixes, --output-format=full, --extend-ignore=I001] # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/dev-requirements.in b/dev-requirements.in index 27c17ac6d0..20aba11e9d 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -30,6 +30,7 @@ torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 # pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436 torch; python_version<'3.12' +pydantic # TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. # We have temporarily disabled this feature on Windows and are using python-magic for Mac OS and Linux instead. diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 968e3153eb..9e84e51243 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -216,6 +216,7 @@ else: from importlib.metadata import entry_points + from flytekit._version import __version__ from flytekit.configuration import Config from flytekit.core.array_node_map_task import map_task @@ -263,6 +264,8 @@ StructuredDatasetType, ) +from flytekit.core import type_utils # noqa: I001 + def current_context() -> ExecutionParameters: """ diff --git a/flytekit/core/type_utils.py b/flytekit/core/type_utils.py index 6119ca804d..8cb42332c9 100644 --- a/flytekit/core/type_utils.py +++ b/flytekit/core/type_utils.py @@ -1,133 +1,142 @@ -# from typing import Dict -# -# from flytekit.core.context_manager import FlyteContextManager -# from flytekit.models.core import types as _core_types -# from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema -# from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer -# from flytekit.types.file import FlyteFile, FlyteFilePathTransformer -# from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer -# from flytekit.types.structured import ( -# StructuredDataset, -# StructuredDatasetMetadata, -# StructuredDatasetTransformerEngine, -# StructuredDatasetType, -# ) -# -# # Conditional import for Pydantic model_serializer and model_validator -# try: -# from pydantic import model_serializer, model_validator -# -# # Serialize and Deserialize functions -# @model_serializer -# def serialize_flyte_file(self) -> Dict[str, str]: -# lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) -# return {"path": lv.scalar.blob.uri} -# -# @model_validator(mode="after") -# def deserialize_flyte_file(self, info) -> FlyteFile: -# if info.context is None or info.context.get("deserialize") is not True: -# return self -# pv = FlyteFilePathTransformer().to_python_value( -# FlyteContextManager.current_context(), -# Literal( -# scalar=Scalar( -# blob=Blob( -# metadata=BlobMetadata( -# type=_core_types.BlobType( -# format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE -# ) -# ), -# uri=self.path, -# ) -# ) -# ), -# type(self), -# ) -# return pv -# -# @model_serializer -# def serialize_flyte_dir(self) -> Dict[str, str]: -# lv = FlyteDirToMultipartBlobTransformer().to_literal( -# FlyteContextManager.current_context(), self, type(self), None -# ) -# return {"path": lv.scalar.blob.uri} -# -# @model_validator(mode="after") -# def deserialize_flyte_dir(self, info) -> FlyteDirectory: -# if info.context is None or info.context.get("deserialize") is not True: -# return self -# pv = FlyteDirToMultipartBlobTransformer().to_python_value( -# FlyteContextManager.current_context(), -# Literal( -# scalar=Scalar( -# blob=Blob( -# metadata=BlobMetadata( -# type=_core_types.BlobType( -# format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART -# ) -# ), -# uri=self.path, -# ) -# ) -# ), -# type(self), -# ) -# return pv -# -# @model_serializer -# def serialize_flyte_schema(self) -> Dict[str, str]: -# FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) -# return {"remote_path": self.remote_path} -# -# @model_validator(mode="after") -# def deserialize_flyte_schema(self, info) -> FlyteSchema: -# if info.context is None or info.context.get("deserialize") is not True: -# return self -# t = FlyteSchemaTransformer() -# return t.to_python_value( -# FlyteContextManager.current_context(), -# Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), -# type(self), -# ) -# -# @model_serializer -# def serialize_structured_dataset(self) -> Dict[str, str]: -# lv = StructuredDatasetTransformerEngine().to_literal( -# FlyteContextManager.current_context(), self, type(self), None -# ) -# sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) -# sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format -# return { -# "uri": sd.uri, -# "file_format": sd.file_format, -# } -# -# @model_validator(mode="after") -# def deserialize_structured_dataset(self, info) -> StructuredDataset: -# if info.context is None or info.context.get("deserialize") is not True: -# return self -# return StructuredDatasetTransformerEngine().to_python_value( -# FlyteContextManager.current_context(), -# Literal( -# scalar=Scalar( -# structured_dataset=StructuredDataset( -# metadata=StructuredDatasetMetadata( -# structured_dataset_type=StructuredDatasetType(format=self.file_format) -# ), -# uri=self.uri, -# ) -# ) -# ), -# type(self), -# ) -# -# setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) -# setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) -# setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) -# setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) -# setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) -# setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) -# setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) -# setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) -# except ImportError: -# pass +from typing import Dict + +from flytekit.core.context_manager import FlyteContextManager +from flytekit.loggers import logger +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema +from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer +from flytekit.types.file import FlyteFile, FlyteFilePathTransformer +from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer +from flytekit.types.structured import ( + StructuredDataset, + StructuredDatasetMetadata, + StructuredDatasetTransformerEngine, + StructuredDatasetType, +) + +# Conditional import for Pydantic model_serializer and model_validator +try: + from pydantic import model_serializer, model_validator + + from flytekit.extras.pydantic import transformer # noqa: F401 + + # Serialize and Deserialize functions + @model_serializer + def serialize_flyte_file(self) -> Dict[str, str]: + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_file(self, info) -> FlyteFile: + if info.context is None or info.context.get("deserialize") is not True: + print("@@@@ Initializing FlyteFile ") + return self + # print("@@@@ Deserializing FlyteFile ") + pv = FlyteFilePathTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=self.path, + ) + ) + ), + type(self), + ) + return pv + + @model_serializer + def serialize_flyte_dir(self) -> Dict[str, str]: + lv = FlyteDirToMultipartBlobTransformer().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_dir(self, info) -> FlyteDirectory: + if info.context is None or info.context.get("deserialize") is not True: + return self + pv = FlyteDirToMultipartBlobTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=self.path, + ) + ) + ), + type(self), + ) + return pv + + @model_serializer + def serialize_flyte_schema(self) -> Dict[str, str]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"remote_path": self.remote_path} + + @model_validator(mode="after") + def deserialize_flyte_schema(self, info) -> FlyteSchema: + if info.context is None or info.context.get("deserialize") is not True: + print("@@@@ INITIALIZING FLYTE SCHEMA") + return self + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), + type(self), + ) + + @model_serializer + def serialize_structured_dataset(self) -> Dict[str, str]: + lv = StructuredDatasetTransformerEngine().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return { + "uri": sd.uri, + "file_format": sd.file_format, + } + + @model_validator(mode="after") + def deserialize_structured_dataset(self, info) -> StructuredDataset: + if info.context is None or info.context.get("deserialize") is not True: + print("@@@@ INITIALIZING SD") + return self + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=self.file_format) + ), + uri=self.uri, + ) + ) + ), + type(self), + ) + + setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) + setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) + setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) + setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) + setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) + setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) + setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) + setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) +except ImportError as e: + logger.info("Pydantic V2 not installed, skipping custom serialization/deserialization.") + print("e: ", e) + pass diff --git a/flytekit/extras/pydantic/__init__.py b/flytekit/extras/pydantic/__init__.py index e3b2da13ac..a17d8e9c36 100644 --- a/flytekit/extras/pydantic/__init__.py +++ b/flytekit/extras/pydantic/__init__.py @@ -1,4 +1,3 @@ -from flytekit.extras.pydantic import custom from flytekit.loggers import logger # TODO: abstract this out so that there's an established pattern for registering plugins @@ -8,12 +7,7 @@ # model_validator and model_serializer are only available in pydantic > 2 from pydantic import model_serializer, model_validator - _pydantic_installed = True -except (ImportError, OSError): - _pydantic_installed = False - - -if _pydantic_installed: - from . import custom, transformer -else: + from . import transformer +except (ImportError, OSError) as e: logger.info("Flytekit only support pydantic version > 2.") + print("error: ", e) diff --git a/flytekit/extras/pydantic/custom.py b/flytekit/extras/pydantic/custom.py index caf9466647..74b6a78ff1 100644 --- a/flytekit/extras/pydantic/custom.py +++ b/flytekit/extras/pydantic/custom.py @@ -1,8 +1,7 @@ from typing import Dict -from pydantic import model_serializer, model_validator - from flytekit.core.context_manager import FlyteContextManager +from flytekit.loggers import logger from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer @@ -15,122 +14,124 @@ StructuredDatasetType, ) - -# Serialize and Deserialize functions -@model_serializer -def serialize_flyte_file(self) -> Dict[str, str]: - lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) - return {"path": lv.scalar.blob.uri} - - -@model_validator(mode="after") -def deserialize_flyte_file(self, info) -> FlyteFile: - if info.context is None or info.context.get("deserialize") is not True: - print("@@@@ initializing flyte file") - return self - print("@@@@ convert to flyte file") - pv = FlyteFilePathTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) - ), - uri=self.path, +try: + from pydantic import model_serializer, model_validator + + # Serialize and Deserialize functions + @model_serializer + def serialize_flyte_file(self) -> Dict[str, str]: + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_file(self, info) -> FlyteFile: + if info.context is None or info.context.get("deserialize") is not True: + print("@@@@@@@@ Initializing") + return self + print("@@@@@@@@ Deserializing") + pv = FlyteFilePathTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=self.path, + ) ) - ) - ), - type(self), - ) - return pv - - -@model_serializer -def serialize_flyte_dir(self) -> Dict[str, str]: - lv = FlyteDirToMultipartBlobTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) - return {"path": lv.scalar.blob.uri} - - -@model_validator(mode="after") -def deserialize_flyte_dir(self, info) -> FlyteDirectory: - if info.context is None or info.context.get("deserialize") is not True: - return self - pv = FlyteDirToMultipartBlobTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ), - uri=self.path, + ), + type(self), + ) + return pv + + @model_serializer + def serialize_flyte_dir(self) -> Dict[str, str]: + lv = FlyteDirToMultipartBlobTransformer().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_dir(self, info) -> FlyteDirectory: + if info.context is None or info.context.get("deserialize") is not True: + return self + pv = FlyteDirToMultipartBlobTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=self.path, + ) ) - ) - ), - type(self), - ) - return pv - - -@model_serializer -def serialize_flyte_schema(self) -> Dict[str, str]: - FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) - return {"remote_path": self.remote_path} - - -@model_validator(mode="after") -def deserialize_flyte_schema(self, info) -> FlyteSchema: - if info.context is None or info.context.get("deserialize") is not True: - return self - t = FlyteSchemaTransformer() - return t.to_python_value( - FlyteContextManager.current_context(), - Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), - type(self), - ) - - -@model_serializer -def serialize_structured_dataset(self) -> Dict[str, str]: - lv = StructuredDatasetTransformerEngine().to_literal(FlyteContextManager.current_context(), self, type(self), None) - sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) - sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format - return { - "uri": sd.uri, - "file_format": sd.file_format, - } - - -@model_validator(mode="after") -def deserialize_structured_dataset(self, info) -> StructuredDataset: - if info.context is None or info.context.get("deserialize") is not True: - return self - return StructuredDatasetTransformerEngine().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - structured_dataset=StructuredDataset( - metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=self.file_format) - ), - uri=self.uri, + ), + type(self), + ) + return pv + + @model_serializer + def serialize_flyte_schema(self) -> Dict[str, str]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"remote_path": self.remote_path} + + @model_validator(mode="after") + def deserialize_flyte_schema(self, info) -> FlyteSchema: + if info.context is None or info.context.get("deserialize") is not True: + + return self + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), + type(self), + ) + + @model_serializer + def serialize_structured_dataset(self) -> Dict[str, str]: + lv = StructuredDatasetTransformerEngine().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return { + "uri": sd.uri, + "file_format": sd.file_format, + } + + @model_validator(mode="after") + def deserialize_structured_dataset(self, info) -> StructuredDataset: + if info.context is None or info.context.get("deserialize") is not True: + return self + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=self.file_format) + ), + uri=self.uri, + ) ) - ) - ), - type(self), - ) - - -setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) -setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) -setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) -setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) -setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) -setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) -setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) -setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) + ), + type(self), + ) + + setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) + setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) + setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) + setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) + setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) + setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) + setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) + setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) +except ImportError: + logger.info("Pydantic V2 not installed, skipping custom serialization/deserialization.") + pass