Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datasets filter for datachain ds ls #898

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/datachain/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def handle_dataset_command(args, catalog):
all=args.all,
team=args.team,
latest_only=not args.versions,
name=args.name,
),
"rm": lambda: rm_dataset(
catalog,
Expand Down
39 changes: 27 additions & 12 deletions src/datachain/cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@
all: bool = True,
team: Optional[str] = None,
latest_only: bool = True,
name: Optional[str] = None,
):
token = Config().read().get("studio", {}).get("token")
all, local, studio = determine_flavors(studio, local, all, token)
if name:
latest_only = False

local_datasets = set(list_datasets_local(catalog)) if all or local else set()
local_datasets = set(list_datasets_local(catalog, name)) if all or local else set()
studio_datasets = (
set(list_datasets_studio(team=team)) if (all or studio) and token else set()
set(list_datasets_studio(team=team, name=name))
if (all or studio) and token
else set()
)

# Group the datasets for both local and studio sources.
Expand All @@ -52,23 +57,23 @@
datasets = []
if latest_only:
# For each dataset name, get the latest version from each source (if available).
for name in all_dataset_names:
datasets.append((name, (local_grouped.get(name), studio_grouped.get(name))))
for n in all_dataset_names:
datasets.append((n, (local_grouped.get(n), studio_grouped.get(n))))
else:
# For each dataset name, merge all versions from both sources.
for name in all_dataset_names:
local_versions = local_grouped.get(name, [])
studio_versions = studio_grouped.get(name, [])
for n in all_dataset_names:
local_versions = local_grouped.get(n, [])
studio_versions = studio_grouped.get(n, [])

# If neither source has any versions, record it as (None, None).
if not local_versions and not studio_versions:
datasets.append((name, (None, None)))
datasets.append((n, (None, None)))

Check warning on line 70 in src/datachain/cli/commands/datasets.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli/commands/datasets.py#L70

Added line #L70 was not covered by tests
else:
# For each unique version from either source, record its presence.
for version in sorted(set(local_versions) | set(studio_versions)):
datasets.append(
(
name,
n,
(
version if version in local_versions else None,
version if version in studio_versions else None,
Expand All @@ -78,23 +83,33 @@

rows = [
_datasets_tabulate_row(
name=name,
name=n,
both=(all or (local and studio)) and token,
local_version=local_version,
studio_version=studio_version,
)
for name, (local_version, studio_version) in datasets
for n, (local_version, studio_version) in datasets
]

print(tabulate(rows, headers="keys"))


def list_datasets_local(catalog: "Catalog"):
def list_datasets_local(catalog: "Catalog", name: Optional[str] = None):
if name:
yield from list_datasets_local_versions(catalog, name)
return

Check warning on line 100 in src/datachain/cli/commands/datasets.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli/commands/datasets.py#L99-L100

Added lines #L99 - L100 were not covered by tests

for d in catalog.ls_datasets():
for v in d.versions:
yield (d.name, v.version)


def list_datasets_local_versions(catalog: "Catalog", name: str):
ds = catalog.get_dataset(name)

Check warning on line 108 in src/datachain/cli/commands/datasets.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli/commands/datasets.py#L108

Added line #L108 was not covered by tests
for v in ds.versions:
yield (name, v.version)

Check warning on line 110 in src/datachain/cli/commands/datasets.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli/commands/datasets.py#L110

Added line #L110 was not covered by tests


def _datasets_tabulate_row(name, both, local_version, studio_version):
row = {
"Name": name,
Expand Down
3 changes: 3 additions & 0 deletions src/datachain/cli/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
description="List datasets.",
formatter_class=CustomHelpFormatter,
)
datasets_ls_parser.add_argument(
"name", action="store", help="Name of the dataset to list", nargs="?"
)
datasets_ls_parser.add_argument(
"--versions",
action="store_true",
Expand Down
25 changes: 24 additions & 1 deletion src/datachain/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,18 @@
print(token)


def list_datasets(team: Optional[str] = None):
def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
if name:
yield from list_dataset_versions(team, name)
return

client = StudioClient(team=team)

response = client.ls_datasets()

if not response.ok:
raise_remote_error(response.message)

if not response.data:
return

Expand All @@ -158,6 +165,22 @@
yield (name, version)


def list_dataset_versions(team: Optional[str] = None, name: str = ""):
client = StudioClient(team=team)

response = client.dataset_info(name)

if not response.ok:
raise_remote_error(response.message)

Check warning on line 174 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L174

Added line #L174 was not covered by tests

if not response.data:
return

Check warning on line 177 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L177

Added line #L177 was not covered by tests

for v in response.data.get("versions", []):
version = v.get("version")
yield (name, version)


def edit_studio_dataset(
team_name: Optional[str],
name: str,
Expand Down
16 changes: 11 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,14 @@ def studio_datasets(requests_mock):
with Config(ConfigLevel.GLOBAL).edit() as conf:
conf["studio"] = {"token": "isat_access_token", "team": "team_name"}

dogs_dataset = {
"id": 1,
"name": "dogs",
"versions": [{"version": 1}, {"version": 2}],
}

datasets = [
{
"id": 1,
"name": "dogs",
"versions": [{"version": 1}, {"version": 2}],
},
dogs_dataset,
{
"id": 2,
"name": "cats",
Expand All @@ -650,6 +652,10 @@ def studio_datasets(requests_mock):
]

requests_mock.get(f"{STUDIO_URL}/api/datachain/datasets", json=datasets)
requests_mock.get(
f"{STUDIO_URL}/api/datachain/datasets/info?dataset_name=dogs&team_name=team_name",
json=dogs_dataset,
)


@pytest.fixture
Expand Down
12 changes: 11 additions & 1 deletion tests/test_cli_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_studio_team_global():


def test_studio_datasets(capsys, studio_datasets, mocker):
def list_datasets_local(_):
def list_datasets_local(_, __):
yield "local", 1
yield "both", 1

Expand Down Expand Up @@ -161,6 +161,12 @@ def list_datasets_local(_):
]
both_output_versions = tabulate(both_rows_versions, headers="keys")

dogs_rows = [
{"Name": "dogs", "Latest Version": "v1"},
{"Name": "dogs", "Latest Version": "v2"},
]
dogs_output = tabulate(dogs_rows, headers="keys")

assert main(["dataset", "ls", "--local"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(local_output.splitlines())
Expand All @@ -185,6 +191,10 @@ def list_datasets_local(_):
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(both_output_versions.splitlines())

assert main(["dataset", "ls", "dogs", "--studio"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(dogs_output.splitlines())


def test_studio_edit_dataset(capsys, mocker):
with requests_mock.mock() as m:
Expand Down