diff --git a/test/test_env.py b/test/test_env.py index 8925e252e99..54b3b484083 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2551,8 +2551,11 @@ def make_env(seed, device=device): p_env = ParallelEnv( 2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)] ) - r_parallel = p_env.rollout(10, policy) - assert not r_parallel.exclude("action").requires_grad + try: + r_parallel = p_env.rollout(10, policy) + assert not r_parallel.exclude("action").requires_grad + finally: + p_env.close() if __name__ == "__main__":