From 825a50e9e92c359264fb5cb610b1e2a4c5678009 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 2 Jul 2024 09:54:58 -0700 Subject: [PATCH] Support s3 path for click Dir param type (#2547) Signed-off-by: Future-Outlier --- flytekit/interaction/click_types.py | 10 ++++++---- tests/flytekit/unit/interaction/test_click_types.py | 8 ++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 5728dce3d0..f0426ac7f2 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -80,13 +80,15 @@ def convert( ) -> typing.Any: if isinstance(value, ArtifactQuery): return value - p = pathlib.Path(value) + # set remote_directory to false if running pyflyte run locally. This makes sure that the original # directory is used and not a random one. remote_directory = None if getattr(ctx.obj, "is_remote", False) else False - if p.exists() and p.is_dir(): - return FlyteDirectory(path=value, remote_directory=remote_directory) - raise click.BadParameter(f"parameter should be a valid directory path, {value}") + if not FileAccessProvider.is_remote(value): + p = pathlib.Path(value) + if not p.exists() or not p.is_dir(): + raise click.BadParameter(f"parameter should be a valid flytedirectory path, {value}") + return FlyteDirectory(path=value, remote_directory=remote_directory) class StructuredDatasetParamType(click.ParamType): diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index cb32982916..e21a283271 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -28,6 +28,14 @@ dummy_param = click.Option(["--dummy"], type=click.STRING, default="dummy") +def test_dir_param(): + import os + m = mock.MagicMock() + current_file_directory = os.path.dirname(os.path.abspath(__file__)) + l = DirParamType().convert(current_file_directory, m, m) + assert l.path == current_file_directory + r = DirParamType().convert("https://tmp/dir", m, m) + assert r.path == "https://tmp/dir" def test_file_param(): m = mock.MagicMock()