Skip to content

Commit

Permalink
Ruggedizes list_local_datasets and adds basic tests for CLI (#126)
Browse files Browse the repository at this point in the history
* tests+fix bug from old datasets w/ no metadata

* small text assert

* fix pre-commit issue
  • Loading branch information
grahamannett authored Jul 30, 2023
1 parent da8578c commit 0320c4c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
11 changes: 7 additions & 4 deletions minari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ def _show_dataset_table(datasets, table_title):
table.add_column("Email", justify="left", style="magenta")

for dst_metadata in datasets.values():
author = dst_metadata.get("author", "Unknown")
author_email = dst_metadata.get("author_email", "Unknown")

assert isinstance(dst_metadata["dataset_id"], str)
assert isinstance(dst_metadata["author"], str)
assert isinstance(dst_metadata["author_email"], str)
assert isinstance(author, str)
assert isinstance(author_email, str)
table.add_row(
dst_metadata["dataset_id"],
str(dst_metadata["total_episodes"]),
str(dst_metadata["total_steps"]),
"Coming soon ...",
dst_metadata["author"],
dst_metadata["author_email"],
author,
author_email,
)

print(table)
Expand Down
5 changes: 3 additions & 2 deletions minari/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def list_local_datasets(
main_file_path = os.path.join(datasets_path, dst_id, "data/main_data.hdf5")
with h5py.File(main_file_path, "r") as f:
metadata = dict(f.attrs.items())
if compatible_minari_version and __version__ not in SpecifierSet(
metadata["minari_version"]
if ("minari_version" not in metadata) or (
compatible_minari_version
and __version__ not in SpecifierSet(metadata["minari_version"])
):
continue
env_name, dataset_name, version = parse_dataset_id(dst_id)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from typer.testing import CliRunner

from minari.cli import app
from minari.storage.local import delete_dataset, list_local_datasets
from tests.dataset.test_dataset_download import get_latest_compatible_dataset_id


runner = CliRunner()


def test_list_app():
result = runner.invoke(app, ["list", "local", "--all"])
assert result.exit_code == 0
# some of the other columns may be cut off by Rich
assert "Name" in result.stdout

result = runner.invoke(app, ["list", "remote"])
assert result.exit_code == 0
assert "Minari datasets in Farama server" in result.stdout


@pytest.mark.parametrize(
"dataset_id",
[get_latest_compatible_dataset_id(env_name="pen", dataset_name="human")],
)
def test_dataset_download_then_delete(dataset_id: str):
"""Test download dataset invocation from CLI.
the downloading functionality is already tested in test_dataset_download.py so this is primarily to assert that the CLI is working as expected.
"""
# might have to clear up the local dataset first.
# ideally this seems like it could just be handled by the tests
if dataset_id in list_local_datasets():
delete_dataset(dataset_id)

result = runner.invoke(app, ["download", dataset_id])

assert result.exit_code == 0
assert f"Downloading {dataset_id} from Farama servers..." in result.stdout
assert f"Dataset {dataset_id} downloaded to" in result.stdout

result = runner.invoke(app, ["delete", dataset_id], input="n")
assert result.exit_code == 1 # aborted but no error
assert "Aborted" in result.stdout

result = runner.invoke(app, ["delete", dataset_id], input="😳")
assert result.exit_code == 1
assert "Error: invalid input" in result.stdout

result = runner.invoke(app, ["delete", dataset_id], input="y")
assert result.exit_code == 0
assert f"Dataset {dataset_id} deleted!" in result.stdout

0 comments on commit 0320c4c

Please sign in to comment.