From 40a67d7ccb6a3813d262c1b27059c03db41a5b48 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 26 Feb 2025 14:20:53 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_env.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 69c63e567a6..de1dd751a67 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3638,8 +3638,11 @@ def test_serial(self, bwad, use_buffers): def test_parallel(self, bwad, use_buffers): N = 50 env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers) - r = env.rollout(N, break_when_any_done=bwad) - assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + try: + r = env.rollout(N, break_when_any_done=bwad) + assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + finally: + env.close(raise_if_closed=False) class AddString(Transform): def __init__(self): @@ -3671,19 +3674,22 @@ def test_partial_reset(self, batched): env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx) else: env = SerialEnv(2, [env0, env1]) - s = env.reset() - i = 0 - for i in range(10): # noqa: B007 - s, s_ = env.step_and_maybe_reset( - s.set("action", torch.ones(2, 1, dtype=torch.int)) - ) - if s.get(("next", "done")).any(): - break - s = s_ - assert i == 5 - assert (s["next", "done"] == torch.tensor([[True], [False]])).all() - assert s_["string"] == ["0", "6"] - assert s["next", "string"] == ["6", "6"] + try: + s = env.reset() + i = 0 + for i in range(10): # noqa: B007 + s, s_ = env.step_and_maybe_reset( + s.set("action", torch.ones(2, 1, dtype=torch.int)) + ) + if s.get(("next", "done")).any(): + break + s = s_ + assert i == 5 + assert (s["next", "done"] == torch.tensor([[True], [False]])).all() + assert s_["string"] == ["0", "6"] + assert s["next", "string"] == ["6", "6"] + finally: + env.close(raise_if_closed=False) @pytest.mark.skipif(not _has_transformers, reason="transformers required") def test_str2str_env_tokenizer(self):