diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index f56b48994c..122a739265 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -7,11 +7,12 @@ import tempfile import typing from dataclasses import dataclass, field, fields -from typing import get_args +from typing import Iterator, get_args import rich_click as click from mashumaro.codecs.json import JSONEncoder from rich.progress import Progress +from typing_extensions import get_origin from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal from flytekit.clis.sdk_in_container.helpers import patch_image_config @@ -538,10 +539,21 @@ def _run(*args, **kwargs): for input_name, v in entity.python_interface.inputs_with_defaults.items(): processed_click_value = kwargs.get(input_name) optional_v = False + + skip_default_value_selection = False if processed_click_value is None and isinstance(v, typing.Tuple): - optional_v = is_optional(v[0]) - if len(v) == 2: - processed_click_value = v[1] + if entity_type == "workflow" and hasattr(v[0], "__args__"): + origin_base_type = get_origin(v[0]) + if inspect.isclass(origin_base_type) and issubclass(origin_base_type, Iterator): # Iterator + args = getattr(v[0], "__args__") + if isinstance(args, tuple) and get_origin(args[0]) is typing.Union: # Iterator[JSON] + logger.debug(f"Detected Iterator[JSON] in {entity.name} input annotations...") + skip_default_value_selection = True + + if not skip_default_value_selection: + optional_v = is_optional(v[0]) + if len(v) == 2: + processed_click_value = v[1] if isinstance(processed_click_value, ArtifactQuery): if run_level_params.is_remote: click.secho( diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 3bb7697d47..ad85d588af 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -9,19 +9,34 @@ import pytest import yaml from click.testing import CliRunner +from flytekit.loggers import logging, logger from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command +from flytekit.clis.sdk_in_container.run import ( + RunLevelParams, + get_entities_in_file, + run_command, +) from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, calculate_hash_from_image_spec +from flytekit.image_spec.image_spec import ( + ImageBuildEngine, + ImageSpec, + calculate_hash_from_image_spec, +) from flytekit.interaction.click_types import DirParamType, FileParamType from flytekit.remote import FlyteRemote +from typing import Iterator +from flytekit.types.iterator import JSON +from flytekit import workflow + pytest.importorskip("pandas") REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py" -IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") +IMPERATIVE_WORKFLOW_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py" +) DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -46,7 +61,9 @@ def workflow_file(request, tmp_path_factory): @pytest.fixture def remote(): with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: - flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote = FlyteRemote( + config=Config.auto(), default_project="p1", default_domain="d1" + ) flyte_remote._client = mock_client return flyte_remote @@ -70,7 +87,9 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file): with mock.patch("flytekit.configuration.plugin.FlyteRemote"): runner = CliRunner() result = runner.invoke( - pyflyte.main, ["run", remote_flag, workflow_file, "my_wf", "--help"], catch_exceptions=False + pyflyte.main, + ["run", remote_flag, workflow_file, "my_wf", "--help"], + catch_exceptions=False, ) assert result.exit_code == 0 @@ -81,7 +100,9 @@ def test_pyflyte_run_with_labels(): with mock.patch("flytekit.configuration.plugin.FlyteRemote"): runner = CliRunner() result = runner.invoke( - pyflyte.main, ["run", "--remote", str(workflow_file), "my_wf", "--help"], catch_exceptions=False + pyflyte.main, + ["run", "--remote", str(workflow_file), "my_wf", "--help"], + catch_exceptions=False, ) assert result.exit_code == 0 @@ -100,7 +121,16 @@ def test_copy_all_files(): runner = CliRunner() result = runner.invoke( pyflyte.main, - ["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + [ + "run", + "--copy-all", + IMPERATIVE_WORKFLOW_FILE, + "wf", + "--in1", + "hello", + "--in2", + "world", + ], catch_exceptions=False, ) assert result.exit_code == 0 @@ -176,7 +206,13 @@ def test_pyflyte_run_cli(workflow_file): @pytest.mark.parametrize( "input", - ["1", os.path.join(DIR_NAME, "testdata/df.parquet"), '{"x":1.0, "y":2.0}', "2020-05-01", "RED"], + [ + "1", + os.path.join(DIR_NAME, "testdata/df.parquet"), + '{"x":1.0, "y":2.0}', + "2020-05-01", + "RED", + ], ) def test_union_type1(input): runner = CliRunner() @@ -300,7 +336,10 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch): ], catch_exceptions=False, ) - assert result.stdout.strip() == "Running Execution on local.\nRunning Execution on local." + assert ( + result.stdout.strip() + == "Running Execution on local.\nRunning Execution on local." + ) assert result.exit_code == 0 @@ -325,12 +364,18 @@ def test_list_default_arguments(wf_path): # default case, what comes from click if no image is specified, the click param is configured to use the default. ic_result_1 = ImageConfig( - default_image=Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"), - images=[Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")], + default_image=Image( + name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest" + ), + images=[ + Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest") + ], ) # test that command line args are merged with the file ic_result_2 = ImageConfig( - default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"), + default_image=Image( + name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest" + ), images=[ Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"), Image(name="asdf", fqn="ghcr.io/asdf/asdf", tag="latest"), @@ -345,7 +390,9 @@ def test_list_default_arguments(wf_path): ) # test that command line args override the file ic_result_3 = ImageConfig( - default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"), + default_image=Image( + name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest" + ), images=[ Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"), Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"), @@ -395,21 +442,29 @@ def test_list_default_arguments(wf_path): reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777", ) def test_pyflyte_run_run( - mock_image, image_string, leaf_configuration_file_name, final_image_config, mock_image_spec_builder + mock_image, + image_string, + leaf_configuration_file_name, + final_image_config, + mock_image_spec_builder, ): mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest" ImageBuildEngine.register("test", mock_image_spec_builder) @task - def tk(): - ... + def tk(): ... mock_click_ctx = mock.MagicMock() mock_remote = mock.MagicMock() image_tuple = (image_string,) image_config = ImageConfig.validate_image(None, "", image_tuple) - pp = pathlib.Path(__file__).parent.parent.parent / "configuration" / "configs" / leaf_configuration_file_name + pp = ( + pathlib.Path(__file__).parent.parent.parent + / "configuration" + / "configs" + / leaf_configuration_file_name + ) obj = RunLevelParams( project="p", @@ -429,6 +484,125 @@ def check_image(*args, **kwargs): run_command(mock_click_ctx, tk)() +def jsons(): + for x in [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + ], + }, + }, + ]: + yield x + + +@mock.patch("flytekit.configuration.default_images.DefaultImages.default_image") +def test_pyflyte_run_with_iterator_json_type( + mock_image, mock_image_spec_builder, caplog +): + mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest" + ImageBuildEngine.register( + "test", + mock_image_spec_builder, + ) + + @task + def t1(x: Iterator[JSON]) -> Iterator[JSON]: + return x + + @workflow + def tk(x: Iterator[JSON] = jsons()) -> Iterator[JSON]: + return t1(x=x) + + @task + def t2(x: list[int]) -> list[int]: + return x + + @workflow + def tk_list(x: list[int] = [1, 2, 3]) -> list[int]: + return t2(x=x) + + @task + def t3(x: Iterator[int]) -> Iterator[int]: + return x + + @workflow + def tk_simple_iterator(x: Iterator[int] = iter([1, 2, 3])) -> Iterator[int]: + return t3(x=x) + + mock_click_ctx = mock.MagicMock() + mock_remote = mock.MagicMock() + image_tuple = ("ghcr.io/flyteorg/mydefault:py3.9-latest",) + image_config = ImageConfig.validate_image(None, "", image_tuple) + + pp = ( + pathlib.Path(__file__).parent.parent.parent + / "configuration" + / "configs" + / "no_images.yaml" + ) + + obj = RunLevelParams( + project="p", + domain="d", + image_config=image_config, + remote=True, + config_file=str(pp), + ) + obj._remote = mock_remote + mock_click_ctx.obj = obj + + def check_image(*args, **kwargs): + assert kwargs["image_config"] == ic_result_1 + + mock_remote.register_script.side_effect = check_image + + logger.propagate = True + with caplog.at_level(logging.DEBUG, logger="flytekit"): + run_command(mock_click_ctx, tk)() + assert any( + "Detected Iterator[JSON] in pyflyte.test_run.tk input annotations..." + in message[2] + for message in caplog.record_tuples + ) + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="flytekit"): + run_command(mock_click_ctx, tk_list)() + assert not any( + "Detected Iterator[JSON] in pyflyte.test_run.tk_list input annotations..." + in message[2] + for message in caplog.record_tuples + ) + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="flytekit"): + run_command(mock_click_ctx, t1)() + assert not any( + "Detected Iterator[JSON] in pyflyte.test_run.t1 input annotations..." + in message[2] + for message in caplog.record_tuples + ) + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="flytekit"): + run_command(mock_click_ctx, tk_simple_iterator)() + assert not any( + "Detected Iterator[JSON] in pyflyte.test_run.tk_simple_iterator input annotations..." + in message[2] + for message in caplog.record_tuples + ) + + def test_file_param(): m = mock.MagicMock() flyte_file = FileParamType().convert(__file__, m, m) @@ -484,7 +658,11 @@ def test_pyflyte_run_with_none(a_val, workflow_file): "envs, envs_argument, expected_output", [ (["--env", "MY_ENV_VAR=hello"], '["MY_ENV_VAR"]', "hello"), - (["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], '["MY_ENV_VAR","ABC"]', "hello,42"), + ( + ["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], + '["MY_ENV_VAR","ABC"]', + "hello,42", + ), ], ) @pytest.mark.parametrize(