diff --git a/tensordict/base.py b/tensordict/base.py index 82b87d9eb..f72578bb3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3041,6 +3041,7 @@ def map( while len(out) < len(self_split): print('out', out) out.append(imap.next(timeout=5.0)) + print(out[-1]['c']) # out = pool.map(fn, self_split, chunksize) out = torch.cat(out, dim) return out diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e0cae98a5..f8eb548e3 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6568,6 +6568,7 @@ def test_map_seed(self): { "r": torch.zeros(20, dtype=torch.int), "s": torch.zeros(20, dtype=torch.int), + "c": torch.arange(20), }, batch_size=[20], )