From b0fdbce021f834165c5ce15c2c9d336b97cd07fd Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Feb 2024 17:34:16 +0000 Subject: [PATCH] amend --- torchrl/envs/batched_envs.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ae313ce5f19..0fd3671d337 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -29,7 +29,7 @@ VERBOSE, ) from torchrl.data.tensor_specs import CompositeSpec -from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING +from torchrl.data.utils import contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _EnvPostInit, EnvBase from torchrl.envs.env_creator import get_env_metadata @@ -308,6 +308,13 @@ def __init__( self.policy_proof = policy_proof self.num_workers = num_workers self.create_env_fn = create_env_fn + + from torchrl.envs.env_creator import EnvCreator + + for i, env_fun in enumerate(self.create_env_fn): + if not isinstance(env_fun, EnvCreator) and not isinstance(env_fun, EnvBase): + self.create_env_fn[i] = EnvCreator(env_fun) + self.create_env_kwargs = create_env_kwargs self.pin_memory = pin_memory if pin_memory: @@ -1050,7 +1057,6 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """ def _start_workers(self) -> None: - from torchrl.envs.env_creator import EnvCreator if self.num_threads is None: self.num_threads = max( @@ -1089,8 +1095,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() env_fun = self.create_env_fn[idx] - if not isinstance(env_fun, EnvCreator): - env_fun = CloudpickleWrapper(env_fun) kwargs[idx].update( { "parent_pipe": parent_pipe,