diff --git a/minari/storage/local.py b/minari/storage/local.py index dd5bfdec..efa16447 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -7,6 +7,7 @@ from packaging.specifiers import SpecifierSet from minari.dataset.minari_dataset import MinariDataset, parse_dataset_id +from minari.storage import hosting from minari.storage.datasets_root_dir import get_dataset_path @@ -14,17 +15,27 @@ __version__ = importlib.metadata.version("minari") -def load_dataset(dataset_id: str): +def load_dataset(dataset_id: str, download: bool = False): """Retrieve Minari dataset from local database. Args: dataset_id (str): name id of Minari dataset + download (bool): if `True` download the dataset if it is not found locally. Default to `False`. Returns: MinariDataset """ file_path = get_dataset_path(dataset_id) data_path = os.path.join(file_path, "data", "main_data.hdf5") + + if not os.path.exists(data_path): + if not download: + raise FileNotFoundError( + f"Dataset {dataset_id} not found locally at {file_path}. Use download=True to download the dataset." + ) + + hosting.download_dataset(dataset_id) + return MinariDataset(data_path) diff --git a/tests/dataset/test_dataset_download.py b/tests/dataset/test_dataset_download.py index d203c199..fe969bed 100644 --- a/tests/dataset/test_dataset_download.py +++ b/tests/dataset/test_dataset_download.py @@ -61,3 +61,21 @@ def test_download_dataset_from_farama_server(dataset_id: str): minari.delete_dataset(dataset_id) local_datasets = minari.list_local_datasets() assert dataset_id not in local_datasets + + +@pytest.mark.parametrize( + "dataset_id", + [ + get_latest_compatible_dataset_id(env_name=env_name, dataset_name="human") + for env_name in env_names + ], +) +def test_load_dataset_with_download(dataset_id: str): + """Test load dataset with and without download.""" + with pytest.raises(FileNotFoundError): + dataset = minari.load_dataset(dataset_id) + + dataset = minari.load_dataset(dataset_id, download=True) + assert isinstance(dataset, MinariDataset) + + minari.delete_dataset(dataset_id)