Skip to content

Commit

Permalink
load dataset with download option (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamannett authored Aug 7, 2023
1 parent 0db0da5 commit de583f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
13 changes: 12 additions & 1 deletion minari/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,35 @@
from packaging.specifiers import SpecifierSet

from minari.dataset.minari_dataset import MinariDataset, parse_dataset_id
from minari.storage import hosting
from minari.storage.datasets_root_dir import get_dataset_path


# Use importlib due to circular import when: "from minari import __version__"
__version__ = importlib.metadata.version("minari")


def load_dataset(dataset_id: str):
def load_dataset(dataset_id: str, download: bool = False):
"""Retrieve Minari dataset from local database.
Args:
dataset_id (str): name id of Minari dataset
download (bool): if `True` download the dataset if it is not found locally. Default to `False`.
Returns:
MinariDataset
"""
file_path = get_dataset_path(dataset_id)
data_path = os.path.join(file_path, "data", "main_data.hdf5")

if not os.path.exists(data_path):
if not download:
raise FileNotFoundError(
f"Dataset {dataset_id} not found locally at {file_path}. Use download=True to download the dataset."
)

hosting.download_dataset(dataset_id)

return MinariDataset(data_path)


Expand Down
18 changes: 18 additions & 0 deletions tests/dataset/test_dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,21 @@ def test_download_dataset_from_farama_server(dataset_id: str):
minari.delete_dataset(dataset_id)
local_datasets = minari.list_local_datasets()
assert dataset_id not in local_datasets


@pytest.mark.parametrize(
"dataset_id",
[
get_latest_compatible_dataset_id(env_name=env_name, dataset_name="human")
for env_name in env_names
],
)
def test_load_dataset_with_download(dataset_id: str):
"""Test load dataset with and without download."""
with pytest.raises(FileNotFoundError):
dataset = minari.load_dataset(dataset_id)

dataset = minari.load_dataset(dataset_id, download=True)
assert isinstance(dataset, MinariDataset)

minari.delete_dataset(dataset_id)

0 comments on commit de583f8

Please sign in to comment.