Skip to content

Commit

Permalink
[Bug] Set bindings for ArrayNode (flyteorg#2742)
Browse files Browse the repository at this point in the history
* wip/hack set bindings

Signed-off-by: Paul Dittamo <[email protected]>

* don't link node when getting bindings from array node subnode

Signed-off-by: Paul Dittamo <[email protected]>

* update param description

Signed-off-by: Paul Dittamo <[email protected]>

* only create node when compiling while setting bindings/calling an ArrayNode

Signed-off-by: Paul Dittamo <[email protected]>

* utilize all inputs when getting input bindings for a subnode

Signed-off-by: Paul Dittamo <[email protected]>

* update create_and_link_node_from_remote

Signed-off-by: Paul Dittamo <[email protected]>

* update create_and_link_node_from_remote

Signed-off-by: Paul Dittamo <[email protected]>

* undo linking node changes to create_and_link_node_from_remote

Signed-off-by: Paul Dittamo <[email protected]>

* undo linking node changes to create_and_link_node_from_remote

Signed-off-by: Paul Dittamo <[email protected]>

* set type to List instead of optional

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

* utilize input bindings for array node instead of undering subnode interface for local execute

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt authored Sep 19, 2024
1 parent 570de08 commit 9bce7c3
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 29 deletions.
36 changes: 30 additions & 6 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.core.promise import (
Promise,
VoidPromise,
create_and_link_node,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
Expand All @@ -20,16 +21,18 @@
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Literal, LiteralCollection, Scalar

ARRAY_NODE_SUBNODE_NAME = "array_node_subnode"


class ArrayNode:
def __init__(
self,
target: LaunchPlan,
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None,
):
"""
Expand All @@ -41,14 +44,14 @@ def __init__(
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param bound_inputs: The set of inputs that should be bound to the map task
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
self.id = target.name
self._bindings = bindings or []

if min_successes is not None:
self._min_successes = min_successes
Expand All @@ -61,7 +64,8 @@ def __init__(
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

self._bound_inputs: Set[str] = bound_inputs or set(bound_inputs) if bound_inputs else set()
# TODO - bound inputs are not supported at the moment
self._bound_inputs: Set[str] = set()

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
Expand Down Expand Up @@ -99,7 +103,7 @@ def python_interface(self) -> flyte_interface.Interface:
@property
def bindings(self) -> List[_literal_models.Binding]:
# Required in get_serializable_node
return []
return self._bindings

@property
def upstream_nodes(self) -> List[Node]:
Expand All @@ -116,7 +120,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
outputs_expected = False

mapped_entity_count = 0
for k in self.python_interface.inputs.keys():
for binding in self.bindings:
k = binding.var
if k not in self._bound_inputs:
v = kwargs[k]
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self.target.python_interface.inputs[k]):
Expand All @@ -137,7 +142,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
literals = []
for i in range(mapped_entity_count):
single_instance_inputs = {}
for k in self.python_interface.inputs.keys():
for binding in self.bindings:
k = binding.var
if k not in self._bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
Expand Down Expand Up @@ -190,6 +196,24 @@ def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode

def __call__(self, *args, **kwargs):
if not self._bindings:
ctx = FlyteContext.current_context()
# since a new entity with an updated list interface is not created, we have to work around the mismatch
# between the interface and the inputs
collection_interface = transform_interface_to_list_interface(
self.flyte_entity.python_interface, self._bound_inputs
)
# don't link the node to the compilation state, since we don't want to add the subnode to the
# workflow as a node
bound_subnode = create_and_link_node(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=collection_interface,
node_id=ARRAY_NODE_SUBNODE_NAME,
**kwargs,
)
self._bindings = bound_subnode.ref.node.bindings
return flyte_entity_call_handler(self, *args, **kwargs)


Expand Down
27 changes: 22 additions & 5 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,9 @@ def create_and_link_node_from_remote(
def create_and_link_node(
ctx: FlyteContext,
entity: SupportsNodeCreation,
overridden_interface: Optional[Interface] = None,
add_node_to_compilation_state: bool = True,
node_id: str = "",
**kwargs,
) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
"""
Expand All @@ -1140,17 +1143,22 @@ def create_and_link_node(
:param ctx: FlyteContext
:param entity: RemoteEntity
:param add_node_to_compilation_state: bool that enables for nodes to be created but not linked to the workflow. This
is useful when creating nodes nested under other nodes such as ArrayNode
:param overridden_interface: utilize this interface instead of the one provided by the entity. This is useful for
ArrayNode as there's a mismatch between the underlying interface and inputs
:param node_id: str if provided, this will be used as the node id.
:param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises.
:return: Optional[Union[Tuple[Promise], Promise, VoidPromise]]
"""
if ctx.compilation_state is None:
if ctx.compilation_state is None and add_node_to_compilation_state:
raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...")

used_inputs = set()
bindings = []
nodes = []

interface = entity.python_interface
interface = overridden_interface or entity.python_interface
typed_interface = flyte_interface.transform_interface_to_typed_interface(
interface, allow_partial_artifact_id_binding=True
)
Expand Down Expand Up @@ -1214,15 +1222,24 @@ def create_and_link_node(
# These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(set([n for n in nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID]))

# TODO: Better naming, probably a derivative of the function name.
# if not adding to compilation state, we don't need to generate a unique node id
node_id = node_id or (
f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}"
if add_node_to_compilation_state and ctx.compilation_state
else node_id
)

flytekit_node = Node(
# TODO: Better naming, probably a derivative of the function name.
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
id=node_id,
metadata=entity.construct_node_metadata(),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_entity=entity,
)
ctx.compilation_state.add_node(flytekit_node)

if add_node_to_compilation_state and ctx.compilation_state:
ctx.compilation_state.add_node(flytekit_node)

if len(typed_interface.outputs) == 0:
return VoidPromise(entity.name, NodeOutput(node=flytekit_node, var="placeholder"))
Expand Down
56 changes: 38 additions & 18 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,57 @@ def serialization_settings():


@task
def multiply(val: int, val1: int) -> int:
return val * val1
def multiply(val: int, val1: typing.Union[int, str], val2: int) -> int:
if type(val1) is str:
return val * val2
return val * int(val1) * val2


@workflow
def parent_wf(a: int, b: int) -> int:
return multiply(val=a, val1=b)
def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int:
return multiply(val=a, val1=b, val2=c)


lp = LaunchPlan.get_default_launch_plan(current_context(), parent_wf)


@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=[2, 4, 6])
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])


def test_lp_serialization(serialization_settings):

wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].array_node is not None
assert wf_spec.template.nodes[0].array_node.node is not None
assert wf_spec.template.nodes[0].array_node.node.workflow_node is not None

top_level = wf_spec.template.nodes[0]
assert top_level.inputs[0].var == "a"
assert len(top_level.inputs[0].binding.collection.bindings) == 3
for binding in top_level.inputs[0].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None
assert top_level.inputs[1].var == "b"
for binding in top_level.inputs[1].binding.collection.bindings:
assert binding.scalar.union is not None
assert len(top_level.inputs[1].binding.collection.bindings) == 3
assert top_level.inputs[2].var == "c"
assert len(top_level.inputs[2].binding.collection.bindings) == 3
for binding in top_level.inputs[2].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None

serialized_array_node = top_level.array_node
assert (
wf_spec.template.nodes[0].array_node.node.workflow_node.launchplan_ref.resource_type
== identifier_models.ResourceType.LAUNCH_PLAN
serialized_array_node.node.workflow_node.launchplan_ref.resource_type
== identifier_models.ResourceType.LAUNCH_PLAN
)
assert wf_spec.template.nodes[0].array_node.node.workflow_node.launchplan_ref.name == "tests.flytekit.unit.core.test_array_node.parent_wf"
assert wf_spec.template.nodes[0].array_node._min_success_ratio == 0.9
assert wf_spec.template.nodes[0].array_node._parallelism == 10
assert (
serialized_array_node.node.workflow_node.launchplan_ref.name
== "tests.flytekit.unit.core.test_array_node.parent_wf"
)
assert serialized_array_node._min_success_ratio == 0.9
assert serialized_array_node._parallelism == 10

subnode = serialized_array_node.node
assert subnode.inputs == top_level.inputs


@pytest.mark.parametrize(
Expand Down Expand Up @@ -97,8 +117,8 @@ def grandparent_ex_wf() -> typing.List[typing.Optional[int]]:


def test_map_task_wrapper():
mapped_task = map_task(multiply)(val=[1, 3, 5], val1=[2, 4, 6])
assert mapped_task == [2, 12, 30]
mapped_task = map_task(multiply)(val=[1, 3, 5], val1=[2, 4, 6], val2=[7, 8, 9])
assert mapped_task == [14, 96, 270]

mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6])
assert mapped_lp == [2, 12, 30]
mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9])
assert mapped_lp == [14, 96, 270]
9 changes: 9 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,23 @@ def t1(a: typing.Union[int, typing.List[int]]) -> typing.Union[int, typing.List[
assert p.ref.node_id == "n0"
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1
assert len(ctx.compilation_state.nodes) == 1

@task
def t2(a: typing.Optional[int] = None) -> typing.Optional[int]:
return a

ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix=""))
p = create_and_link_node(ctx, t2)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1
assert len(ctx.compilation_state.nodes) == 1

ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix=""))
p = create_and_link_node(ctx, t2, add_node_to_compilation_state=False)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1
assert len(ctx.compilation_state.nodes) == 0


def test_create_and_link_node_from_remote():
Expand Down

0 comments on commit 9bce7c3

Please sign in to comment.