diff --git a/src/datachain/cli/__init__.py b/src/datachain/cli/__init__.py index 05fb67c2c..97f2920c7 100644 --- a/src/datachain/cli/__init__.py +++ b/src/datachain/cli/__init__.py @@ -161,6 +161,7 @@ def handle_dataset_command(args, catalog): all=args.all, team=args.team, latest_only=not args.versions, + name=args.name, ), "rm": lambda: rm_dataset( catalog, diff --git a/src/datachain/cli/commands/datasets.py b/src/datachain/cli/commands/datasets.py index 403ccca0f..f4f6d20d4 100644 --- a/src/datachain/cli/commands/datasets.py +++ b/src/datachain/cli/commands/datasets.py @@ -33,13 +33,18 @@ def list_datasets( all: bool = True, team: Optional[str] = None, latest_only: bool = True, + name: Optional[str] = None, ): token = Config().read().get("studio", {}).get("token") all, local, studio = determine_flavors(studio, local, all, token) + if name: + latest_only = False - local_datasets = set(list_datasets_local(catalog)) if all or local else set() + local_datasets = set(list_datasets_local(catalog, name)) if all or local else set() studio_datasets = ( - set(list_datasets_studio(team=team)) if (all or studio) and token else set() + set(list_datasets_studio(team=team, name=name)) + if (all or studio) and token + else set() ) # Group the datasets for both local and studio sources. @@ -52,23 +57,23 @@ def list_datasets( datasets = [] if latest_only: # For each dataset name, get the latest version from each source (if available). - for name in all_dataset_names: - datasets.append((name, (local_grouped.get(name), studio_grouped.get(name)))) + for n in all_dataset_names: + datasets.append((n, (local_grouped.get(n), studio_grouped.get(n)))) else: # For each dataset name, merge all versions from both sources. - for name in all_dataset_names: - local_versions = local_grouped.get(name, []) - studio_versions = studio_grouped.get(name, []) + for n in all_dataset_names: + local_versions = local_grouped.get(n, []) + studio_versions = studio_grouped.get(n, []) # If neither source has any versions, record it as (None, None). if not local_versions and not studio_versions: - datasets.append((name, (None, None))) + datasets.append((n, (None, None))) else: # For each unique version from either source, record its presence. for version in sorted(set(local_versions) | set(studio_versions)): datasets.append( ( - name, + n, ( version if version in local_versions else None, version if version in studio_versions else None, @@ -78,23 +83,33 @@ def list_datasets( rows = [ _datasets_tabulate_row( - name=name, + name=n, both=(all or (local and studio)) and token, local_version=local_version, studio_version=studio_version, ) - for name, (local_version, studio_version) in datasets + for n, (local_version, studio_version) in datasets ] print(tabulate(rows, headers="keys")) -def list_datasets_local(catalog: "Catalog"): +def list_datasets_local(catalog: "Catalog", name: Optional[str] = None): + if name: + yield from list_datasets_local_versions(catalog, name) + return + for d in catalog.ls_datasets(): for v in d.versions: yield (d.name, v.version) +def list_datasets_local_versions(catalog: "Catalog", name: str): + ds = catalog.get_dataset(name) + for v in ds.versions: + yield (name, v.version) + + def _datasets_tabulate_row(name, both, local_version, studio_version): row = { "Name": name, diff --git a/src/datachain/cli/parser/__init__.py b/src/datachain/cli/parser/__init__.py index 6b6380443..22c6ad5a3 100644 --- a/src/datachain/cli/parser/__init__.py +++ b/src/datachain/cli/parser/__init__.py @@ -254,6 +254,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 description="List datasets.", formatter_class=CustomHelpFormatter, ) + datasets_ls_parser.add_argument( + "name", action="store", help="Name of the dataset to list", nargs="?" + ) datasets_ls_parser.add_argument( "--versions", action="store_true", diff --git a/src/datachain/studio.py b/src/datachain/studio.py index f2e564e10..caa5ed1e8 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -140,11 +140,18 @@ def token(): print(token) -def list_datasets(team: Optional[str] = None): +def list_datasets(team: Optional[str] = None, name: Optional[str] = None): + if name: + yield from list_dataset_versions(team, name) + return + client = StudioClient(team=team) + response = client.ls_datasets() + if not response.ok: raise_remote_error(response.message) + if not response.data: return @@ -158,6 +165,22 @@ def list_datasets(team: Optional[str] = None): yield (name, version) +def list_dataset_versions(team: Optional[str] = None, name: str = ""): + client = StudioClient(team=team) + + response = client.dataset_info(name) + + if not response.ok: + raise_remote_error(response.message) + + if not response.data: + return + + for v in response.data.get("versions", []): + version = v.get("version") + yield (name, version) + + def edit_studio_dataset( team_name: Optional[str], name: str, diff --git a/tests/conftest.py b/tests/conftest.py index 6e920e584..e6868ed1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -631,12 +631,14 @@ def studio_datasets(requests_mock): with Config(ConfigLevel.GLOBAL).edit() as conf: conf["studio"] = {"token": "isat_access_token", "team": "team_name"} + dogs_dataset = { + "id": 1, + "name": "dogs", + "versions": [{"version": 1}, {"version": 2}], + } + datasets = [ - { - "id": 1, - "name": "dogs", - "versions": [{"version": 1}, {"version": 2}], - }, + dogs_dataset, { "id": 2, "name": "cats", @@ -650,6 +652,10 @@ def studio_datasets(requests_mock): ] requests_mock.get(f"{STUDIO_URL}/api/datachain/datasets", json=datasets) + requests_mock.get( + f"{STUDIO_URL}/api/datachain/datasets/info?dataset_name=dogs&team_name=team_name", + json=dogs_dataset, + ) @pytest.fixture diff --git a/tests/test_cli_studio.py b/tests/test_cli_studio.py index 8376588ef..bec8c1a2d 100644 --- a/tests/test_cli_studio.py +++ b/tests/test_cli_studio.py @@ -120,7 +120,7 @@ def test_studio_team_global(): def test_studio_datasets(capsys, studio_datasets, mocker): - def list_datasets_local(_): + def list_datasets_local(_, __): yield "local", 1 yield "both", 1 @@ -161,6 +161,12 @@ def list_datasets_local(_): ] both_output_versions = tabulate(both_rows_versions, headers="keys") + dogs_rows = [ + {"Name": "dogs", "Latest Version": "v1"}, + {"Name": "dogs", "Latest Version": "v2"}, + ] + dogs_output = tabulate(dogs_rows, headers="keys") + assert main(["dataset", "ls", "--local"]) == 0 out = capsys.readouterr().out assert sorted(out.splitlines()) == sorted(local_output.splitlines()) @@ -185,6 +191,10 @@ def list_datasets_local(_): out = capsys.readouterr().out assert sorted(out.splitlines()) == sorted(both_output_versions.splitlines()) + assert main(["dataset", "ls", "dogs", "--studio"]) == 0 + out = capsys.readouterr().out + assert sorted(out.splitlines()) == sorted(dogs_output.splitlines()) + def test_studio_edit_dataset(capsys, mocker): with requests_mock.mock() as m: