Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 8, 2024
1 parent 44d1f9a commit b14edf7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 18 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _main(argv):
"offline-data": [
"huggingface_hub", # for roboset
"minari",
"requests",
"tqdm",
"scikit-learn",
"pandas",
Expand Down
24 changes: 20 additions & 4 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,12 +1833,28 @@ def test_grouping(self, n_agents, scenario_name="dispersion", n_envs=2):

@pytest.mark.slow
class TestGenDGRL:
def test_gen_dgrl(self, tmpdir):
dataset_id = GenDGRLExperienceReplay.available_datasets[0]
@staticmethod
@pytest.fixture
def _patch_traj_len():
# avoids processing the entire dataset
_get_category_len = GenDGRLExperienceReplay._get_category_len

def new_get_category_len(cls, category_name):
return 100

GenDGRLExperienceReplay._get_category_len = classmethod(new_get_category_len)

yield
GenDGRLExperienceReplay._get_category_len = _get_category_len

@pytest.mark.parametrize("dataset_num", [0, 4, 8])
def test_gen_dgrl(self, dataset_num, tmpdir, _patch_traj_len):
dataset_id = GenDGRLExperienceReplay.available_datasets[dataset_num]
print("dataset_id", dataset_id)
dataset = GenDGRLExperienceReplay(dataset_id, batch_size=32, root=tmpdir)
for batch in dataset:
print(batch)
for batch in dataset: # noqa: B007
break
assert batch.get(("next", "observation")).shape[-3] == 3


@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found")
Expand Down
60 changes: 46 additions & 14 deletions torchrl/data/datasets/gen_dgrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import os
import tarfile
import tempfile
import typing as tp
from pathlib import Path

import numpy as np
import requests
import torch

from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.envs.utils import _classproperty

_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
_has_requests = importlib.util.find_spec("requests", None) is not None


class GenDGRLExperienceReplay(TensorDictReplayBuffer):
"""Gen-DGRL Experience Replay dataset.
Expand Down Expand Up @@ -233,16 +236,19 @@ def _unpack_category_file(
idx = 0
td_memmap = None
dataset_len = self._get_category_len(category_name)
try:
if _has_tqdm:
from tqdm import tqdm

pbar = tqdm(total=dataset_len)
except ImportError:
else:
pbar = None
mode = "r:xz" if str(file_path).endswith("xz") else "r"
full = False
with tarfile.open(file_path, mode) as tar:
members = list(tar.getmembers())
for i in range(0, len(members), batch):
if full:
break
submembers = [
member for member in members[i : i + batch] if member.isfile()
]
Expand All @@ -265,7 +271,6 @@ def _unpack_category_file(
td.rename_key_("dones", ("next", "done"))
td.rename_key_("actions", "action")
td.rename_key_("rewards", ("next", "reward"))
td.set(("next", "reward"), td.get(("next", "reward")).unsqueeze(-1))
td.set(
("next", "done"), td.get(("next", "done")).bool().unsqueeze(-1)
)
Expand All @@ -274,6 +279,13 @@ def _unpack_category_file(
torch.zeros_like(td.get(("next", "done"))),
)
td.set(("next", "terminated"), td.get(("next", "done")))

td.set(
"terminated", torch.zeros_like(td.get(("next", "terminated")))
)
td.set("done", torch.zeros_like(td.get(("next", "done"))))
td.set("truncated", torch.zeros_like(td.get(("next", "truncated"))))

td.batch_size = td.get("observation").shape[:1]
if td_memmap is None:
td_memmap = (
Expand All @@ -285,7 +297,14 @@ def _unpack_category_file(
idx_end = min(idx_end, td_memmap.shape[0])
if pbar is not None:
pbar.update(td.shape[0])
td_memmap[idx:idx_end] = td
length = idx_end - idx
if length > 0:
if length != td.shape[0]:
td_memmap[idx:idx_end] = td[:length]
else:
td_memmap[idx:idx_end] = td
else:
full = True
idx = idx_end
os.remove(npyfile)

Expand Down Expand Up @@ -323,15 +342,28 @@ def _download_category_file(
@classmethod
def _download_with_progress_bar(cls, url: str, file_path: str):
# taken from https://stackoverflow.com/a/62113293/986477
if not _has_requests:
raise ImportError(
"The requests package is required for Gen-DGRL dataset download."
)
import requests

resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(file_path, "wb") as file, tqdm(
desc=file_path,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
if _has_tqdm:
from tqdm import tqdm

pbar = tqdm(
desc=file_path,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
)
else:
pbar = None
with open(file_path, "wb") as file:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
if pbar is not None:
pbar.update(size)

0 comments on commit b14edf7

Please sign in to comment.