Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Oct 24, 2024
2 parents 22552f3 + 1a1ee53 commit a67f4b2
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 181 deletions.
15 changes: 1 addition & 14 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

import typing_extensions
from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
Expand All @@ -18,7 +17,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, TypeTransformer, is_annotated
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
Expand All @@ -27,8 +26,6 @@
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql, Task
from flytekit.tools.module_loader import load_object_from_module
from flytekit.types.pickle import pickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,16 +74,6 @@ def __init__(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
)

for k, v in actual_task.python_interface.inputs.items():
if bound_inputs and k in bound_inputs:
continue
transformer: TypeTransformer = TypeEngine.get_transformer(v)
if isinstance(transformer, FlytePickleTransformer):
if is_annotated(v):
for annotation in typing_extensions.get_args(v)[1:]:
if isinstance(annotation, pickle.BatchSize):
raise ValueError("Choosing a BatchSize for map tasks inputs is not supported.")

n_outputs = len(actual_task.python_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")
Expand Down
76 changes: 22 additions & 54 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,18 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.
def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any:
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
Since Flyte types are already serializable, this function is intended for using strings instead of directly creating Flyte files and directories in the dataclass.
An example shows the lifecycle:
@dataclass
class DC:
ff: FlyteFile
@task
def t1() -> DC:
return DC(ff="s3://path")
Lifecycle: DC(ff="s3://path") -> to_literal() -> DC(ff=FlyteFile(path="s3://path")) -> msgpack -> to_python_value() -> DC(ff=FlyteFile(path="s3://path"))
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
Expand Down Expand Up @@ -1521,52 +1533,17 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
"""
This function evaluates whether the provided type is batchable or not.
It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle.
"""
from flytekit.types.pickle import FlytePickle

if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

async def async_to_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
break
if batch_size > 0:
lit_list = [
TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type)
for i in range(0, len(python_val), batch_size)
] # type: ignore
else:
lit_list = []
else:
t = self.get_sub_type(python_type)
lit_list = [
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type))
for x in python_val
]
lit_list = await _run_coros_in_chunks(lit_list)
t = self.get_sub_type(python_type)
lit_list = [
asyncio.create_task(TypeEngine.async_to_literal(ctx, x, t, expected.collection_type)) for x in python_val
]
lit_list = await _run_coros_in_chunks(lit_list)

return Literal(collection=LiteralCollection(literals=lit_list))

Expand All @@ -1585,20 +1562,11 @@ async def async_to_python_value( # type: ignore
f"is not a collection (Flyte's representation of Python lists)."
)
)
if self.is_batchable(expected_python_type):
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
if len(batch_list) > 0 and type(batch_list[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that won't
# merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
else:
st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await _run_coros_in_chunks(result)
return result # type: ignore # should be a list, thinks its a tuple

st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await _run_coros_in_chunks(result)
return result # type: ignore # should be a list, thinks its a tuple

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/pickle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
FlytePickle
"""

from .pickle import BatchSize, FlytePickle
from .pickle import FlytePickle
13 changes: 0 additions & 13 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
T = typing.TypeVar("T")


class BatchSize:
"""
Flyte-specific object used to wrap the hash function for a specific type
"""

def __init__(self, val: int):
self._val = val

@property
def val(self) -> int:
return self._val


class FlytePickle(typing.Generic[T]):
"""
This type is only used by flytekit internally. User should not use this type.
Expand Down
8 changes: 0 additions & 8 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LiteralOffloadedMetadata,
)
from flytekit.tools.translator import get_serializable, Options
from flytekit.types.pickle import BatchSize


@pytest.fixture
Expand Down Expand Up @@ -102,13 +101,6 @@ def say_hello(name: str) -> str:


def test_map_task_with_pickle():
@task
def say_hello(name: Annotated[typing.Any, BatchSize(10)]) -> str:
return f"hello {name}!"

with pytest.raises(ValueError, match="Choosing a BatchSize for map tasks inputs is not supported."):
map_task(say_hello)(name=["abc", "def"])

@task
def say_hello(name: typing.Any) -> str:
return f"hello {name}!"
Expand Down
3 changes: 1 addition & 2 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException
from flytekit.models import literals as literal_models
from flytekit.models.types import LiteralType, SimpleType, TypeStructure
from flytekit.types.pickle.pickle import BatchSize


def test_create_and_link_node():
Expand Down Expand Up @@ -141,7 +140,7 @@ class MyDataclass(DataClassJsonMixin):
)
def test_translate_inputs_to_literals(input):
@task
def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], BatchSize(2)]]):
def t1(a: typing.Union[float, MyDataclass, typing.List[typing.Any]]):
print(a)

ctx = context_manager.FlyteContext.current_context()
Expand Down
84 changes: 1 addition & 83 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from flytekit.types.iterator.iterator import IteratorTransformer
from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET

Expand Down Expand Up @@ -443,16 +443,6 @@ def test_dir_no_downloader_default():
assert pv.download() == local_dir


def test_dir_with_batch_size():
flyte_dir = Annotated[FlyteDirectory, BatchSize(100)]
val = flyte_dir("s3://bucket/key")
transformer = TypeEngine.get_transformer(flyte_dir)
ctx = FlyteContext.current_context()
lt = transformer.get_literal_type(flyte_dir)
lv = transformer.to_literal(ctx, val, flyte_dir, lt)
assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source


def test_dict_transformer():
d = DictTransformer()

Expand Down Expand Up @@ -2725,78 +2715,6 @@ def test_file_ext_with_flyte_file_wrong_type():
assert str(e.value) == "Underlying type of File Extension must be of type <str>"


def test_is_batchable():
assert ListTransformer.is_batchable(typing.List[int]) is False
assert ListTransformer.is_batchable(typing.List[str]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False
assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False

assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True
assert (
ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)])
is True
)


@pytest.mark.parametrize(
"python_val, python_type, expected_list_length",
[
# Case 1: List of FlytePickle objects with default batch size.
# (By default, the batch_size is set to the length of the whole list.)
# After converting to literal, the result will be [batched_FlytePickle(5 items)].
# Therefore, the expected list length is [1].
([{"foo"}] * 5, typing.List[FlytePickle], [1]),
# Case 2: List of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(
["foo"] * 5,
Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)],
[3],
),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
(
[["foo", "foo", "foo"]] * 2,
typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]],
[2, 1],
),
# Case 4: Empty list
([[], typing.List[FlytePickle], []]),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)

tmp_lv = lv
for length in expected_list_length:
# Check that after converting to literal, the length of the literal list is equal to:
# - the length of the original list divided by the batch size if not nested
# - the length of the original list if it contains a nested list
assert len(tmp_lv.collection.literals) == length
tmp_lv = tmp_lv.collection.literals[0]

pv = TypeEngine.to_python_value(ctx, lv, python_type)
# Check that after converting literal to Python value, the result is equal to the original python values.
assert pv == python_val
if get_origin(python_type) is Annotated:
pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0])
# Remove the annotation and check that after converting to Python value, the result is equal
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val


@pytest.mark.parametrize(
"t,expected",
[
Expand Down
7 changes: 1 addition & 6 deletions tests/flytekit/unit/types/pickle/test_flyte_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType
from flytekit.tools.translator import get_serializable
from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -59,11 +59,6 @@ def test_get_literal_type():
assert lt == expected_lt


def test_batch_size():
bs = BatchSize(5)
assert bs.val == 5


def test_nested():
class Foo(object):
def __init__(self, number: int):
Expand Down

0 comments on commit a67f4b2

Please sign in to comment.