Skip to content

Commit

Permalink
Fix batch add for async
Browse files Browse the repository at this point in the history
  • Loading branch information
BFAnas committed Aug 9, 2023
1 parent 3510e64 commit c685498
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
19 changes: 12 additions & 7 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,26 +1368,31 @@ def test_custom_key():
'done':
np.array([False]),
'returns':
np.array([74.70343082])
np.array([74.70343082]),
'info':
Batch(),
'policy':
Batch(),
}
)
buffer_size = len(batch.rew)
buffer = ReplayBuffer(buffer_size)
buffer.add(batch)
sampled_batch, _ = buffer.sample(1)
# Check if they have the same keys
assert set(batch.__dict__.keys()) == set(sampled_batch.__dict__.keys()), \
assert set(batch.keys()) == set(sampled_batch.keys()), \
"Batches have different keys: {} and {}".format(
set(batch.__dict__.keys()), set(sampled_batch.__dict__.keys()))
set(batch.keys()), set(sampled_batch.keys()))
# Compare the values for each key
for key in batch.__dict__.keys():
for key in batch.keys():
if isinstance(batch.__dict__[key], np.ndarray
) and isinstance(sampled_batch.__dict__[key], np.ndarray):
assert np.allclose(batch.__dict__[key], sampled_batch.__dict__[key]), \
"Value mismatch for key: {}".format(key)
else:
assert batch.__dict__[key] == sampled_batch.__dict__[key], \
"Value mismatch for key: {}".format(key)
if isinstance(batch.__dict__[key], Batch
) and isinstance(sampled_batch.__dict__[key], Batch):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def add(
episode_reward is 0.
"""
# preprocess batch
new_batch = Batch()
for key in batch.keys():
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert set(["obs", "act", "rew", "terminated", "truncated",
"done"]).issubset(batch.keys())
Expand Down

0 comments on commit c685498

Please sign in to comment.