Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 29, 2024
1 parent 6f9c672 commit 04bcce0
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion torchrl/data/replay_buffers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ class TED2Flat:
... # load the data to represent it
... td = TensorDict.load(tmpdir + "/storage/")
... print(td)
TensorDict(
fields={
action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True),
collector: TensorDict(
fields={
traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False),
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True),
reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True),
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
"""

Expand Down Expand Up @@ -196,13 +213,39 @@ class Flat2TED:
... rb.dumps(tmpdir)
... # load the data to represent it
... td = TensorDict.load(tmpdir + "/storage/")
... print(td)
...
... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200))
... rb_load.register_load_hook(Flat2TED())
... rb_load.load(tmpdir)
... print("storage after loading", rb_load[:])
... assert (rb[:] == rb_load[:]).all()
storage after loading TensorDict(
fields={
action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
collector: TensorDict(
fields={
traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False),
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False),
observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=cpu,
is_shared=False)
"""

Expand Down

0 comments on commit 04bcce0

Please sign in to comment.