diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d51f71d837..426e0f9d98 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata @@ -38,7 +38,7 @@ def __init__( name: str, image: str, command: List[str], - inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, + inputs: Optional[Union[Dict[str, Tuple[Type, Any]], OrderedDict[str, Type]]] = None, metadata: Optional[TaskMetadata] = None, arguments: Optional[List[str]] = None, outputs: Optional[Dict[str, Type]] = None, diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 6a11c9dc50..c8895f2d91 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -10,6 +10,7 @@ from contextlib import contextmanager from typing import Any, Dict, List, Optional, Set +from flytekit import ContainerTask from flytekit.configuration import SerializationSettings from flytekit.core import tracker from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin @@ -33,7 +34,7 @@ class MapPythonTask(PythonTask): def __init__( self, - python_function_task: typing.Union[PythonFunctionTask, functools.partial], + python_function_task: typing.Union[PythonFunctionTask, ContainerTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, @@ -63,8 +64,8 @@ def __init__( else: actual_task = python_function_task - if not isinstance(actual_task, PythonFunctionTask): - raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + if not isinstance(actual_task, (PythonFunctionTask, ContainerTask)): + raise ValueError("Map tasks can only compose of Python Function or Container Tasks currently") if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") @@ -75,9 +76,13 @@ def __init__( collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) self._run_task: PythonFunctionTask = actual_task - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() - name = f"{mod}.map_{f}_{h}" + + if isinstance(actual_task, ContainerTask): + name = f"raw_container_task.mapper_{actual_task.name}_{h}" + else: + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + name = f"{mod}.map_{f}_{h}" self._cmd_prefix: typing.Optional[typing.List[str]] = None self._max_concurrency: typing.Optional[int] = concurrency @@ -142,14 +147,20 @@ def prepare_target(self): self._run_task.reset_command_fn() def get_container(self, settings: SerializationSettings) -> Container: + if isinstance(self._run_task, ContainerTask): + return self._run_task.get_container(settings) with self.prepare_target(): return self._run_task.get_container(settings) def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod: + if isinstance(self._run_task, ContainerTask): + return self._run_task.get_k8s_pod(settings) with self.prepare_target(): return self._run_task.get_k8s_pod(settings) def get_sql(self, settings: SerializationSettings) -> Sql: + if isinstance(self._run_task, ContainerTask): + return self._run_task.get_sql(settings) with self.prepare_target(): return self._run_task.get_sql(settings) @@ -270,7 +281,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - task_function: typing.Union[PythonFunctionTask, functools.partial], + task_function: typing.Union[PythonFunctionTask, functools.partial, ContainerTask], concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 8b30fc4d36..9e13ed34bb 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -180,7 +180,10 @@ def get_serializable_task( if settings.should_fast_serialize(): # This handles container tasks. - if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask)): + if container and ( + isinstance(entity, PythonAutoContainerTask) + or (isinstance(entity, MapPythonTask) and isinstance(entity.run_task, PythonAutoContainerTask)) + ): # For fast registration, we'll need to muck with the command, but on # ly for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index d032aca2d1..f7e15a04d6 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -5,7 +5,7 @@ import pytest import flytekit.configuration -from flytekit import LaunchPlan, map_task +from flytekit import ContainerTask, LaunchPlan, kwtypes, map_task from flytekit.configuration import Image, ImageConfig from flytekit.core.map_task import MapPythonTask, MapTaskResolver from flytekit.core.task import TaskMetadata, task @@ -25,6 +25,22 @@ def serialization_settings(): ) +raw_container = ContainerTask( + name="ellipse-area-metadata-python", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=int), + outputs=kwtypes(area=float), + image="flyte/raw-container:v1", + command=[ + "python", + "test.py", + "{{.inputs.a}}", + "/var/outputs", + ], +) + + @task def t1(a: int) -> str: b = a + 2 @@ -106,6 +122,23 @@ def test_serialization(serialization_settings): ] +def test_serialization_with_raw_container(serialization_settings): + maptask = map_task(raw_container, metadata=TaskMetadata(retries=1)) + task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) + + # By default all map_task tasks will have their custom fields set. + assert task_spec.template.custom["minSuccessRatio"] == 1.0 + assert task_spec.template.type == "container_array" + assert task_spec.template.task_type_version == 1 + assert task_spec.template.container.args is None + assert task_spec.template.container.command == [ + "python", + "test.py", + "{{.inputs.a}}", + "/var/outputs", + ] + + @pytest.mark.parametrize( "custom_fields_dict, expected_custom_fields", [