Skip to content

Commit

Permalink
Add datasets filter for datachain ds ls (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
amritghimire authored Feb 7, 2025
1 parent 86fc806 commit 79f6cf9
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/datachain/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def handle_dataset_command(args, catalog):
all=args.all,
team=args.team,
latest_only=not args.versions,
name=args.name,
),
"rm": lambda: rm_dataset(
catalog,
Expand Down
39 changes: 27 additions & 12 deletions src/datachain/cli/commands/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ def list_datasets(
all: bool = True,
team: Optional[str] = None,
latest_only: bool = True,
name: Optional[str] = None,
):
token = Config().read().get("studio", {}).get("token")
all, local, studio = determine_flavors(studio, local, all, token)
if name:
latest_only = False

local_datasets = set(list_datasets_local(catalog)) if all or local else set()
local_datasets = set(list_datasets_local(catalog, name)) if all or local else set()
studio_datasets = (
set(list_datasets_studio(team=team)) if (all or studio) and token else set()
set(list_datasets_studio(team=team, name=name))
if (all or studio) and token
else set()
)

# Group the datasets for both local and studio sources.
Expand All @@ -52,23 +57,23 @@ def list_datasets(
datasets = []
if latest_only:
# For each dataset name, get the latest version from each source (if available).
for name in all_dataset_names:
datasets.append((name, (local_grouped.get(name), studio_grouped.get(name))))
for n in all_dataset_names:
datasets.append((n, (local_grouped.get(n), studio_grouped.get(n))))
else:
# For each dataset name, merge all versions from both sources.
for name in all_dataset_names:
local_versions = local_grouped.get(name, [])
studio_versions = studio_grouped.get(name, [])
for n in all_dataset_names:
local_versions = local_grouped.get(n, [])
studio_versions = studio_grouped.get(n, [])

# If neither source has any versions, record it as (None, None).
if not local_versions and not studio_versions:
datasets.append((name, (None, None)))
datasets.append((n, (None, None)))
else:
# For each unique version from either source, record its presence.
for version in sorted(set(local_versions) | set(studio_versions)):
datasets.append(
(
name,
n,
(
version if version in local_versions else None,
version if version in studio_versions else None,
Expand All @@ -78,23 +83,33 @@ def list_datasets(

rows = [
_datasets_tabulate_row(
name=name,
name=n,
both=(all or (local and studio)) and token,
local_version=local_version,
studio_version=studio_version,
)
for name, (local_version, studio_version) in datasets
for n, (local_version, studio_version) in datasets
]

print(tabulate(rows, headers="keys"))


def list_datasets_local(catalog: "Catalog"):
def list_datasets_local(catalog: "Catalog", name: Optional[str] = None):
if name:
yield from list_datasets_local_versions(catalog, name)
return

for d in catalog.ls_datasets():
for v in d.versions:
yield (d.name, v.version)


def list_datasets_local_versions(catalog: "Catalog", name: str):
ds = catalog.get_dataset(name)
for v in ds.versions:
yield (name, v.version)


def _datasets_tabulate_row(name, both, local_version, studio_version):
row = {
"Name": name,
Expand Down
3 changes: 3 additions & 0 deletions src/datachain/cli/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
description="List datasets.",
formatter_class=CustomHelpFormatter,
)
datasets_ls_parser.add_argument(
"name", action="store", help="Name of the dataset to list", nargs="?"
)
datasets_ls_parser.add_argument(
"--versions",
action="store_true",
Expand Down
25 changes: 24 additions & 1 deletion src/datachain/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,18 @@ def token():
print(token)


def list_datasets(team: Optional[str] = None):
def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
if name:
yield from list_dataset_versions(team, name)
return

client = StudioClient(team=team)

response = client.ls_datasets()

if not response.ok:
raise_remote_error(response.message)

if not response.data:
return

Expand All @@ -158,6 +165,22 @@ def list_datasets(team: Optional[str] = None):
yield (name, version)


def list_dataset_versions(team: Optional[str] = None, name: str = ""):
client = StudioClient(team=team)

response = client.dataset_info(name)

if not response.ok:
raise_remote_error(response.message)

if not response.data:
return

for v in response.data.get("versions", []):
version = v.get("version")
yield (name, version)


def edit_studio_dataset(
team_name: Optional[str],
name: str,
Expand Down
16 changes: 11 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,14 @@ def studio_datasets(requests_mock):
with Config(ConfigLevel.GLOBAL).edit() as conf:
conf["studio"] = {"token": "isat_access_token", "team": "team_name"}

dogs_dataset = {
"id": 1,
"name": "dogs",
"versions": [{"version": 1}, {"version": 2}],
}

datasets = [
{
"id": 1,
"name": "dogs",
"versions": [{"version": 1}, {"version": 2}],
},
dogs_dataset,
{
"id": 2,
"name": "cats",
Expand All @@ -650,6 +652,10 @@ def studio_datasets(requests_mock):
]

requests_mock.get(f"{STUDIO_URL}/api/datachain/datasets", json=datasets)
requests_mock.get(
f"{STUDIO_URL}/api/datachain/datasets/info?dataset_name=dogs&team_name=team_name",
json=dogs_dataset,
)


@pytest.fixture
Expand Down
12 changes: 11 additions & 1 deletion tests/test_cli_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_studio_team_global():


def test_studio_datasets(capsys, studio_datasets, mocker):
def list_datasets_local(_):
def list_datasets_local(_, __):
yield "local", 1
yield "both", 1

Expand Down Expand Up @@ -161,6 +161,12 @@ def list_datasets_local(_):
]
both_output_versions = tabulate(both_rows_versions, headers="keys")

dogs_rows = [
{"Name": "dogs", "Latest Version": "v1"},
{"Name": "dogs", "Latest Version": "v2"},
]
dogs_output = tabulate(dogs_rows, headers="keys")

assert main(["dataset", "ls", "--local"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(local_output.splitlines())
Expand All @@ -185,6 +191,10 @@ def list_datasets_local(_):
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(both_output_versions.splitlines())

assert main(["dataset", "ls", "dogs", "--studio"]) == 0
out = capsys.readouterr().out
assert sorted(out.splitlines()) == sorted(dogs_output.splitlines())


def test_studio_edit_dataset(capsys, mocker):
with requests_mock.mock() as m:
Expand Down

0 comments on commit 79f6cf9

Please sign in to comment.