From f08466a1fd5b0bc8ce6228c4d1230fff906dace2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 23 Oct 2024 18:58:24 -0700 Subject: [PATCH] delete Signed-off-by: Yee Hing Tong --- flytekit/core/array_node_map_task.py | 15 +--- flytekit/core/type_engine.py | 59 ++----------- flytekit/types/pickle/__init__.py | 2 +- flytekit/types/pickle/pickle.py | 13 --- .../unit/core/test_array_node_map_task.py | 8 -- tests/flytekit/unit/core/test_promise.py | 3 +- tests/flytekit/unit/core/test_type_engine.py | 84 +------------------ .../unit/types/pickle/test_flyte_pickle.py | 7 +- 8 files changed, 13 insertions(+), 178 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index dbab3545a9..5fd184e7fd 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -7,7 +7,6 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast -import typing_extensions from flyteidl.core import tasks_pb2 from flytekit.configuration import SerializationSettings @@ -18,7 +17,7 @@ from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask -from flytekit.core.type_engine import TypeEngine, TypeTransformer, is_annotated +from flytekit.core.type_engine import TypeEngine from flytekit.core.utils import timeit from flytekit.loggers import logger from flytekit.models import literals as _literal_models @@ -27,8 +26,6 @@ from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql, Task from flytekit.tools.module_loader import load_object_from_module -from flytekit.types.pickle import pickle -from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.utils.asyn import loop_manager if TYPE_CHECKING: @@ -77,16 +74,6 @@ def __init__( "Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks." ) - for k, v in actual_task.python_interface.inputs.items(): - if bound_inputs and k in bound_inputs: - continue - transformer: TypeTransformer = TypeEngine.get_transformer(v) - if isinstance(transformer, FlytePickleTransformer): - if is_annotated(v): - for annotation in typing_extensions.get_args(v)[1:]: - if isinstance(annotation, pickle.BatchSize): - raise ValueError("Choosing a BatchSize for map tasks inputs is not supported.") - n_outputs = len(actual_task.python_interface.outputs) if n_outputs > 1: raise ValueError("Only tasks with a single output are supported in map tasks.") diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 900afa8562..cb16127de7 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1520,49 +1520,15 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") - @staticmethod - def is_batchable(t: Type): - """ - This function evaluates whether the provided type is batchable or not. - It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle. - """ - from flytekit.types.pickle import FlytePickle - - if is_annotated(t): - return ListTransformer.is_batchable(get_args(t)[0]) - if get_origin(t) is list: - subtype = get_args(t)[0] - if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): - return True - return False - async def async_to_literal( self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType ) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - if ListTransformer.is_batchable(python_type): - from flytekit.types.pickle.pickle import BatchSize, FlytePickle - - batch_size = len(python_val) # default batch size - # parse annotated to get the number of items saved in a pickle file. - if is_annotated(python_type): - for annotation in get_args(python_type)[1:]: - if isinstance(annotation, BatchSize): - batch_size = annotation.val - break - if batch_size > 0: - lit_list = [ - TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) - for i in range(0, len(python_val), batch_size) - ] # type: ignore - else: - lit_list = [] - else: - t = self.get_sub_type(python_type) - lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val] - lit_list = await asyncio.gather(*lit_list) + t = self.get_sub_type(python_type) + lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val] + lit_list = await asyncio.gather(*lit_list) return Literal(collection=LiteralCollection(literals=lit_list)) @@ -1581,20 +1547,11 @@ async def async_to_python_value( # type: ignore f"is not a collection (Flyte's representation of Python lists)." ) ) - if self.is_batchable(expected_python_type): - from flytekit.types.pickle import FlytePickle - - batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] - if len(batch_list) > 0 and type(batch_list[0]) is list: - # Make it have backward compatibility. The upstream task may use old version of Flytekit that won't - # merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. - return [item for batch in batch_list for item in batch] - return batch_list - else: - st = self.get_sub_type(expected_python_type) - result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits] - result = await asyncio.gather(*result) - return result # type: ignore # should be a list, thinks its a tuple + + st = self.get_sub_type(expected_python_type) + result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits] + result = await asyncio.gather(*result) + return result # type: ignore # should be a list, thinks its a tuple def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py index 44c16b25cd..59833bdc84 100644 --- a/flytekit/types/pickle/__init__.py +++ b/flytekit/types/pickle/__init__.py @@ -10,4 +10,4 @@ FlytePickle """ -from .pickle import BatchSize, FlytePickle +from .pickle import FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 7b4c99cae6..b49b03205a 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -13,19 +13,6 @@ T = typing.TypeVar("T") -class BatchSize: - """ - Flyte-specific object used to wrap the hash function for a specific type - """ - - def __init__(self, val: int): - self._val = val - - @property - def val(self) -> int: - return self._val - - class FlytePickle(typing.Generic[T]): """ This type is only used by flytekit internally. User should not use this type. diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 74ea5ac3b4..7621de3076 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -25,7 +25,6 @@ LiteralOffloadedMetadata, ) from flytekit.tools.translator import get_serializable, Options -from flytekit.types.pickle import BatchSize @pytest.fixture @@ -102,13 +101,6 @@ def say_hello(name: str) -> str: def test_map_task_with_pickle(): - @task - def say_hello(name: Annotated[typing.Any, BatchSize(10)]) -> str: - return f"hello {name}!" - - with pytest.raises(ValueError, match="Choosing a BatchSize for map tasks inputs is not supported."): - map_task(say_hello)(name=["abc", "def"]) - @task def say_hello(name: typing.Any) -> str: return f"hello {name}!" diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index a8a6e3444d..615cc16991 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -23,7 +23,6 @@ from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException from flytekit.models import literals as literal_models from flytekit.models.types import LiteralType, SimpleType, TypeStructure -from flytekit.types.pickle.pickle import BatchSize def test_create_and_link_node(): @@ -141,7 +140,7 @@ class MyDataclass(DataClassJsonMixin): ) def test_translate_inputs_to_literals(input): @task - def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], BatchSize(2)]]): + def t1(a: typing.Union[float, MyDataclass, typing.List[typing.Any]]): print(a) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8721a8d4db..7966b00f2c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -77,7 +77,7 @@ from flytekit.types.iterator.iterator import IteratorTransformer from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON from flytekit.types.pickle import FlytePickle -from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer +from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.types.schema import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET @@ -443,16 +443,6 @@ def test_dir_no_downloader_default(): assert pv.download() == local_dir -def test_dir_with_batch_size(): - flyte_dir = Annotated[FlyteDirectory, BatchSize(100)] - val = flyte_dir("s3://bucket/key") - transformer = TypeEngine.get_transformer(flyte_dir) - ctx = FlyteContext.current_context() - lt = transformer.get_literal_type(flyte_dir) - lv = transformer.to_literal(ctx, val, flyte_dir, lt) - assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source - - def test_dict_transformer(): d = DictTransformer() @@ -2725,78 +2715,6 @@ def test_file_ext_with_flyte_file_wrong_type(): assert str(e.value) == "Underlying type of File Extension must be of type " -def test_is_batchable(): - assert ListTransformer.is_batchable(typing.List[int]) is False - assert ListTransformer.is_batchable(typing.List[str]) is False - assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False - assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False - assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False - - assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True - assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True - assert ( - ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)]) - is True - ) - - -@pytest.mark.parametrize( - "python_val, python_type, expected_list_length", - [ - # Case 1: List of FlytePickle objects with default batch size. - # (By default, the batch_size is set to the length of the whole list.) - # After converting to literal, the result will be [batched_FlytePickle(5 items)]. - # Therefore, the expected list length is [1]. - ([{"foo"}] * 5, typing.List[FlytePickle], [1]), - # Case 2: List of FlytePickle objects with batch size 2. - # After converting to literal, the result will be - # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. - # Therefore, the expected list length is [3]. - ( - ["foo"] * 5, - Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], - [3], - ), - # Case 3: Nested list of FlytePickle objects with batch size 2. - # After converting to literal, the result will be - # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] - # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). - ( - [["foo", "foo", "foo"]] * 2, - typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], - [2, 1], - ), - # Case 4: Empty list - ([[], typing.List[FlytePickle], []]), - ], -) -def test_batch_pickle_list(python_val, python_type, expected_list_length): - ctx = FlyteContext.current_context() - expected = TypeEngine.to_literal_type(python_type) - lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - - tmp_lv = lv - for length in expected_list_length: - # Check that after converting to literal, the length of the literal list is equal to: - # - the length of the original list divided by the batch size if not nested - # - the length of the original list if it contains a nested list - assert len(tmp_lv.collection.literals) == length - tmp_lv = tmp_lv.collection.literals[0] - - pv = TypeEngine.to_python_value(ctx, lv, python_type) - # Check that after converting literal to Python value, the result is equal to the original python values. - assert pv == python_val - if get_origin(python_type) is Annotated: - pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) - # Remove the annotation and check that after converting to Python value, the result is equal - # to the original input values. This is used to simulate the following case: - # @workflow - # def wf(): - # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] - # task1(data=data) # task1(data: typing.List[FlytePickle]) - assert pv == python_val - - @pytest.mark.parametrize( "t,expected", [ diff --git a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py index 48c0770593..51f119affc 100644 --- a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py +++ b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py @@ -16,7 +16,7 @@ from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType from flytekit.tools.translator import get_serializable -from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer +from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -59,11 +59,6 @@ def test_get_literal_type(): assert lt == expected_lt -def test_batch_size(): - bs = BatchSize(5) - assert bs.val == 5 - - def test_nested(): class Foo(object): def __init__(self, number: int):