Skip to content

Commit

Permalink
Explicitly disallow iteration on Promises (#2337)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Apr 10, 2024
1 parent 11cfddc commit 177571b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
11 changes: 10 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,18 @@ def wf():
The attribute keys are appended on the promise and a new promise is returned with the updated attribute path.
We don't modify the original promise because it might be used in other places as well.
"""

return self._append_attr(key)

def __iter__(self):
"""
Flyte/kit (as of https://github.com/flyteorg/flyte/issues/3864) supports indexing into a list of promises.
But it still doesn't make sense to
"""
raise ValueError(
"Promise objects are not iterable - can't range() over a promise."
" But you can use [index] or the still stabilizing @eager"
)

def __getattr__(self, key) -> Promise:
"""
When we use . to access the attribute on the promise, for example
Expand Down
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from collections import OrderedDict
from typing import List

import pytest

Expand All @@ -15,6 +16,7 @@
from flytekit.core.workflow import workflow
from flytekit.models.literals import LiteralMap
from flytekit.tools.translator import get_serializable_task
from flytekit.types.file import FlyteFile

settings = flytekit.configuration.SerializationSettings(
project="test_proj",
Expand Down Expand Up @@ -290,3 +292,37 @@ def dt(mode: int) -> int:
serialised_entities_iterator = iter(entity_mapping.values())
assert "t1" in next(serialised_entities_iterator).template.id.name
assert "t2" in next(serialised_entities_iterator).template.id.name


def test_iter():
@task(requests=Resources(mem="5Gi"))
def ff_list_task() -> List[FlyteFile]:
return [FlyteFile(path=__file__, remote_path=False), FlyteFile(path=__file__, remote_path=False)]

@workflow
def sub_wf(input_file: FlyteFile) -> FlyteFile:
return input_file

@dynamic(requests=Resources(mem="5Gi"))
def dynamic_task() -> List[FlyteFile]:
batched_input_files = ff_list_task()
result_files: List[FlyteFile] = []

for _ in batched_input_files:
...

return result_files

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {})
with pytest.raises(ValueError):
dynamic_task.dispatch_execute(ctx, input_literal_map)

0 comments on commit 177571b

Please sign in to comment.