diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 832ab17a1c..cfe2be1ad8 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -15,7 +15,7 @@ import flytekit from flytekit import PythonFunctionTask, Resources, lazy_module from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import OutputMetadata +from flytekit.core.context_manager import FlyteContextManager, OutputMetadata from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model from flytekit.exceptions.user import FlyteRecoverableException @@ -429,13 +429,18 @@ def fn_partial(): """Closure of the task function with kwargs already bound.""" try: return_val = self._task_function(**kwargs) + core_context = FlyteContextManager.current_context() + omt = core_context.output_metadata_tracker + om = None + if omt: + om = omt.get(return_val) except Exception as e: # See explanation in `create_recoverable_error_file` why we check # for recoverable errors here in the worker processes. if isinstance(e, FlyteRecoverableException): create_recoverable_error_file() raise - return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=None) + return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) launcher_target_func = fn_partial launcher_args = () @@ -470,7 +475,8 @@ def fn_partial(): if not isinstance(deck, flytekit.deck.deck.TimeLineDeck): ctx.decks.append(deck) if out[0].om: - ctx.output_metadata_tracker.add(out[0].return_value, out[0].om) + core_context = FlyteContextManager.current_context() + core_context.output_metadata_tracker.add(out[0].return_value, out[0].om) return out[0].return_value else: diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index b56fc0aa08..39f1e0bb80 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -2,6 +2,10 @@ import typing from dataclasses import dataclass from unittest import mock +from typing_extensions import Annotated, cast +from flytekitplugins.kfpytorch.task import Elastic + +from flytekit import Artifact import pytest import torch @@ -11,6 +15,7 @@ import flytekit from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker from flytekit.configuration import SerializationSettings from flytekit.exceptions.user import FlyteRecoverableException @@ -159,6 +164,41 @@ def wf(): assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html +class Card(object): + def __init__(self, text: str): + self.text = text + + def serialize_to_string(self, ctx: FlyteContext, variable_name: str): + print(f"In serialize_to_string: {id(ctx)}") + return "card", "card" + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_output_metadata_passing(start_method: str) -> None: + ea = Artifact(name="elastic-artf") + + @task( + task_config=Elastic(start_method=start_method), + ) + def train2() -> Annotated[str, ea]: + return ea.create_from("hello flyte", Card("## card")) + + @workflow + def wf(): + train2() + + ctx = FlyteContext.current_context() + omt = OutputMetadataTracker() + with FlyteContextManager.with_context( + ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt) + ) as child_ctx: + cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] + # call execute directly so as to be able to get at the same FlyteContext object. + res = train2.execute() + om = child_ctx.output_metadata_tracker.get(res) + assert len(om.additional_items) == 1 + + @pytest.mark.parametrize( "recoverable,start_method", [