diff --git a/test/test_env.py b/test/test_env.py index ec44b0653e4..418b65580d3 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2539,7 +2539,9 @@ def make_env(seed, device=device): return env serial_env = SerialEnv( - 2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], device=device + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, ) r_serial = serial_env.rollout(10, policy) @@ -2549,7 +2551,9 @@ def make_env(seed, device=device): torch.testing.assert_close(g, g_serial) p_env = ParallelEnv( - 2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], device=device, + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, ) try: r_parallel = p_env.rollout(10, policy)