From 04bcce0f29ebf0d60c3f7293a36525385845ab5a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Apr 2024 16:06:13 +0100 Subject: [PATCH] amend --- torchrl/data/replay_buffers/utils.py | 45 +++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 45032c4399f..f38c0e2f9f0 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -125,6 +125,23 @@ class TED2Flat: ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") ... print(td) + TensorDict( + fields={ + action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True), + collector: TensorDict( + fields={ + traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), + observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True), + reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) """ @@ -196,13 +213,39 @@ class Flat2TED: ... rb.dumps(tmpdir) ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") - ... print(td) ... ... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200)) ... rb_load.register_load_hook(Flat2TED()) ... rb_load.load(tmpdir) ... print("storage after loading", rb_load[:]) ... assert (rb[:] == rb_load[:]).all() + storage after loading TensorDict( + fields={ + action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), + reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + """