diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 4a61383c6e1..1b7bff82a58 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2455,9 +2455,6 @@ def _step( try: history = torch.stack(list(history.unbind(0)) + [local_history]) except Exception: - print(history) - print(history.unbind(0)) - print(local_history) raise assert isinstance(history, History) next_tensordict["history"] = history diff --git a/test/test_env.py b/test/test_env.py index 2636a071ab5..69c63e567a6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -15,7 +15,6 @@ from collections import defaultdict from functools import partial from sys import platform -from tokenize import maybe from typing import Optional import numpy as np @@ -4406,20 +4405,24 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi class TestEnvWithHistory: @pytest.fixture(autouse=True, scope="class") def set_capture(self): - with set_capture_non_tensor_stack(False): + with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env( + False + ): yield return - def _make_env(self, device): - return CountingEnv(device=device).append_transform(HistoryTransform()) + def _make_env(self, device, max_steps=10): + return CountingEnv(device=device, max_steps=max_steps).append_transform( + HistoryTransform() + ) - def _make_skipping_env(self, device): - env = self._make_env(device=device) - env = env.append_transform(StepCounter()) + def _make_skipping_env(self, device, max_steps=10): + env = self._make_env(device=device, max_steps=max_steps) # skip every 3 steps env = env.append_transform( - ConditionalSkip(lambda td: td["step_count"] % 3 == 0) + ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2)) ) + env = TransformedEnv(env, StepCounter()) return env @pytest.mark.parametrize("device", [None, "cpu"]) @@ -4482,17 +4485,39 @@ def test_env_history_base_collector(self, device_env, collector_cls): env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5 ) for d in collector: - print(d) + for i in range(d.shape[0] - 1): + assert ( + d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0] + ) @pytest.mark.parametrize("device_env", [None, "cpu"]) @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) def test_skipping_history_env_collector(self, device_env, collector_cls): - env = self._make_skipping_env(device_env) + env = self._make_skipping_env(device_env, max_steps=10) collector = collector_cls( - env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5 + env, + lambda td: td.update(env.full_action_spec.one()), + total_frames=35, + frames_per_batch=5, ) + length = None + count = 1 for d in collector: - print(d) + for k in range(1, 5): + if len(d[k]["history"].content) == 2: + count = 1 + continue + if count % 3 == 2: + assert ( + d[k]["next", "history"].content + == d[k - 1]["next", "history"].content + ), (d["next", "history"].content, k, count) + else: + assert d[k]["next", "history"].content[-1] == str( + int(d[k - 1]["next", "history"].content[-1]) + 1 + ), (d["next", "history"].content, k, count) + count += 1 + count += 1 if __name__ == "__main__": diff --git a/test/test_transforms.py b/test/test_transforms.py index 5f98f4dc96c..39a66eb6abb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -13496,7 +13496,7 @@ def check_non_tensor_match(self, td): class ToString(Transform): def _apply_transform(self, obs: torch.Tensor) -> None: - return NonTensorData(str(obs), device=obs.device) + return NonTensorData(str(obs), device=self.parent.device) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 425b2035228..346efb4e117 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2525,7 +2525,6 @@ def __init__( if isinstance(shape, int): shape = _size([shape]) - # _, device = _default_dtype_and_device(None, device) domain = None super().__init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f96b197caac..bcb3aade632 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1142,10 +1142,12 @@ def _step( out_td = self._envs[i]._step(_data_in) next_td[i].update_( out_td, + # _env_output_keys exclude non-tensor data keys_to_update=list(self._env_output_keys), non_blocking=self.non_blocking, ) if out_tds is not None: + # we store the non-tensor data here out_tds.append(out_td) # We must pass a clone of the tensordict, as the values of this tensordict @@ -1989,7 +1991,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: next_td_passthrough = None data = [{} for _ in range(self.num_workers)] - assert self._non_tensor_keys if self._non_tensor_keys: for i, td in zip( workers_range, @@ -2012,7 +2013,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) - print("non_tensor_td", non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps @@ -2718,9 +2718,6 @@ def _run_worker_pipe_direct( i += 1 # data, idx = data # data = data[idx] - print("device received", data["history"].device) - print('data["history"]', data["history"]) - print('data["history"][0]', data["history"][0]) next_td = env._step(data) if event is not None: event.record() @@ -2733,7 +2730,6 @@ def _run_worker_pipe_direct( ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err - print(f"next_td in worker {pid} and consolidate {consolidate}", next_td) child_pipe.send(next_td) del next_td diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f674aee095e..6fd17334654 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1979,7 +1979,6 @@ def select_and_clone(name, x, y): result = tensordict._fast_apply( select_and_clone, next_tensordict, - # device=next_tensordict.device, default=None, filter_empty=True, is_leaf=_is_leaf_nontensor, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0bf924f9512..adf741839bc 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -10592,7 +10592,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: cond = self.cond(tensordict) # Write result in step tensordict["_step"] = tensordict.get("_step", True) & ~cond - if not tensordict["_step"].shape == tensordict.batch_size: + if tensordict["_step"].shape != tensordict.batch_size: tensordict["_step"] = tensordict["_step"].view(tensordict.batch_size) return tensordict diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index be53bd311d9..176ca5600a7 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -757,8 +757,6 @@ def check_env_specs( auto_reset=tensordict is None, break_when_any_done=break_when_any_done, ) - print(real_tensordict) - print(fake_tensordict) if return_contiguous: fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)