From 636c4aa089fcdc48d70c411e7c4a297799fbfc90 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 24 Oct 2024 23:57:50 -0700 Subject: [PATCH] clean up setting passed in metadata for array node Signed-off-by: Paul Dittamo --- flytekit/core/array_node.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 9a6f08689a..c83da43f87 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -19,7 +19,6 @@ flyte_entity_call_handler, translate_inputs_to_literals, ) -from flytekit.core.task import TaskMetadata from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -41,7 +40,7 @@ def __init__( concurrency: Optional[int] = None, min_successes: Optional[int] = None, min_success_ratio: Optional[float] = None, - metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None, + metadata: Optional[_workflow_model.NodeMetadata] = None, ): """ :param target: The target Flyte entity to map over @@ -53,7 +52,7 @@ def __init__( min_success_ratio :param min_success_ratio: The minimum ratio of successful executions. :param execution_mode: The execution mode for propeller to use when handling ArrayNode - :param metadata: The metadata for the underlying entity + :param metadata: The metadata for the underlying node """ from flytekit.remote import FlyteLaunchPlan @@ -62,6 +61,7 @@ def __init__( self._execution_mode = execution_mode self.id = target.name self._bindings = bindings or [] + self.metadata = metadata if min_successes is not None: self._min_successes = min_successes @@ -92,22 +92,15 @@ def __init__( else: raise ValueError("No interface found for the target entity.") - self.metadata = None if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE: raise ValueError("Only execution version 1 is supported for LaunchPlans.") - if metadata: - if isinstance(metadata, _workflow_model.NodeMetadata): - self.metadata = metadata - else: - raise TypeError("Invalid metadata for LaunchPlan. Should be NodeMetadata.") else: raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}") def construct_node_metadata(self) -> _workflow_model.NodeMetadata: # Part of SupportsNodeCreation interface - # TODO - include passed in metadata - return _workflow_model.NodeMetadata(name=self.target.name) + return self.metadata or _workflow_model.NodeMetadata(name=self.target.name) @property def name(self) -> str: