Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 26, 2025
1 parent 33fedaa commit 40a67d7
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3638,8 +3638,11 @@ def test_serial(self, bwad, use_buffers):
def test_parallel(self, bwad, use_buffers):
N = 50
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
try:
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
finally:
env.close(raise_if_closed=False)

class AddString(Transform):
def __init__(self):
Expand Down Expand Up @@ -3671,19 +3674,22 @@ def test_partial_reset(self, batched):
env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
else:
env = SerialEnv(2, [env0, env1])
s = env.reset()
i = 0
for i in range(10): # noqa: B007
s, s_ = env.step_and_maybe_reset(
s.set("action", torch.ones(2, 1, dtype=torch.int))
)
if s.get(("next", "done")).any():
break
s = s_
assert i == 5
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
assert s_["string"] == ["0", "6"]
assert s["next", "string"] == ["6", "6"]
try:
s = env.reset()
i = 0
for i in range(10): # noqa: B007
s, s_ = env.step_and_maybe_reset(
s.set("action", torch.ones(2, 1, dtype=torch.int))
)
if s.get(("next", "done")).any():
break
s = s_
assert i == 5
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
assert s_["string"] == ["0", "6"]
assert s["next", "string"] == ["6", "6"]
finally:
env.close(raise_if_closed=False)

@pytest.mark.skipif(not _has_transformers, reason="transformers required")
def test_str2str_env_tokenizer(self):
Expand Down

0 comments on commit 40a67d7

Please sign in to comment.