From e94461b048ca0f593adf3393dfbb618f035da346 Mon Sep 17 00:00:00 2001 From: Amrit Ghimire Date: Wed, 5 Feb 2025 21:52:20 +0545 Subject: [PATCH 1/2] Add datasets filter for `datachain ds ls` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With this change, a optional positional arg is added for datachain ds ls. Example output: ``` datachain ds ls temp Name Studio Local ------ -------- ------- temp v1 ✖ temp v2 ✖ temp v3 ✖ temp v4 ✖ temp v5 ✖ temp v6 ✖ temp v7 ✖ temp v8 ✖ temp v9 ✖ temp v10 ✖ temp v11 ✖ temp v12 ✖ temp v13 v13 ``` --- src/datachain/cli/__init__.py | 1 + src/datachain/cli/commands/datasets.py | 23 ++++++++++++++++++++--- src/datachain/cli/parser/__init__.py | 3 +++ src/datachain/studio.py | 25 ++++++++++++++++++++++++- tests/conftest.py | 16 +++++++++++----- tests/test_cli_studio.py | 12 +++++++++++- 6 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/datachain/cli/__init__.py b/src/datachain/cli/__init__.py index 05fb67c2c..913c9f48c 100644 --- a/src/datachain/cli/__init__.py +++ b/src/datachain/cli/__init__.py @@ -161,6 +161,7 @@ def handle_dataset_command(args, catalog): all=args.all, team=args.team, latest_only=not args.versions, + dataset_name=args.dataset_name, ), "rm": lambda: rm_dataset( catalog, diff --git a/src/datachain/cli/commands/datasets.py b/src/datachain/cli/commands/datasets.py index 403ccca0f..dd9f3cb90 100644 --- a/src/datachain/cli/commands/datasets.py +++ b/src/datachain/cli/commands/datasets.py @@ -33,13 +33,20 @@ def list_datasets( all: bool = True, team: Optional[str] = None, latest_only: bool = True, + dataset_name: Optional[str] = None, ): token = Config().read().get("studio", {}).get("token") all, local, studio = determine_flavors(studio, local, all, token) + if dataset_name: + latest_only = False - local_datasets = set(list_datasets_local(catalog)) if all or local else set() + local_datasets = ( + set(list_datasets_local(catalog, dataset_name)) if all or local else set() + ) studio_datasets = ( - set(list_datasets_studio(team=team)) if (all or studio) and token else set() + set(list_datasets_studio(team=team, dataset_name=dataset_name)) + if (all or studio) and token + else set() ) # Group the datasets for both local and studio sources. @@ -89,12 +96,22 @@ def list_datasets( print(tabulate(rows, headers="keys")) -def list_datasets_local(catalog: "Catalog"): +def list_datasets_local(catalog: "Catalog", dataset_name: Optional[str] = None): + if dataset_name: + yield from list_datasets_local_versions(catalog, dataset_name) + return + for d in catalog.ls_datasets(): for v in d.versions: yield (d.name, v.version) +def list_datasets_local_versions(catalog: "Catalog", dataset_name: str): + ds = catalog.get_dataset(dataset_name) + for v in ds.versions: + yield (dataset_name, v.version) + + def _datasets_tabulate_row(name, both, local_version, studio_version): row = { "Name": name, diff --git a/src/datachain/cli/parser/__init__.py b/src/datachain/cli/parser/__init__.py index 6b6380443..8319da52d 100644 --- a/src/datachain/cli/parser/__init__.py +++ b/src/datachain/cli/parser/__init__.py @@ -254,6 +254,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 description="List datasets.", formatter_class=CustomHelpFormatter, ) + datasets_ls_parser.add_argument( + "dataset_name", action="store", help="Name of the dataset to list", nargs="?" + ) datasets_ls_parser.add_argument( "--versions", action="store_true", diff --git a/src/datachain/studio.py b/src/datachain/studio.py index f2e564e10..94d758c30 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -140,11 +140,18 @@ def token(): print(token) -def list_datasets(team: Optional[str] = None): +def list_datasets(team: Optional[str] = None, dataset_name: Optional[str] = None): + if dataset_name: + yield from list_dataset_versions(team, dataset_name) + return + client = StudioClient(team=team) + response = client.ls_datasets() + if not response.ok: raise_remote_error(response.message) + if not response.data: return @@ -158,6 +165,22 @@ def list_datasets(team: Optional[str] = None): yield (name, version) +def list_dataset_versions(team: Optional[str] = None, dataset_name: str = ""): + client = StudioClient(team=team) + + response = client.dataset_info(dataset_name) + + if not response.ok: + raise_remote_error(response.message) + + if not response.data: + return + + for v in response.data.get("versions", []): + version = v.get("version") + yield (dataset_name, version) + + def edit_studio_dataset( team_name: Optional[str], name: str, diff --git a/tests/conftest.py b/tests/conftest.py index 6e920e584..e6868ed1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -631,12 +631,14 @@ def studio_datasets(requests_mock): with Config(ConfigLevel.GLOBAL).edit() as conf: conf["studio"] = {"token": "isat_access_token", "team": "team_name"} + dogs_dataset = { + "id": 1, + "name": "dogs", + "versions": [{"version": 1}, {"version": 2}], + } + datasets = [ - { - "id": 1, - "name": "dogs", - "versions": [{"version": 1}, {"version": 2}], - }, + dogs_dataset, { "id": 2, "name": "cats", @@ -650,6 +652,10 @@ def studio_datasets(requests_mock): ] requests_mock.get(f"{STUDIO_URL}/api/datachain/datasets", json=datasets) + requests_mock.get( + f"{STUDIO_URL}/api/datachain/datasets/info?dataset_name=dogs&team_name=team_name", + json=dogs_dataset, + ) @pytest.fixture diff --git a/tests/test_cli_studio.py b/tests/test_cli_studio.py index 8376588ef..bec8c1a2d 100644 --- a/tests/test_cli_studio.py +++ b/tests/test_cli_studio.py @@ -120,7 +120,7 @@ def test_studio_team_global(): def test_studio_datasets(capsys, studio_datasets, mocker): - def list_datasets_local(_): + def list_datasets_local(_, __): yield "local", 1 yield "both", 1 @@ -161,6 +161,12 @@ def list_datasets_local(_): ] both_output_versions = tabulate(both_rows_versions, headers="keys") + dogs_rows = [ + {"Name": "dogs", "Latest Version": "v1"}, + {"Name": "dogs", "Latest Version": "v2"}, + ] + dogs_output = tabulate(dogs_rows, headers="keys") + assert main(["dataset", "ls", "--local"]) == 0 out = capsys.readouterr().out assert sorted(out.splitlines()) == sorted(local_output.splitlines()) @@ -185,6 +191,10 @@ def list_datasets_local(_): out = capsys.readouterr().out assert sorted(out.splitlines()) == sorted(both_output_versions.splitlines()) + assert main(["dataset", "ls", "dogs", "--studio"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(dogs_output.splitlines()) + def test_studio_edit_dataset(capsys, mocker): with requests_mock.mock() as m: From 0952dcb7028189c410218c753e6b4cba40c97dd1 Mon Sep 17 00:00:00 2001 From: Amrit Ghimire Date: Thu, 6 Feb 2025 21:05:25 +0545 Subject: [PATCH 2/2] Rename dataset_name to name --- src/datachain/cli/__init__.py | 2 +- src/datachain/cli/commands/datasets.py | 40 ++++++++++++-------------- src/datachain/cli/parser/__init__.py | 2 +- src/datachain/studio.py | 12 ++++---- 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/datachain/cli/__init__.py b/src/datachain/cli/__init__.py index 913c9f48c..97f2920c7 100644 --- a/src/datachain/cli/__init__.py +++ b/src/datachain/cli/__init__.py @@ -161,7 +161,7 @@ def handle_dataset_command(args, catalog): all=args.all, team=args.team, latest_only=not args.versions, - dataset_name=args.dataset_name, + name=args.name, ), "rm": lambda: rm_dataset( catalog, diff --git a/src/datachain/cli/commands/datasets.py b/src/datachain/cli/commands/datasets.py index dd9f3cb90..f4f6d20d4 100644 --- a/src/datachain/cli/commands/datasets.py +++ b/src/datachain/cli/commands/datasets.py @@ -33,18 +33,16 @@ def list_datasets( all: bool = True, team: Optional[str] = None, latest_only: bool = True, - dataset_name: Optional[str] = None, + name: Optional[str] = None, ): token = Config().read().get("studio", {}).get("token") all, local, studio = determine_flavors(studio, local, all, token) - if dataset_name: + if name: latest_only = False - local_datasets = ( - set(list_datasets_local(catalog, dataset_name)) if all or local else set() - ) + local_datasets = set(list_datasets_local(catalog, name)) if all or local else set() studio_datasets = ( - set(list_datasets_studio(team=team, dataset_name=dataset_name)) + set(list_datasets_studio(team=team, name=name)) if (all or studio) and token else set() ) @@ -59,23 +57,23 @@ def list_datasets( datasets = [] if latest_only: # For each dataset name, get the latest version from each source (if available). - for name in all_dataset_names: - datasets.append((name, (local_grouped.get(name), studio_grouped.get(name)))) + for n in all_dataset_names: + datasets.append((n, (local_grouped.get(n), studio_grouped.get(n)))) else: # For each dataset name, merge all versions from both sources. - for name in all_dataset_names: - local_versions = local_grouped.get(name, []) - studio_versions = studio_grouped.get(name, []) + for n in all_dataset_names: + local_versions = local_grouped.get(n, []) + studio_versions = studio_grouped.get(n, []) # If neither source has any versions, record it as (None, None). if not local_versions and not studio_versions: - datasets.append((name, (None, None))) + datasets.append((n, (None, None))) else: # For each unique version from either source, record its presence. for version in sorted(set(local_versions) | set(studio_versions)): datasets.append( ( - name, + n, ( version if version in local_versions else None, version if version in studio_versions else None, @@ -85,20 +83,20 @@ def list_datasets( rows = [ _datasets_tabulate_row( - name=name, + name=n, both=(all or (local and studio)) and token, local_version=local_version, studio_version=studio_version, ) - for name, (local_version, studio_version) in datasets + for n, (local_version, studio_version) in datasets ] print(tabulate(rows, headers="keys")) -def list_datasets_local(catalog: "Catalog", dataset_name: Optional[str] = None): - if dataset_name: - yield from list_datasets_local_versions(catalog, dataset_name) +def list_datasets_local(catalog: "Catalog", name: Optional[str] = None): + if name: + yield from list_datasets_local_versions(catalog, name) return for d in catalog.ls_datasets(): @@ -106,10 +104,10 @@ def list_datasets_local(catalog: "Catalog", dataset_name: Optional[str] = None): yield (d.name, v.version) -def list_datasets_local_versions(catalog: "Catalog", dataset_name: str): - ds = catalog.get_dataset(dataset_name) +def list_datasets_local_versions(catalog: "Catalog", name: str): + ds = catalog.get_dataset(name) for v in ds.versions: - yield (dataset_name, v.version) + yield (name, v.version) def _datasets_tabulate_row(name, both, local_version, studio_version): diff --git a/src/datachain/cli/parser/__init__.py b/src/datachain/cli/parser/__init__.py index 8319da52d..22c6ad5a3 100644 --- a/src/datachain/cli/parser/__init__.py +++ b/src/datachain/cli/parser/__init__.py @@ -255,7 +255,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 formatter_class=CustomHelpFormatter, ) datasets_ls_parser.add_argument( - "dataset_name", action="store", help="Name of the dataset to list", nargs="?" + "name", action="store", help="Name of the dataset to list", nargs="?" ) datasets_ls_parser.add_argument( "--versions", diff --git a/src/datachain/studio.py b/src/datachain/studio.py index 94d758c30..caa5ed1e8 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -140,9 +140,9 @@ def token(): print(token) -def list_datasets(team: Optional[str] = None, dataset_name: Optional[str] = None): - if dataset_name: - yield from list_dataset_versions(team, dataset_name) +def list_datasets(team: Optional[str] = None, name: Optional[str] = None): + if name: + yield from list_dataset_versions(team, name) return client = StudioClient(team=team) @@ -165,10 +165,10 @@ def list_datasets(team: Optional[str] = None, dataset_name: Optional[str] = None yield (name, version) -def list_dataset_versions(team: Optional[str] = None, dataset_name: str = ""): +def list_dataset_versions(team: Optional[str] = None, name: str = ""): client = StudioClient(team=team) - response = client.dataset_info(dataset_name) + response = client.dataset_info(name) if not response.ok: raise_remote_error(response.message) @@ -178,7 +178,7 @@ def list_dataset_versions(team: Optional[str] = None, dataset_name: str = ""): for v in response.data.get("versions", []): version = v.get("version") - yield (dataset_name, version) + yield (name, version) def edit_studio_dataset(