-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pydantic Transformer V2 #2792
base: master
Are you sure you want to change the base?
Pydantic Transformer V2 #2792
Conversation
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
if lv.scalar.primitive.float_value is not None: | ||
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") | ||
return int(lv.scalar.primitive.float_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for cases when you input from the flyte console, and you use attribute access directly, you have to convert the float
to int
.
Since javascript has only number
, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type
class TrainConfig(BaseModel):
lr: float = 1e-3
batch_size: int = 32
@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
return t_args(a=cfg.lr, batch_size=cfg.batch_size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the javascript issue and the attribute access issue are orthogonal right?
this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?
Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer
.
For Pydantic Transformer, we will use strict=False
as argument to convert it to right type.
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can delete this part after console is updated right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can guarantee the console can generate an integer but not float from the input, then we can delete it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this going to work though? Do we also do a version check of the backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After console does the right thing, won't this value be coming in through the binary value instead? Instead of lv.scalar.primitive.integer/float
.
@lukas503 |
Hi @Future-Outlier, Thanks for working on the Pydantic I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if Here’s an example: from flytekit import task, workflow
from pydantic import BaseModel
class Config(BaseModel):
x: int = 1
# y: int = 4
@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
return Config()
@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
print("CALLED!", cfg)
return cfg
@workflow
def my_workflow():
config = task1(val=5)
task2(cfg=config)
if __name__ == "__main__":
print(Config.model_json_schema())
my_workflow() When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment Is this the expected behavior? Shouldn't schema changes like this invalidate the cache? |
good question, will test this out and ask other maintainers if I don't know what happened, thank you <3 |
@lukas503 |
Signed-off-by: Future-Outlier <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2792 +/- ##
===========================================
+ Coverage 45.53% 76.31% +30.77%
===========================================
Files 196 199 +3
Lines 20418 20743 +325
Branches 2647 2666 +19
===========================================
+ Hits 9298 15829 +6531
+ Misses 10658 4200 -6458
- Partials 462 714 +252 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
if lv.scalar: | ||
if lv.scalar.binary: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) | ||
if lv.scalar.generic: | ||
return self.from_generic_idl(lv.scalar.generic, expected_python_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DC(BaseModel):
ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
@task(container_image=image)
def t_args(dc: DC) -> DC:
with open(dc.ff, "r") as f:
print(f.read())
return dc
@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
with open(ff, "r") as f:
print(f.read())
return ff
@workflow
def wf(dc: DC) -> DC:
t_ff(dc.ff)
return t_args(dc=dc)
this is for this case input from flyteconsole
.
Thanks for updating the PR. I now understand the underlying issue better. It appears the caching mechanism is ignoring the output types/schema. What’s unclear to me is why the output types/schema aren’t factored into the hash used for caching. In my opinion, any interface change could invalidate the cache even the outputs. I don’t see how the old cached outputs can remain valid after an interface change. That said, this concern isn’t directly related to the current PR, so feel free to proceed as is. Update: It works as expected if remote flyte is used. The faulty behavior I described is happening only locally. |
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
todo: add this to flytesnack example # Flytekit Pydantic Plugin
Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations.
The plugin adds type support for pydantic models.
To install the plugin, run the following command:
```bash
pip install flytekitplugins-pydantic-v2 Type Examplefrom enum import Enum
import os
from typing import Dict, List, Optional
import pandas as pd
from pydantic import BaseModel, Field
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec
image = ImageSpec(packages=["flytekitplugins-pydantic-v2",
"pandas",
"pyarrow"],
registry="localhost:30000",
)
class Status(Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
class InnerBM(BaseModel):
a: int = -1
b: float = 2.1
c: str = "Hello, Flyte"
d: bool = False
e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
1: FlyteFile("s3://my-s3-bucket/example.txt"),
-1: FlyteFile("s3://my-s3-bucket/example.txt")})
k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
m: dict = Field(default_factory=lambda: {"key": "value"})
n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
enum_status: Status = Status.PENDING
sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
file_format="parquet"))
fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(
remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))
class BM(BaseModel):
a: int = -1
b: float = 2.1
c: str = "Hello, Flyte"
d: bool = False
e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
1: FlyteFile("s3://my-s3-bucket/example.txt"),
-1: FlyteFile("s3://my-s3-bucket/example.txt")})
k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
m: dict = Field(default_factory=lambda: {"key": "value"})
n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
inner_dc: InnerBM = Field(default_factory=lambda: InnerBM())
enum_status: Status = Status.PENDING
sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
file_format="parquet"))
fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))
@task(container_image=image)
def t_dc(dc: BM) -> BM:
return dc
@task(container_image=image)
def t_inner(inner_dc: InnerBM):
assert isinstance(inner_dc, InnerBM)
expected_file_content = "Default content"
# f: List[FlyteFile]
for ff in inner_dc.f:
assert isinstance(ff, FlyteFile)
with open(ff, "r") as f:
assert f.read() == expected_file_content
# j: Dict[int, FlyteFile]
for _, ff in inner_dc.j.items():
assert isinstance(ff, FlyteFile)
with open(ff, "r") as f:
assert f.read() == expected_file_content
# n: FlyteFile
assert isinstance(inner_dc.n, FlyteFile)
with open(inner_dc.n, "r") as f:
assert f.read() == expected_file_content
# o: FlyteDirectory
assert isinstance(inner_dc.o, FlyteDirectory)
assert not inner_dc.o.downloaded
with open(os.path.join(inner_dc.o, "example.txt"), "r") as fh:
assert fh.read() == expected_file_content
assert inner_dc.o.downloaded
print("Test InnerBM Successfully Passed")
# enum: Status
assert inner_dc.enum_status == Status.PENDING
@task(container_image=image)
def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]],
h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile],
k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict,
n: FlyteFile, o: FlyteDirectory,
enum_status: Status,
sd: StructuredDataset,
fsc: FlyteSchema,
):
# Strict type checks for simple types
assert isinstance(a, int), f"a is not int, it's {type(a)}"
assert a == -1
assert isinstance(b, float), f"b is not float, it's {type(b)}"
assert isinstance(c, str), f"c is not str, it's {type(c)}"
assert isinstance(d, bool), f"d is not bool, it's {type(d)}"
# Strict type checks for List[int]
assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]"
# Strict type checks for List[FlyteFile]
assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]"
# Strict type checks for List[List[int]]
assert isinstance(g, list) and all(
isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]"
# Strict type checks for List[Dict[int, bool]]
assert isinstance(h, list) and all(
isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h
), "h is not List[Dict[int, bool]]"
# Strict type checks for Dict[int, bool]
assert isinstance(i, dict) and all(
isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]"
# Strict type checks for Dict[int, FlyteFile]
assert isinstance(j, dict) and all(
isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]"
# Strict type checks for Dict[int, List[int]]
assert isinstance(k, dict) and all(
isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in
k.items()), "k is not Dict[int, List[int]]"
# Strict type checks for Dict[int, Dict[int, int]]
assert isinstance(l, dict) and all(
isinstance(k, int) and isinstance(v, dict) and all(
isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items())
for k, v in l.items()), "l is not Dict[int, Dict[int, int]]"
# Strict type check for a generic dict
assert isinstance(m, dict), "m is not dict"
# Strict type check for FlyteFile
assert isinstance(n, FlyteFile), "n is not FlyteFile"
# Strict type check for FlyteDirectory
assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory"
# # Strict type check for Enum
assert isinstance(enum_status, Status), "enum_status is not Status"
assert isinstance(sd, StructuredDataset), "sd is not StructuredDataset"
print("sd:", sd.open(pd.DataFrame).all())
assert isinstance(fsc, FlyteSchema), "fsc is not FlyteSchema"
print("fsc: ", fsc.open().all())
print("All attributes passed strict type checks.")
@workflow
def wf(dc: BM):
t_dc(dc=dc)
t_inner(inner_dc=dc.inner_dc)
t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c,
d=dc.d, e=dc.e, f=dc.f,
g=dc.g, h=dc.h, i=dc.i,
j=dc.j, k=dc.k, l=dc.l,
m=dc.m, n=dc.n, o=dc.o,
enum_status=dc.enum_status,
sd=dc.sd,
fsc=dc.fsc,
)
t_test_all_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c,
d=dc.inner_dc.d, e=dc.inner_dc.e, f=dc.inner_dc.f,
g=dc.inner_dc.g, h=dc.inner_dc.h, i=dc.inner_dc.i,
j=dc.inner_dc.j, k=dc.inner_dc.k, l=dc.inner_dc.l,
m=dc.inner_dc.m, n=dc.inner_dc.n, o=dc.inner_dc.o,
enum_status=dc.inner_dc.enum_status,
sd=dc.inner_dc.sd,
fsc=dc.inner_dc.fsc,
)
|
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
.github/workflows/pythonbuild.yml
Outdated
# TODO: remove pydantic v1 plugin, since v2 is in core already | ||
# flytekit-pydantic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are going to remove pydantic v1 in the future and this will fail when pydantic version > 2 (CI use pydantic version > 2)
Let's comment it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove it?
flytekit/__init__.py
Outdated
from flytekit.deck import Deck | ||
from flytekit.extras import pydantic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we move to lazy import transformer, this will fail to have custom serialize and deserialize behavior, still investigation the root cause.
Signed-off-by: Future-Outlier <[email protected]> Co-authored-by: pingsutw <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor things, otherwise, it's looking pretty good.
flytekit/extras/pydantic/__init__.py
Outdated
logger.info(f"Meet error when importing pydantic: `{e}`") | ||
logger.info("Flytekit only support pydantic version > 2.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: those should be a warning.
from pydantic import model_serializer, model_validator | ||
|
||
except ImportError: | ||
logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
FuncType = TypeVar("FuncType", bound=Callable[..., Any]) | ||
|
||
from typing_extensions import Literal as typing_literal | ||
|
||
def model_serializer( | ||
__f: Union[Callable[..., Any], None] = None, | ||
*, | ||
mode: typing_literal["plain", "wrap"] = "plain", | ||
when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always", | ||
return_type: Any = None, | ||
) -> Callable[[Any], Any]: | ||
"""Placeholder decorator for Pydantic model_serializer.""" | ||
|
||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
def wrapper(*args, **kwargs): | ||
raise Exception( | ||
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." | ||
) | ||
|
||
return wrapper | ||
|
||
# If no function (__f) is provided, return the decorator | ||
if __f is None: | ||
return decorator | ||
# If __f is provided, directly decorate the function | ||
return decorator(__f) | ||
|
||
def model_validator( | ||
*, | ||
mode: typing_literal["wrap", "before", "after"], | ||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | ||
"""Placeholder decorator for Pydantic model_validator.""" | ||
|
||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
def wrapper(*args, **kwargs): | ||
raise Exception( | ||
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." | ||
) | ||
|
||
return wrapper | ||
|
||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we supporting only pydantic v2? Why do we have these fallback definitions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to make here to work.
This syntax is more readable then setattr
https://github.com/flyteorg/flytekit/pull/2792/files#diff-22cf9c7153b54371b4a77331ddf276a082cf4b3c5e7bd1595dd67232288594fdR168-R176
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's to support the case where pydantic is not installed at all. cuz yeah it looks nicer in the real File/Directory class, but we also want it to not fail ofc.
.github/workflows/pythonbuild.yml
Outdated
# TODO: remove pydantic v1 plugin, since v2 is in core already | ||
# flytekit-pydantic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove it?
flytekit/core/type_engine.py
Outdated
@@ -215,16 +215,32 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |||
) | |||
|
|||
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: | |||
""" | |||
TODO: Add more comments to explain the lifecycle of attribute access. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fill in TODO
if lv.scalar.primitive.float_value is not None: | ||
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") | ||
return int(lv.scalar.primitive.float_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this going to work though? Do we also do a version check of the backend?
@@ -4,7 +4,7 @@ | |||
|
|||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | |||
|
|||
plugin_requires = ["flytekit>=1.7.0b0", "pydantic"] | |||
plugin_requires = ["flytekit>=1.7.0b0", "pydantic<2"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also leave a warning in the README.md explaining that we're deprecating this plugin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this going to work though? Do we also do a version check of the backend?
No this is just for supporting the case I've mentioned above.
we didn't support this before and I think we should do it.
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Let's merge it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙇 thank you @Future-Outlier 🙏 this is going to be fantastic.
@@ -1124,6 +1194,8 @@ def lazy_import_transformers(cls): | |||
from flytekit.extras import pytorch # noqa: F401 | |||
if is_imported("sklearn"): | |||
from flytekit.extras import sklearn # noqa: F401 | |||
if is_imported("pydantic"): | |||
from flytekit.extras import pydantic # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change the name of this folder? pydantic
can get confusing because the real library is also called pydantic right?
@@ -2194,6 +2304,34 @@ def _check_and_covert_float(lv: Literal) -> float: | |||
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float") | |||
|
|||
|
|||
def _handle_flyte_console_float_input_to_int(lv: Literal) -> int: | |||
""" | |||
Flyte Console is written by JavaScript and JavaScript has only one number type which is float. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically javascript's number type is Number but yeah, sometimes it keeps track of trailing 0s and sometimes it doesn't.
if lv.scalar.primitive.float_value is not None: | ||
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") | ||
return int(lv.scalar.primitive.float_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After console does the right thing, won't this value be coming in through the binary value instead? Instead of lv.scalar.primitive.integer/float
.
FuncType = TypeVar("FuncType", bound=Callable[..., Any]) | ||
|
||
from typing_extensions import Literal as typing_literal | ||
|
||
def model_serializer( | ||
__f: Union[Callable[..., Any], None] = None, | ||
*, | ||
mode: typing_literal["plain", "wrap"] = "plain", | ||
when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always", | ||
return_type: Any = None, | ||
) -> Callable[[Any], Any]: | ||
"""Placeholder decorator for Pydantic model_serializer.""" | ||
|
||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
def wrapper(*args, **kwargs): | ||
raise Exception( | ||
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." | ||
) | ||
|
||
return wrapper | ||
|
||
# If no function (__f) is provided, return the decorator | ||
if __f is None: | ||
return decorator | ||
# If __f is provided, directly decorate the function | ||
return decorator(__f) | ||
|
||
def model_validator( | ||
*, | ||
mode: typing_literal["wrap", "before", "after"], | ||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | ||
"""Placeholder decorator for Pydantic model_validator.""" | ||
|
||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
def wrapper(*args, **kwargs): | ||
raise Exception( | ||
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." | ||
) | ||
|
||
return wrapper | ||
|
||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's to support the case where pydantic is not installed at all. cuz yeah it looks nicer in the real File/Directory class, but we also want it to not fail ofc.
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False) | ||
|
||
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: | ||
schema = t.model_json_schema() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a future PR, can we add some unit tests to ensure that we're correctly extracting default values into the schema?
|
||
bm = BM() | ||
wf(bm=bm) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add two new lines between tests? I know it doesn't matter, but pycharm complains.
|
||
def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): | ||
class InnerBM(BaseModel): | ||
flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
little bit confused about pydantic here. Are you supposed to use dataclasses.field
here instead of pydantic.Field
?
def test_protocol(): | ||
assert get_protocol("s3://my-s3-bucket/file") == "s3" | ||
assert get_protocol("/file") == "file" | ||
|
||
|
||
def generate_pandas() -> pd.DataFrame: | ||
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep spaces plz
flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) | ||
flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) | ||
|
||
class BM(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in a different file, but can we add pydantic models that contain dataclass
and also dataclass
es that contain Pydantic models? I know some people have been asking for that, be good to have some tests for it.
Thank you!
Tracking issue
flyteorg/flyte#5033
flyteorg/flyte#5318
How to test it by others
Not Sure
Which pydantic version should we use as the lower bound?
This case will fail in the Flyte Console
file tree structure
Why didn't integrate with pydantic v1 BaseModel? (make you run v1 and v2 BaseModel at the same time together)
This is an issue from pydantic.
pydantic/pydantic#9919
If this is fixed, then we can support both pydantic v1 and v2 at the same time.
story:
Kevin and I wanted to support v1 and v2 at the same time before, but after knowing that this would take lots of time, we asked Ketan for advice, then he said that if users want it, then we can try to support it or tell users to support it.
Why are the changes needed?
flyteconsole input to handle flytetypes.
when handling the input below, and attribute access to a flyte type, we need to teach flyte types how to convert a protobuf struct to flyte types.
Take FlyteFile as an example.
lifecycle
flyteconsole input to handle float issue.
It will be needed when in the following example.
For enum class.
I've tried basemodel -> dict obj -> msgpack bytes first.
To make this happen, you need to call the function
BaseModel.model_dump
, but this function can't interpretEnum
.However,
BaseModel.model_dump_json
can.@model_serializer
and@model_validator(mode="after")
?You can understand them as
_serialize
and_deserialize
in FlyteTypes, which useSerializableType
to customize the serialize/deserialize behavior for flyte types.Related PRs: #2554
What changes were proposed in this pull request?
note: we don't support pydantic BaseModel has a dataclass with FlyteTypes.
We support pydantic BaseModel has a dataclass with primitive types.
How was this patch tested?
Example code.
(nested cases, flyte types, and attribute access.)
Setup process
local and remote execution.
ImageSpec for the docker image.
Screenshots
Check all the applicable boxes