Skip to content

Commit

Permalink
improve minari list remote performance (Farama-Foundation#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Oct 12, 2024
1 parent 287b345 commit bb94e1c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
58 changes: 30 additions & 28 deletions minari/storage/hosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import json
import os
import warnings
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List

from minari.dataset.minari_dataset import gen_dataset_id, parse_dataset_id
from minari.dataset.minari_storage import METADATA_FILE_NAME, MinariStorage
from minari.storage.datasets_root_dir import get_dataset_path
from minari.storage.local import dataset_id_sort_key, load_dataset
from minari.storage.local import load_dataset
from minari.storage.remotes import get_cloud_storage


Expand Down Expand Up @@ -200,44 +202,44 @@ def list_remote_datasets(
"""
from minari import supported_dataset_versions

cloud_storage = get_cloud_storage()
blobs = cloud_storage.list_blobs()

# Generate dict = {'dataset_id': (version, metadata)}
remote_datasets = {}
for blob in blobs:
def blob_to_metadata(blob):
try:
if os.path.basename(blob.name) == METADATA_FILE_NAME:
metadata = json.loads(blob.download_as_bytes(client=None))
if (
compatible_minari_version
and metadata["minari_version"] not in supported_dataset_versions
):
continue
dataset_id = metadata["dataset_id"]
namespace, dataset_name, version = parse_dataset_id(dataset_id)
dataset = gen_dataset_id(namespace, dataset_name)

if latest_version:
if (
dataset not in remote_datasets
or version > remote_datasets[dataset][0]
):
remote_datasets[dataset] = (version, metadata)
else:
remote_datasets[dataset_id] = metadata
return
return metadata
except Exception:
warnings.warn(f"Misconfigured dataset named {blob.name} on remote")

if latest_version:
# Convert to dict = {'dataset_id': metadata}
remote_datasets = dict(
map(lambda x: (f"{x[0]}-v{x[1][0]}", x[1][1]), remote_datasets.items())
)
cloud_storage = get_cloud_storage()
blobs = cloud_storage.list_blobs()
with ThreadPoolExecutor(max_workers=32) as executor:
remote_metadatas = executor.map(blob_to_metadata, blobs)

remote_datasets = {}
max_version = defaultdict(dict)
for metadata in remote_metadatas:
if metadata is None:
continue

dataset_id = metadata["dataset_id"]
remote_datasets[dataset_id] = metadata

if latest_version:
namespace, dataset_name, version = parse_dataset_id(dataset_id)
old_version = max_version[namespace].get(dataset_name, version)
max_version[namespace][dataset_name] = max(old_version, version)
if old_version != max_version[namespace][dataset_name]:
min_id = gen_dataset_id(
namespace, dataset_name, min(old_version, version)
)
del remote_datasets[min_id]

return {
k: remote_datasets[k] for k in sorted(remote_datasets, key=dataset_id_sort_key)
}
return remote_datasets


def get_remote_dataset_versions(
Expand Down
4 changes: 2 additions & 2 deletions minari/storage/remotes/cloud_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional
from typing import Any, Iterable, Optional


class CloudStorage(ABC):
Expand All @@ -14,7 +14,7 @@ def upload_directory(self, path: Path, remote_dir_path: str) -> None: ...
def upload_file(self, local_path: Path, remote_path: str) -> None: ...

@abstractmethod
def list_blobs(self, prefix: Optional[str] = None) -> list: ...
def list_blobs(self, prefix: Optional[str] = None) -> Iterable: ...

@abstractmethod
def download_blob(self, blob: Any, file_path: str) -> None: ...
4 changes: 2 additions & 2 deletions minari/storage/remotes/gcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Any, Optional
from typing import Any, Iterable, Optional

from minari.storage.remotes.cloud_storage import CloudStorage

Expand Down Expand Up @@ -39,7 +39,7 @@ def upload_file(self, local_path: Path, remote_path: str) -> None:
blob = self.bucket.blob(remote_path)
blob.upload_from_filename(local_path)

def list_blobs(self, prefix: Optional[str] = None) -> list:
def list_blobs(self, prefix: Optional[str] = None) -> Iterable:
return self.bucket.list_blobs(prefix=prefix)

def download_blob(self, blob: Any, file_path: Path) -> None:
Expand Down

0 comments on commit bb94e1c

Please sign in to comment.