Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic Transformer V2 #2792

Open
wants to merge 100 commits into
base: master
Choose a base branch
from
Open

Conversation

Future-Outlier
Copy link
Member

@Future-Outlier Future-Outlier commented Oct 8, 2024

Tracking issue

flyteorg/flyte#5033
flyteorg/flyte#5318

How to test it by others

  1. git clone https://github.com/flyteorg/flytekit
  2. gh pr checkout 2792
  3. make setup-global-uv
  4. cd plugins/flytekit-pydantic-v2 && pip install -e .
  5. test a workflow example

Not Sure

Which pydantic version should we use as the lower bound?

This case will fail in the Flyte Console

@dataclass
class DC:
    a: Union[bool, str, int]
    b: Union[bool, str, int]

@task(container_image=image)
def add(dc1: DC, dc2: DC) -> Union[bool, int, str]:
    return dc1.a + dc2.b  # type: ignore

# input from flyte console to generate generic protobuf struct
# "{\"a\": 1, \"b\": 2}",
@workflow
def wf(dc: DC) -> Union[bool, int, str]:
    return add(dc1=dc, dc2=dc)

file tree structure

  1. The file tree structure is the same as flytekit-pydantic

Why didn't integrate with pydantic v1 BaseModel? (make you run v1 and v2 BaseModel at the same time together)

This is an issue from pydantic.
pydantic/pydantic#9919
If this is fixed, then we can support both pydantic v1 and v2 at the same time.

story:
Kevin and I wanted to support v1 and v2 at the same time before, but after knowing that this would take lots of time, we asked Ketan for advice, then he said that if users want it, then we can try to support it or tell users to support it.

Why are the changes needed?

  1. why from_generic_idl?
    flyteconsole input to handle flytetypes.

when handling the input below, and attribute access to a flyte type, we need to teach flyte types how to convert a protobuf struct to flyte types.
Take FlyteFile as an example.

lifecycle

json str -> protobuf struct -> attribute access flyte type -> send to downstream input
class DC(BaseModel):
    ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))

@workflow
def wf(dc: DC) -> DC:
    t_ff(dc.ff)
    return t_args(dc=dc)

# console input: {"ff":{"path":"s3://my-s3-bucket/example.txt"}}
image image
  1. why _check_and_covert_int in the int transformer?
    flyteconsole input to handle float issue.
    It will be needed when in the following example.
json str -> protbuf struct -> attribute access and get float(due to javascript problem) -> convert float to int in flytekit
class DC(BaseModel):
    a: int = -1

@workflow
def wf(dc: DC):
    t_int(input_int=dc.a)
  1. why basemodel -> json str -> dict obj -> msgpack bytes?
    For enum class.

I've tried basemodel -> dict obj -> msgpack bytes first.
To make this happen, you need to call the function BaseModel.model_dump, but this function can't interpret Enum.
However, BaseModel.model_dump_json can.

  1. What are @model_serializer and @model_validator(mode="after")?
    You can understand them as _serialize and _deserialize in FlyteTypes, which use SerializableType to customize the serialize/deserialize behavior for flyte types.

Related PRs: #2554

What changes were proposed in this pull request?

  • attribute access (primitives and flyte types) (datetime not sure)
  • flyte types
  • nested cases
  • dataclasses.dataclass in pydantic.BaseModel
  • pydantic.dataclass in pydantic.BaseModel
  • pydantic.BaseModel in pydantic.BaseModel

note: we don't support pydantic BaseModel has a dataclass with FlyteTypes.
We support pydantic BaseModel has a dataclass with primitive types.

@dataclass
class dc:
    ff: FlyteFile

class DC(BaseModel):
    inner_dc: dc

# This is not supported
# ============================
@dataclass
class dc:
    a: int

class DC(BaseModel):
    inner_dc: dc

# This is supported
# ============================

How was this patch tested?

Example code.
(nested cases, flyte types, and attribute access.)

from pydantic import BaseModel, Field
from typing import Dict, List, Optional

from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec, kwtypes
from enum import Enum
import os
import pandas as pd

flytekit_hash = "fb82dd521615039f626c78489b2e83259d7db2a5"
flytekit = f"git+https://github.com/flyteorg/flytekit.git@{flytekit_hash}"
pydantic_plugin = f"git+https://github.com/flyteorg/flytekit.git@{flytekit_hash}#subdirectory=plugins/flytekit-pydantic-v2"

# Define custom image for the task
image = ImageSpec(packages=[
                            flytekit,
                            pydantic_plugin,
                            "pandas",
                            "pyarrow"],
                            apt_packages=["git"],
                            registry="localhost:30000",
                         )

class Status(Enum):
    PENDING = "pending"
    APPROVED = "approved"
    REJECTED = "rejected"

class InnerBM(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    enum_status: Status = Status.PENDING
    sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
        uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
        file_format="parquet"))
    fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(
        remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))

class BM(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    inner_bm: InnerBM = Field(default_factory=lambda: InnerBM())
    enum_status: Status = Status.PENDING
    sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
        uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
        file_format="parquet"))
    fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))

@task(container_image=image)
def t_bm(bm: BM) -> BM:
    return bm

@task(container_image=image)
def t_inner(inner_bm: InnerBM):
    assert isinstance(inner_bm, InnerBM)

    expected_file_content = "Default content"

    # f: List[FlyteFile]
    for ff in inner_bm.f:
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # j: Dict[int, FlyteFile]
    for _, ff in inner_bm.j.items():
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # n: FlyteFile
    assert isinstance(inner_bm.n, FlyteFile)
    with open(inner_bm.n, "r") as f:
        assert f.read() == expected_file_content
    # o: FlyteDirectory
    assert isinstance(inner_bm.o, FlyteDirectory)
    assert not inner_bm.o.downloaded
    with open(os.path.join(inner_bm.o, "example.txt"), "r") as fh:
        assert fh.read() == expected_file_content
    assert inner_bm.o.downloaded
    print("Test InnerBM Successfully Passed")
    # enum: Status
    assert inner_bm.enum_status == Status.PENDING


@task(container_image=image)
def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]],
                          h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile],
                          k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict,
                          n: FlyteFile, o: FlyteDirectory,
                          enum_status: Status,
                          sd: StructuredDataset,
                          fsc: FlyteSchema,
                          ):
    # Strict type checks for simple types
    assert isinstance(a, int), f"a is not int, it's {type(a)}"
    assert a == -1
    assert isinstance(b, float), f"b is not float, it's {type(b)}"
    assert isinstance(c, str), f"c is not str, it's {type(c)}"
    assert isinstance(d, bool), f"d is not bool, it's {type(d)}"

    # Strict type checks for List[int]
    assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]"

    # Strict type checks for List[FlyteFile]
    assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]"

    # Strict type checks for List[List[int]]
    assert isinstance(g, list) and all(
        isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]"

    # Strict type checks for List[Dict[int, bool]]
    assert isinstance(h, list) and all(
        isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h
    ), "h is not List[Dict[int, bool]]"

    # Strict type checks for Dict[int, bool]
    assert isinstance(i, dict) and all(
        isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]"

    # Strict type checks for Dict[int, FlyteFile]
    assert isinstance(j, dict) and all(
        isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]"

    # Strict type checks for Dict[int, List[int]]
    assert isinstance(k, dict) and all(
        isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in
        k.items()), "k is not Dict[int, List[int]]"

    # Strict type checks for Dict[int, Dict[int, int]]
    assert isinstance(l, dict) and all(
        isinstance(k, int) and isinstance(v, dict) and all(
            isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items())
        for k, v in l.items()), "l is not Dict[int, Dict[int, int]]"

    # Strict type check for a generic dict
    assert isinstance(m, dict), "m is not dict"

    # Strict type check for FlyteFile
    assert isinstance(n, FlyteFile), "n is not FlyteFile"

    # Strict type check for FlyteDirectory
    assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory"

    # # Strict type check for Enum
    assert isinstance(enum_status, Status), "enum_status is not Status"

    assert isinstance(sd, StructuredDataset), "sd is not StructuredDataset"
    print("sd:", sd.open(pd.DataFrame).all())

    assert isinstance(fsc, FlyteSchema), "fsc is not FlyteSchema"
    print("fsc: ", fsc.open().all())

    print("All attributes passed strict type checks.")


@workflow
def wf(bm: BM):
    t_bm(bm=bm)
    t_inner(inner_bm=bm.inner_bm)
    t_test_all_attributes(a=bm.a, b=bm.b, c=bm.c,
                          d=bm.d, e=bm.e, f=bm.f,
                          g=bm.g, h=bm.h, i=bm.i,
                          j=bm.j, k=bm.k, l=bm.l,
                          m=bm.m, n=bm.n, o=bm.o,
                          enum_status=bm.enum_status,
                          sd=bm.sd,
                          fsc=bm.fsc,
                          )

    t_test_all_attributes(a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c,
                          d=bm.inner_bm.d, e=bm.inner_bm.e, f=bm.inner_bm.f,
                          g=bm.inner_bm.g, h=bm.inner_bm.h, i=bm.inner_bm.i,
                          j=bm.inner_bm.j, k=bm.inner_bm.k, l=bm.inner_bm.l,
                          m=bm.inner_bm.m, n=bm.inner_bm.n, o=bm.inner_bm.o,
                          enum_status=bm.inner_bm.enum_status,
                          sd=bm.inner_bm.sd,
                          fsc=bm.inner_bm.fsc,
                          )

if __name__ == "__main__":
    from flytekit.clis.sdk_in_container import pyflyte
    from click.testing import CliRunner

    runner = CliRunner()
    path = os.path.realpath(__file__)
    input_val = BM().model_dump_json()
    print(input_val)
    result = runner.invoke(pyflyte.main,
                           ["run", path, "wf", "--bm", input_val])
    print("Local Execution: ", result.output)

    result = runner.invoke(pyflyte.main,
                           ["run", "--remote", path, "wf", "--bm", input_val])
    print("Remote Execution: ", result.output)

Setup process

local and remote execution.
ImageSpec for the docker image.

Screenshots

  • local execution
image image
  • remote execution
image
  • remote execution from flyte console input
image image

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Signed-off-by: Future-Outlier <[email protected]>
@Future-Outlier Future-Outlier changed the title Pydantic Transformer V2 [wip] Pydantic Transformer V2 Oct 8, 2024
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Comment on lines +2059 to +2061
if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Copy link
Member Author

@Future-Outlier Future-Outlier Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for cases when you input from the flyte console, and you use attribute access directly, you have to convert the float to int.
Since javascript has only number, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type

class TrainConfig(BaseModel):
    lr: float = 1e-3
    batch_size: int = 32

@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
    return t_args(a=cfg.lr, batch_size=cfg.batch_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the javascript issue and the attribute access issue are orthogonal right?

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer.

For Pydantic Transformer, we will use strict=False as argument to convert it to right type.

    def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
        if binary_idl_object.tag == MESSAGEPACK:
            dict_obj = msgpack.loads(binary_idl_object.value)
            python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
            return python_val

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can delete this part after console is updated right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can guarantee the console can generate an integer but not float from the input, then we can delete it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this going to work though? Do we also do a version check of the backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After console does the right thing, won't this value be coming in through the binary value instead? Instead of lv.scalar.primitive.integer/float.

@Future-Outlier
Copy link
Member Author

@lukas503
Hi, I saw you add an emoji to this PR!
Do you want to help me test this out?
Search "How to test it by others?" will have a guide for you!

@lukas503
Copy link

lukas503 commented Oct 8, 2024

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

@Future-Outlier
Copy link
Member Author

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

good question, will test this out and ask other maintainers if I don't know what happened, thank you <3

@Future-Outlier
Copy link
Member Author

Hi @Future-Outlier,

Thanks for working on the Pydantic TypeTransformer! Which "How to test it by others?" guide are you referring to?

I've been testing the code locally and wondered about a behavior related to caching. Specifically, I’m curious if model_json_schema is considered in the hash used for caching.

Here’s an example:

from flytekit import task, workflow
from pydantic import BaseModel

class Config(BaseModel):
    x: int = 1
    # y: int = 4

@task(cache=True, cache_version="v1")
def task1(val: int) -> Config:
    return Config()

@task(cache=True, cache_version="v1")
def task2(cfg: Config) -> Config:
    print("CALLED!", cfg)
    return cfg

@workflow
def my_workflow():
    config = task1(val=5)
    task2(cfg=config)

if __name__ == "__main__":
    print(Config.model_json_schema())
    my_workflow()

When I run the workflow for the first time, nothing is cached. On the second run, the results are cached, as expected. However, if I uncomment y: int = 4, the tasks still remain cached. I would assume that this schema change would trigger a cache bust and re-execute the tasks. This causes failure if I update the attributes and the cache_version of task2.

Is this the expected behavior? Shouldn't schema changes like this invalidate the cache?

@lukas503
sorry can you try again?
I've updated the above description.

Copy link

codecov bot commented Oct 9, 2024

Codecov Report

Attention: Patch coverage is 56.22776% with 123 lines in your changes missing coverage. Please review.

Project coverage is 76.31%. Comparing base (3fc51af) to head (7735352).
Report is 8 commits behind head on master.

Files with missing lines Patch % Lines
flytekit/extras/pydantic/transformer.py 44.44% 25 Missing ⚠️
flytekit/types/schema/types.py 35.29% 19 Missing and 3 partials ⚠️
flytekit/types/structured/structured_dataset.py 32.25% 18 Missing and 3 partials ⚠️
flytekit/extras/pydantic/decorator.py 21.73% 18 Missing ⚠️
flytekit/types/directory/types.py 75.60% 8 Missing and 2 partials ⚠️
flytekit/types/file/file.py 72.97% 8 Missing and 2 partials ⚠️
flytekit/interaction/click_types.py 30.76% 9 Missing ⚠️
flytekit/core/type_engine.py 89.58% 4 Missing and 1 partial ⚠️
flytekit/extras/pydantic/__init__.py 57.14% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2792       +/-   ##
===========================================
+ Coverage   45.53%   76.31%   +30.77%     
===========================================
  Files         196      199        +3     
  Lines       20418    20743      +325     
  Branches     2647     2666       +19     
===========================================
+ Hits         9298    15829     +6531     
+ Misses      10658     4200     -6458     
- Partials      462      714      +252     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Comment on lines +586 to +590
if lv.scalar:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class DC(BaseModel):
    ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))

@task(container_image=image)
def t_args(dc: DC) -> DC:
    with open(dc.ff, "r") as f:
        print(f.read())
    return dc

@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
    with open(ff, "r") as f:
        print(f.read())
    return ff

@workflow
def wf(dc: DC) -> DC:
    t_ff(dc.ff)
    return t_args(dc=dc)

this is for this case input from flyteconsole.

@lukas503
Copy link

lukas503 commented Oct 9, 2024

sorry can you try again?
I've updated the above description.

Thanks for updating the PR. I now understand the underlying issue better. It appears the caching mechanism is ignoring the output types/schema. What’s unclear to me is why the output types/schema aren’t factored into the hash used for caching. In my opinion, any interface change could invalidate the cache even the outputs. I don’t see how the old cached outputs can remain valid after an interface change.

That said, this concern isn’t directly related to the current PR, so feel free to proceed as is.

Update: It works as expected if remote flyte is used. The faulty behavior I described is happening only locally.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
@Future-Outlier
Copy link
Member Author

todo: add this to flytesnack example

# Flytekit Pydantic Plugin

Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations.

The plugin adds type support for pydantic models.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-pydantic-v2

Type Example

from enum import Enum
import os
from typing import Dict, List, Optional

import pandas as pd
from pydantic import BaseModel, Field

from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec


image = ImageSpec(packages=["flytekitplugins-pydantic-v2",
                            "pandas",
                            "pyarrow"],
                            registry="localhost:30000",
                            )

class Status(Enum):
    PENDING = "pending"
    APPROVED = "approved"
    REJECTED = "rejected"

class InnerBM(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    enum_status: Status = Status.PENDING
    sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
        uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
        file_format="parquet"))
    fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(
        remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))

class BM(BaseModel):
    a: int = -1
    b: float = 2.1
    c: str = "Hello, Flyte"
    d: bool = False
    e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2])
    f: List[FlyteFile] = Field(default_factory=lambda: [FlyteFile("s3://my-s3-bucket/example.txt")])
    g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]])
    h: List[Dict[int, bool]] = Field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
    i: Dict[int, bool] = Field(default_factory=lambda: {0: False, 1: True, -1: False})
    j: Dict[int, FlyteFile] = Field(default_factory=lambda: {0: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             1: FlyteFile("s3://my-s3-bucket/example.txt"),
                                                             -1: FlyteFile("s3://my-s3-bucket/example.txt")})
    k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]})
    l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}})
    m: dict = Field(default_factory=lambda: {"key": "value"})
    n: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
    o: FlyteDirectory = Field(default_factory=lambda: FlyteDirectory("s3://my-s3-bucket/s3_flyte_dir"))
    inner_dc: InnerBM = Field(default_factory=lambda: InnerBM())
    enum_status: Status = Status.PENDING
    sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset(
        uri="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/7f31035fdf92510e40ee9340f9e5bf34",
        file_format="parquet"))
    fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema(remote_path="s3://my-s3-bucket/data/uk/ahlg7qw7q5m4np7vwdqm-n0-0/ab3aef21302d0529daef8c43825c3fdf"))

@task(container_image=image)
def t_dc(dc: BM) -> BM:
    return dc

@task(container_image=image)
def t_inner(inner_dc: InnerBM):
    assert isinstance(inner_dc, InnerBM)

    expected_file_content = "Default content"

    # f: List[FlyteFile]
    for ff in inner_dc.f:
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # j: Dict[int, FlyteFile]
    for _, ff in inner_dc.j.items():
        assert isinstance(ff, FlyteFile)
        with open(ff, "r") as f:
            assert f.read() == expected_file_content
    # n: FlyteFile
    assert isinstance(inner_dc.n, FlyteFile)
    with open(inner_dc.n, "r") as f:
        assert f.read() == expected_file_content
    # o: FlyteDirectory
    assert isinstance(inner_dc.o, FlyteDirectory)
    assert not inner_dc.o.downloaded
    with open(os.path.join(inner_dc.o, "example.txt"), "r") as fh:
        assert fh.read() == expected_file_content
    assert inner_dc.o.downloaded
    print("Test InnerBM Successfully Passed")
    # enum: Status
    assert inner_dc.enum_status == Status.PENDING


@task(container_image=image)
def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]],
                          h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile],
                          k: Dict[int, List[int]], l: Dict[int, Dict[int, int]], m: dict,
                          n: FlyteFile, o: FlyteDirectory,
                          enum_status: Status,
                          sd: StructuredDataset,
                          fsc: FlyteSchema,
                          ):
    # Strict type checks for simple types
    assert isinstance(a, int), f"a is not int, it's {type(a)}"
    assert a == -1
    assert isinstance(b, float), f"b is not float, it's {type(b)}"
    assert isinstance(c, str), f"c is not str, it's {type(c)}"
    assert isinstance(d, bool), f"d is not bool, it's {type(d)}"

    # Strict type checks for List[int]
    assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]"

    # Strict type checks for List[FlyteFile]
    assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]"

    # Strict type checks for List[List[int]]
    assert isinstance(g, list) and all(
        isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]"

    # Strict type checks for List[Dict[int, bool]]
    assert isinstance(h, list) and all(
        isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h
    ), "h is not List[Dict[int, bool]]"

    # Strict type checks for Dict[int, bool]
    assert isinstance(i, dict) and all(
        isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]"

    # Strict type checks for Dict[int, FlyteFile]
    assert isinstance(j, dict) and all(
        isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]"

    # Strict type checks for Dict[int, List[int]]
    assert isinstance(k, dict) and all(
        isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in
        k.items()), "k is not Dict[int, List[int]]"

    # Strict type checks for Dict[int, Dict[int, int]]
    assert isinstance(l, dict) and all(
        isinstance(k, int) and isinstance(v, dict) and all(
            isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items())
        for k, v in l.items()), "l is not Dict[int, Dict[int, int]]"

    # Strict type check for a generic dict
    assert isinstance(m, dict), "m is not dict"

    # Strict type check for FlyteFile
    assert isinstance(n, FlyteFile), "n is not FlyteFile"

    # Strict type check for FlyteDirectory
    assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory"

    # # Strict type check for Enum
    assert isinstance(enum_status, Status), "enum_status is not Status"

    assert isinstance(sd, StructuredDataset), "sd is not StructuredDataset"
    print("sd:", sd.open(pd.DataFrame).all())

    assert isinstance(fsc, FlyteSchema), "fsc is not FlyteSchema"
    print("fsc: ", fsc.open().all())

    print("All attributes passed strict type checks.")


@workflow
def wf(dc: BM):
    t_dc(dc=dc)
    t_inner(inner_dc=dc.inner_dc)
    t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c,
                          d=dc.d, e=dc.e, f=dc.f,
                          g=dc.g, h=dc.h, i=dc.i,
                          j=dc.j, k=dc.k, l=dc.l,
                          m=dc.m, n=dc.n, o=dc.o,
                          enum_status=dc.enum_status,
                          sd=dc.sd,
                          fsc=dc.fsc,
                          )

    t_test_all_attributes(a=dc.inner_dc.a, b=dc.inner_dc.b, c=dc.inner_dc.c,
                          d=dc.inner_dc.d, e=dc.inner_dc.e, f=dc.inner_dc.f,
                          g=dc.inner_dc.g, h=dc.inner_dc.h, i=dc.inner_dc.i,
                          j=dc.inner_dc.j, k=dc.inner_dc.k, l=dc.inner_dc.l,
                          m=dc.inner_dc.m, n=dc.inner_dc.n, o=dc.inner_dc.o,
                          enum_status=dc.inner_dc.enum_status,
                          sd=dc.inner_dc.sd,
                          fsc=dc.inner_dc.fsc,
                          )

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Comment on lines 359 to 360
# TODO: remove pydantic v1 plugin, since v2 is in core already
# flytekit-pydantic
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are going to remove pydantic v1 in the future and this will fail when pydantic version > 2 (CI use pydantic version > 2)
Let's comment it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove it?

Comment on lines 246 to 247
from flytekit.deck import Deck
from flytekit.extras import pydantic
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move to lazy import transformer, this will fail to have custom serialize and deserialize behavior, still investigation the root cause.

Future-Outlier and others added 6 commits October 23, 2024 14:35
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw  <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Copy link
Member Author

@Future-Outlier Future-Outlier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do a final test and it works
image

Copy link
Collaborator

@eapolinario eapolinario left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor things, otherwise, it's looking pretty good.

Comment on lines 10 to 11
logger.info(f"Meet error when importing pydantic: `{e}`")
logger.info("Flytekit only support pydantic version > 2.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: those should be a warning.

from pydantic import model_serializer, model_validator

except ImportError:
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Comment on lines +16 to +57
FuncType = TypeVar("FuncType", bound=Callable[..., Any])

from typing_extensions import Literal as typing_literal

def model_serializer(
__f: Union[Callable[..., Any], None] = None,
*,
mode: typing_literal["plain", "wrap"] = "plain",
when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always",
return_type: Any = None,
) -> Callable[[Any], Any]:
"""Placeholder decorator for Pydantic model_serializer."""

def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args, **kwargs):
raise Exception(
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature."
)

return wrapper

# If no function (__f) is provided, return the decorator
if __f is None:
return decorator
# If __f is provided, directly decorate the function
return decorator(__f)

def model_validator(
*,
mode: typing_literal["wrap", "before", "after"],
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Placeholder decorator for Pydantic model_validator."""

def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args, **kwargs):
raise Exception(
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature."
)

return wrapper

return decorator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we supporting only pydantic v2? Why do we have these fallback definitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's to support the case where pydantic is not installed at all. cuz yeah it looks nicer in the real File/Directory class, but we also want it to not fail ofc.

Comment on lines 359 to 360
# TODO: remove pydantic v1 plugin, since v2 is in core already
# flytekit-pydantic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove it?

@@ -215,16 +215,32 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
"""
TODO: Add more comments to explain the lifecycle of attribute access.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill in TODO

Comment on lines +2059 to +2061
if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this going to work though? Do we also do a version check of the backend?

@@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.7.0b0", "pydantic"]
plugin_requires = ["flytekit>=1.7.0b0", "pydantic<2"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also leave a warning in the README.md explaining that we're deprecating this plugin?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eapolinario

how is this going to work though? Do we also do a version check of the backend?
No this is just for supporting the case I've mentioned above.
we didn't support this before and I think we should do it.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
@Future-Outlier
Copy link
Member Author

Let's merge it.
cc @eapolinario @wild-endeavor @pingsutw

Copy link
Contributor

@wild-endeavor wild-endeavor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙇 thank you @Future-Outlier 🙏 this is going to be fantastic.

@@ -1124,6 +1194,8 @@ def lazy_import_transformers(cls):
from flytekit.extras import pytorch # noqa: F401
if is_imported("sklearn"):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pydantic"):
from flytekit.extras import pydantic # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change the name of this folder? pydantic can get confusing because the real library is also called pydantic right?

@@ -2194,6 +2304,34 @@ def _check_and_covert_float(lv: Literal) -> float:
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float")


def _handle_flyte_console_float_input_to_int(lv: Literal) -> int:
"""
Flyte Console is written by JavaScript and JavaScript has only one number type which is float.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically javascript's number type is Number but yeah, sometimes it keeps track of trailing 0s and sometimes it doesn't.

Comment on lines +2059 to +2061
if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After console does the right thing, won't this value be coming in through the binary value instead? Instead of lv.scalar.primitive.integer/float.

Comment on lines +16 to +57
FuncType = TypeVar("FuncType", bound=Callable[..., Any])

from typing_extensions import Literal as typing_literal

def model_serializer(
__f: Union[Callable[..., Any], None] = None,
*,
mode: typing_literal["plain", "wrap"] = "plain",
when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always",
return_type: Any = None,
) -> Callable[[Any], Any]:
"""Placeholder decorator for Pydantic model_serializer."""

def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args, **kwargs):
raise Exception(
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature."
)

return wrapper

# If no function (__f) is provided, return the decorator
if __f is None:
return decorator
# If __f is provided, directly decorate the function
return decorator(__f)

def model_validator(
*,
mode: typing_literal["wrap", "before", "after"],
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Placeholder decorator for Pydantic model_validator."""

def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args, **kwargs):
raise Exception(
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature."
)

return wrapper

return decorator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's to support the case where pydantic is not installed at all. cuz yeah it looks nicer in the real File/Directory class, but we also want it to not fail ofc.

super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)

def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
schema = t.model_json_schema()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, can we add some unit tests to ensure that we're correctly extracting default values into the schema?


bm = BM()
wf(bm=bm)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add two new lines between tests? I know it doesn't matter, but pycharm complains.


def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory):
class InnerBM(BaseModel):
flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

little bit confused about pydantic here. Are you supposed to use dataclasses.field here instead of pydantic.Field?

def test_protocol():
assert get_protocol("s3://my-s3-bucket/file") == "s3"
assert get_protocol("/file") == "file"


def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep spaces plz

flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file))
flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory))

class BM(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a different file, but can we add pydantic models that contain dataclass and also dataclasses that contain Pydantic models? I know some people have been asking for that, be good to have some tests for it.

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants