Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Oct 22, 2024
1 parent 971aa47 commit ff2d4a0
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [--fix, --show-fixes, --output-format=full]
args: [--fix, --show-fixes, --output-format=full, --extend-ignore=I001]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ torch<=1.12.1; python_version<'3.11'
# pytorch 2 supports python 3.11
# pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436
torch; python_version<'3.12'
pydantic

# TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic.
# We have temporarily disabled this feature on Windows and are using python-magic for Mac OS and Linux instead.
Expand Down
3 changes: 3 additions & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
else:
from importlib.metadata import entry_points


from flytekit._version import __version__
from flytekit.configuration import Config
from flytekit.core.array_node_map_task import map_task
Expand Down Expand Up @@ -263,6 +264,8 @@
StructuredDatasetType,
)

from flytekit.core import type_utils # noqa: I001


def current_context() -> ExecutionParameters:
"""
Expand Down
275 changes: 142 additions & 133 deletions flytekit/core/type_utils.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,142 @@
# from typing import Dict
#
# from flytekit.core.context_manager import FlyteContextManager
# from flytekit.models.core import types as _core_types
# from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema
# from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer
# from flytekit.types.file import FlyteFile, FlyteFilePathTransformer
# from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer
# from flytekit.types.structured import (
# StructuredDataset,
# StructuredDatasetMetadata,
# StructuredDatasetTransformerEngine,
# StructuredDatasetType,
# )
#
# # Conditional import for Pydantic model_serializer and model_validator
# try:
# from pydantic import model_serializer, model_validator
#
# # Serialize and Deserialize functions
# @model_serializer
# def serialize_flyte_file(self) -> Dict[str, str]:
# lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
# return {"path": lv.scalar.blob.uri}
#
# @model_validator(mode="after")
# def deserialize_flyte_file(self, info) -> FlyteFile:
# if info.context is None or info.context.get("deserialize") is not True:
# return self
# pv = FlyteFilePathTransformer().to_python_value(
# FlyteContextManager.current_context(),
# Literal(
# scalar=Scalar(
# blob=Blob(
# metadata=BlobMetadata(
# type=_core_types.BlobType(
# format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
# )
# ),
# uri=self.path,
# )
# )
# ),
# type(self),
# )
# return pv
#
# @model_serializer
# def serialize_flyte_dir(self) -> Dict[str, str]:
# lv = FlyteDirToMultipartBlobTransformer().to_literal(
# FlyteContextManager.current_context(), self, type(self), None
# )
# return {"path": lv.scalar.blob.uri}
#
# @model_validator(mode="after")
# def deserialize_flyte_dir(self, info) -> FlyteDirectory:
# if info.context is None or info.context.get("deserialize") is not True:
# return self
# pv = FlyteDirToMultipartBlobTransformer().to_python_value(
# FlyteContextManager.current_context(),
# Literal(
# scalar=Scalar(
# blob=Blob(
# metadata=BlobMetadata(
# type=_core_types.BlobType(
# format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
# )
# ),
# uri=self.path,
# )
# )
# ),
# type(self),
# )
# return pv
#
# @model_serializer
# def serialize_flyte_schema(self) -> Dict[str, str]:
# FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
# return {"remote_path": self.remote_path}
#
# @model_validator(mode="after")
# def deserialize_flyte_schema(self, info) -> FlyteSchema:
# if info.context is None or info.context.get("deserialize") is not True:
# return self
# t = FlyteSchemaTransformer()
# return t.to_python_value(
# FlyteContextManager.current_context(),
# Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))),
# type(self),
# )
#
# @model_serializer
# def serialize_structured_dataset(self) -> Dict[str, str]:
# lv = StructuredDatasetTransformerEngine().to_literal(
# FlyteContextManager.current_context(), self, type(self), None
# )
# sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
# sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
# return {
# "uri": sd.uri,
# "file_format": sd.file_format,
# }
#
# @model_validator(mode="after")
# def deserialize_structured_dataset(self, info) -> StructuredDataset:
# if info.context is None or info.context.get("deserialize") is not True:
# return self
# return StructuredDatasetTransformerEngine().to_python_value(
# FlyteContextManager.current_context(),
# Literal(
# scalar=Scalar(
# structured_dataset=StructuredDataset(
# metadata=StructuredDatasetMetadata(
# structured_dataset_type=StructuredDatasetType(format=self.file_format)
# ),
# uri=self.uri,
# )
# )
# ),
# type(self),
# )
#
# setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file)
# setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file)
# setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir)
# setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir)
# setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema)
# setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema)
# setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset)
# setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset)
# except ImportError:
# pass
from typing import Dict

from flytekit.core.context_manager import FlyteContextManager
from flytekit.loggers import logger
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema
from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured import (
StructuredDataset,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)

# Conditional import for Pydantic model_serializer and model_validator
try:
from pydantic import model_serializer, model_validator

from flytekit.extras.pydantic import transformer # noqa: F401

# Serialize and Deserialize functions
@model_serializer
def serialize_flyte_file(self) -> Dict[str, str]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}

Check warning on line 27 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L26-L27

Added lines #L26 - L27 were not covered by tests

@model_validator(mode="after")
def deserialize_flyte_file(self, info) -> FlyteFile:
if info.context is None or info.context.get("deserialize") is not True:
print("@@@@ Initializing FlyteFile ")
return self

Check warning on line 33 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L32-L33

Added lines #L32 - L33 were not covered by tests
# print("@@@@ Deserializing FlyteFile ")
pv = FlyteFilePathTransformer().to_python_value(

Check warning on line 35 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L35

Added line #L35 was not covered by tests
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=self.path,
)
)
),
type(self),
)
return pv

Check warning on line 51 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L51

Added line #L51 was not covered by tests

@model_serializer
def serialize_flyte_dir(self) -> Dict[str, str]:
lv = FlyteDirToMultipartBlobTransformer().to_literal(

Check warning on line 55 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L55

Added line #L55 was not covered by tests
FlyteContextManager.current_context(), self, type(self), None
)
return {"path": lv.scalar.blob.uri}

Check warning on line 58 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L58

Added line #L58 was not covered by tests

@model_validator(mode="after")
def deserialize_flyte_dir(self, info) -> FlyteDirectory:
if info.context is None or info.context.get("deserialize") is not True:
return self
pv = FlyteDirToMultipartBlobTransformer().to_python_value(

Check warning on line 64 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L63-L64

Added lines #L63 - L64 were not covered by tests
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=self.path,
)
)
),
type(self),
)
return pv

Check warning on line 80 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L80

Added line #L80 was not covered by tests

@model_serializer
def serialize_flyte_schema(self) -> Dict[str, str]:
FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"remote_path": self.remote_path}

Check warning on line 85 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L84-L85

Added lines #L84 - L85 were not covered by tests

@model_validator(mode="after")
def deserialize_flyte_schema(self, info) -> FlyteSchema:
if info.context is None or info.context.get("deserialize") is not True:
print("@@@@ INITIALIZING FLYTE SCHEMA")
return self
t = FlyteSchemaTransformer()
return t.to_python_value(

Check warning on line 93 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L90-L93

Added lines #L90 - L93 were not covered by tests
FlyteContextManager.current_context(),
Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))),
type(self),
)

@model_serializer
def serialize_structured_dataset(self) -> Dict[str, str]:
lv = StructuredDatasetTransformerEngine().to_literal(

Check warning on line 101 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L101

Added line #L101 was not covered by tests
FlyteContextManager.current_context(), self, type(self), None
)
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
return {

Check warning on line 106 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L104-L106

Added lines #L104 - L106 were not covered by tests
"uri": sd.uri,
"file_format": sd.file_format,
}

@model_validator(mode="after")
def deserialize_structured_dataset(self, info) -> StructuredDataset:
if info.context is None or info.context.get("deserialize") is not True:
print("@@@@ INITIALIZING SD")
return self
return StructuredDatasetTransformerEngine().to_python_value(

Check warning on line 116 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L114-L116

Added lines #L114 - L116 were not covered by tests
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.file_format)
),
uri=self.uri,
)
)
),
type(self),
)

setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file)
setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file)
setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir)
setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir)
setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema)
setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema)
setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset)
setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset)
except ImportError as e:
logger.info("Pydantic V2 not installed, skipping custom serialization/deserialization.")
print("e: ", e)
pass

Check warning on line 142 in flytekit/core/type_utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_utils.py#L139-L142

Added lines #L139 - L142 were not covered by tests
12 changes: 3 additions & 9 deletions flytekit/extras/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from flytekit.extras.pydantic import custom
from flytekit.loggers import logger

# TODO: abstract this out so that there's an established pattern for registering plugins
Expand All @@ -8,12 +7,7 @@
# model_validator and model_serializer are only available in pydantic > 2
from pydantic import model_serializer, model_validator

_pydantic_installed = True
except (ImportError, OSError):
_pydantic_installed = False


if _pydantic_installed:
from . import custom, transformer
else:
from . import transformer
except (ImportError, OSError) as e:
logger.info("Flytekit only support pydantic version > 2.")
print("error: ", e)

Check warning on line 13 in flytekit/extras/pydantic/__init__.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pydantic/__init__.py#L11-L13

Added lines #L11 - L13 were not covered by tests
Loading

0 comments on commit ff2d4a0

Please sign in to comment.