Skip to content

Commit

Permalink
Allow explicit setting of multiprocessing context for SubprocEnvWorker (
Browse files Browse the repository at this point in the history
#1072)

Running multiple training runs in parallel (with, for example, joblib)
fails on macOS due to a change in the standard context for
multiprocessing (see
[here](https://stackoverflow.com/questions/65098398/why-using-fork-works-but-using-spawn-fails-in-python3-8-multiprocessing)
or
[here](https://www.reddit.com/r/learnpython/comments/g5372v/multiprocessing_with_fork_on_macos/)).
This PR adds the ability to explicitly set a multiprocessing context for
the SubProcEnvWorker (similar to gymnasium's
[AsyncVecEnv](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/vector/async_vector_env.py)).
---------

Co-authored-by: Maximilian Huettenrauch <[email protected]>
Co-authored-by: Michael Panchenko <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2024
1 parent 1714c7f commit e82379c
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 34 deletions.
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,7 @@ Dominik
Tsinghua
Tianshou
appliedAI
macOS
joblib
master
Panchenko
10 changes: 8 additions & 2 deletions examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def restore(self, event: RestoreEvent, world: World):


class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(self, task: str, seed: int, obs_norm=True) -> None:
def __init__(
self,
task: str,
seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
) -> None:
super().__init__(
task=task,
seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
self.obs_norm = obs_norm
Expand Down
58 changes: 48 additions & 10 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, Literal

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -371,8 +371,13 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)
def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
super().__init__(env_fns, DummyEnvWorker, wait_num, timeout)


class SubprocVectorEnv(BaseVectorEnv):
Expand All @@ -381,13 +386,36 @@ class SubprocVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
Additional arguments are:
:param share_memory: whether to share memory between the main process and the worker process. Allows for
shared buffers to exchange observations
:param context: the context to use for multiprocessing. Usually it's fine to use the default context, but
`spawn` as well as `fork` can have non-obvious side effects, see for example
https://github.com/google-deepmind/mujoco/issues/742, or
https://github.com/Farama-Foundation/Gymnasium/issues/222.
Consider using 'fork' when using macOS and additional parallelization, for example via joblib.
Defaults to None, which will use the default system context.
"""

def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
share_memory: bool = False,
context: Literal["fork", "spawn"] | None = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
return SubprocEnvWorker(fn, share_memory=share_memory, context=context)

super().__init__(env_fns, worker_fn, **kwargs)
super().__init__(
env_fns,
worker_fn,
wait_num,
timeout,
)


class ShmemVectorEnv(BaseVectorEnv):
Expand All @@ -400,11 +428,16 @@ class ShmemVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)

super().__init__(env_fns, worker_fn, **kwargs)
super().__init__(env_fns, worker_fn, wait_num, timeout)


class RayVectorEnv(BaseVectorEnv):
Expand All @@ -417,7 +450,12 @@ class RayVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
try:
import ray
except ImportError as exception:
Expand All @@ -426,4 +464,4 @@ def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) ->
) from exception
if not ray.is_initialized():
ray.init()
super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), **kwargs)
super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout)
51 changes: 38 additions & 13 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ctypes
import multiprocessing
import time
from collections import OrderedDict
from collections.abc import Callable
from multiprocessing import Array, Pipe, connection
from multiprocessing.context import Process
from typing import Any
from multiprocessing import Pipe, connection
from multiprocessing.context import BaseContext
from typing import Any, Literal

import gymnasium as gym
import numpy as np
Expand All @@ -31,10 +32,26 @@


class ShArray:
"""Wrapper of multiprocessing Array."""
"""Wrapper of multiprocessing Array.
def __init__(self, dtype: np.generic, shape: tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
Example usage:
::
import numpy as np
import multiprocessing as mp
from tianshou.env.worker.subproc import ShArray
ctx = mp.get_context('fork') # set an explicit context
arr = ShArray(np.dtype(np.float32), (2, 3), ctx)
arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
print(arr.get())
"""

def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None:
if ctx is None:
ctx = multiprocessing.get_context()
self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
self.dtype = dtype
self.shape = shape

Expand All @@ -49,14 +66,14 @@ def get(self) -> np.ndarray:
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore


def _setup_buf(space: gym.Space) -> dict | tuple | ShArray:
def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()}
if isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
return ShArray(space.dtype, space.shape) # type: ignore
return tuple([_setup_buf(t, ctx) for t in space.spaces])
return ShArray(space.dtype, space.shape, ctx) # type: ignore


def _worker(
Expand Down Expand Up @@ -125,23 +142,31 @@ def _encode_obs(
class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""

def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
def __init__(
self,
env_fn: Callable[[], gym.Env],
share_memory: bool = False,
context: BaseContext | Literal["fork", "spawn"] | None = None,
) -> None:
self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory
self.buffer: dict | tuple | ShArray | None = None
if not isinstance(context, BaseContext):
context = multiprocessing.get_context(context)
assert hasattr(context, "Process") # for mypy
if self.share_memory:
dummy = env_fn()
obs_space = dummy.observation_space
dummy.close()
del dummy
self.buffer = _setup_buf(obs_space)
self.buffer = _setup_buf(obs_space, context)
args = (
self.parent_remote,
self.child_remote,
CloudpickleWrapper(env_fn),
self.buffer,
)
self.process = Process(target=_worker, args=args, daemon=True)
self.process = context.Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
super().__init__(env_fn)
Expand Down
35 changes: 26 additions & 9 deletions tianshou/highlevel/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseVectorEnv,
DummyVectorEnv,
RayVectorEnv,
ShmemVectorEnv,
SubprocVectorEnv,
)
from tianshou.highlevel.persistence import Persistence
Expand Down Expand Up @@ -69,17 +68,25 @@ class VectorEnvType(Enum):
"""Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem"
"""Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray"
"""Parallelization based on the `ray` library"""

def create_venv(self, factories: Sequence[Callable[[], gym.Env]]) -> BaseVectorEnv:
def create_venv(
self,
factories: Sequence[Callable[[], gym.Env]],
) -> BaseVectorEnv:
match self:
case VectorEnvType.DUMMY:
return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM:
return ShmemVectorEnv(factories)
return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.RAY:
return RayVectorEnv(factories)
case _:
Expand Down Expand Up @@ -121,10 +128,14 @@ def from_factory_and_type(
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
train_envs = venv_type.create_venv(
[lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs,
)
test_envs = venv_type.create_venv(
[lambda: factory_fn(EnvMode.TEST)] * num_test_envs,
)
if create_watch_env:
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)])
else:
watch_env = None
env = factory_fn(EnvMode.TRAIN)
Expand Down Expand Up @@ -344,7 +355,9 @@ class EnvFactory(ToStringMixin, ABC):
"""Main interface for the creation of environments (in various forms)."""

def __init__(self, venv_type: VectorEnvType):
""":param venv_type: the type of vectorized environment to use"""
""":param venv_type: the type of vectorized environment to use for train and test environments.
watch environments are always created as dummy environments.
"""
self.venv_type = venv_type

@abstractmethod
Expand All @@ -355,10 +368,14 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.
:return: the vectorized environments
"""
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
if mode == EnvMode.WATCH:
return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)])
else:
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)

def create_envs(
self,
Expand Down

0 comments on commit e82379c

Please sign in to comment.