Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 26, 2025
1 parent 22f7687 commit 77d7edd
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 27 deletions.
3 changes: 0 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,9 +2455,6 @@ def _step(
try:
history = torch.stack(list(history.unbind(0)) + [local_history])
except Exception:
print(history)
print(history.unbind(0))
print(local_history)
raise
assert isinstance(history, History)
next_tensordict["history"] = history
Expand Down
49 changes: 37 additions & 12 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import defaultdict
from functools import partial
from sys import platform
from tokenize import maybe
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -4406,20 +4405,24 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
class TestEnvWithHistory:
@pytest.fixture(autouse=True, scope="class")
def set_capture(self):
with set_capture_non_tensor_stack(False):
with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env(
False
):
yield
return

def _make_env(self, device):
return CountingEnv(device=device).append_transform(HistoryTransform())
def _make_env(self, device, max_steps=10):
return CountingEnv(device=device, max_steps=max_steps).append_transform(
HistoryTransform()
)

def _make_skipping_env(self, device):
env = self._make_env(device=device)
env = env.append_transform(StepCounter())
def _make_skipping_env(self, device, max_steps=10):
env = self._make_env(device=device, max_steps=max_steps)
# skip every 3 steps
env = env.append_transform(
ConditionalSkip(lambda td: td["step_count"] % 3 == 0)
ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2))
)
env = TransformedEnv(env, StepCounter())
return env

@pytest.mark.parametrize("device", [None, "cpu"])
Expand Down Expand Up @@ -4482,17 +4485,39 @@ def test_env_history_base_collector(self, device_env, collector_cls):
env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5
)
for d in collector:
print(d)
for i in range(d.shape[0] - 1):
assert (
d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0]
)

@pytest.mark.parametrize("device_env", [None, "cpu"])
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
def test_skipping_history_env_collector(self, device_env, collector_cls):
env = self._make_skipping_env(device_env)
env = self._make_skipping_env(device_env, max_steps=10)
collector = collector_cls(
env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5
env,
lambda td: td.update(env.full_action_spec.one()),
total_frames=35,
frames_per_batch=5,
)
length = None
count = 1
for d in collector:
print(d)
for k in range(1, 5):
if len(d[k]["history"].content) == 2:
count = 1
continue
if count % 3 == 2:
assert (
d[k]["next", "history"].content
== d[k - 1]["next", "history"].content
), (d["next", "history"].content, k, count)
else:
assert d[k]["next", "history"].content[-1] == str(
int(d[k - 1]["next", "history"].content[-1]) + 1
), (d["next", "history"].content, k, count)
count += 1
count += 1


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13496,7 +13496,7 @@ def check_non_tensor_match(self, td):

class ToString(Transform):
def _apply_transform(self, obs: torch.Tensor) -> None:
return NonTensorData(str(obs), device=obs.device)
return NonTensorData(str(obs), device=self.parent.device)

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
Expand Down
1 change: 0 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,6 @@ def __init__(
if isinstance(shape, int):
shape = _size([shape])

# _, device = _default_dtype_and_device(None, device)
domain = None
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
Expand Down
8 changes: 2 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,10 +1142,12 @@ def _step(
out_td = self._envs[i]._step(_data_in)
next_td[i].update_(
out_td,
# _env_output_keys exclude non-tensor data
keys_to_update=list(self._env_output_keys),
non_blocking=self.non_blocking,
)
if out_tds is not None:
# we store the non-tensor data here
out_tds.append(out_td)

# We must pass a clone of the tensordict, as the values of this tensordict
Expand Down Expand Up @@ -1989,7 +1991,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
next_td_passthrough = None
data = [{} for _ in range(self.num_workers)]

assert self._non_tensor_keys
if self._non_tensor_keys:
for i, td in zip(
workers_range,
Expand All @@ -2012,7 +2013,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for i in workers_range:
msg, non_tensor_td = self.parent_channels[i].recv()
non_tensor_tds.append(non_tensor_td)
print("non_tensor_td", non_tensor_td)

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
Expand Down Expand Up @@ -2718,9 +2718,6 @@ def _run_worker_pipe_direct(
i += 1
# data, idx = data
# data = data[idx]
print("device received", data["history"].device)
print('data["history"]', data["history"])
print('data["history"][0]', data["history"][0])
next_td = env._step(data)
if event is not None:
event.record()
Expand All @@ -2733,7 +2730,6 @@ def _run_worker_pipe_direct(
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
print(f"next_td in worker {pid} and consolidate {consolidate}", next_td)
child_pipe.send(next_td)

del next_td
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,6 @@ def select_and_clone(name, x, y):
result = tensordict._fast_apply(
select_and_clone,
next_tensordict,
# device=next_tensordict.device,
default=None,
filter_empty=True,
is_leaf=_is_leaf_nontensor,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10592,7 +10592,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
cond = self.cond(tensordict)
# Write result in step
tensordict["_step"] = tensordict.get("_step", True) & ~cond
if not tensordict["_step"].shape == tensordict.batch_size:
if tensordict["_step"].shape != tensordict.batch_size:
tensordict["_step"] = tensordict["_step"].view(tensordict.batch_size)
return tensordict

Expand Down
2 changes: 0 additions & 2 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,6 @@ def check_env_specs(
auto_reset=tensordict is None,
break_when_any_done=break_when_any_done,
)
print(real_tensordict)
print(fake_tensordict)

if return_contiguous:
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
Expand Down

0 comments on commit 77d7edd

Please sign in to comment.