diff --git a/src/orion/executor/multiprocess_backend.py b/src/orion/executor/multiprocess_backend.py index 03cb557e7..93046d017 100644 --- a/src/orion/executor/multiprocess_backend.py +++ b/src/orion/executor/multiprocess_backend.py @@ -83,10 +83,10 @@ def Process(*args, **kwds): if v.major == 3 and v.minor >= 8: args = args[1:] - if Pool.ALLOW_DAEMON: - return Process(*args, **kwds) + if not Pool.ALLOW_DAEMON: + return PyPool.Process(*args, **kwds) - return _Process(*args, **kwds) + return _Process(*args, **kwds, daemon=False) def shutdown(self): # NB: https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html @@ -167,13 +167,18 @@ def __init__(self, n_workers=-1, backend="multiprocess", **kwargs): if n_workers <= 0: n_workers = multiprocessing.cpu_count() + self.pool_config = {"n_workers": n_workers, "backend": backend} self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers) def __setstate__(self, state): - self.pool = state["pool"] + log.warning("Nesting multiprocess executor") + self.pool_config = state["pool_config"] + backend = self.pool_config.get("backend") + n_workers = self.pool_config.get("n_workers", -1) + self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers) def __getstate__(self): - return dict(pool=self.pool) + return {"pool_config": self.pool_config} def __enter__(self): return self diff --git a/tests/unittests/client/test_experiment_client.py b/tests/unittests/client/test_experiment_client.py index 61282ba7f..cfa4264e6 100644 --- a/tests/unittests/client/test_experiment_client.py +++ b/tests/unittests/client/test_experiment_client.py @@ -1000,7 +1000,7 @@ def main(*args, **kwargs): def test_run_experiment_twice(): - """""" + """Makes sure the executor is not freed after workon""" with create_experiment(config, base_trial) as (cfg, experiment, client): client.workon(main, max_trials=10) diff --git a/tests/unittests/executor/test_executor.py b/tests/unittests/executor/test_executor.py index 4223d69e3..720fceb6d 100644 --- a/tests/unittests/executor/test_executor.py +++ b/tests/unittests/executor/test_executor.py @@ -1,3 +1,5 @@ +import multiprocessing +import multiprocessing.process as proc import os import time @@ -9,6 +11,14 @@ from orion.executor.ray_backend import HAS_RAY, Ray from orion.executor.single_backend import SingleExecutor +try: + import torch + from torchvision import datasets, transforms + + HAS_PYTORCH = True +except: + HAS_PYTORCH = False + def multiprocess(n): return PoolExecutor(n, "multiprocess") @@ -265,7 +275,7 @@ def nested(executor): return sum(f.get() for f in futures) -@pytest.mark.parametrize("backend", [xfail_dask_if_not_installed(Dask), SingleExecutor]) +@pytest.mark.parametrize("backend", backends) def test_nested_submit(backend): with backend(5) as executor: futures = [executor.submit(nested, executor) for i in range(5)] @@ -276,17 +286,36 @@ def test_nested_submit(backend): assert r.value == 35 -@pytest.mark.parametrize("backend", [multiprocess, thread]) -def test_nested_submit_failure(backend): +def inc(a): + return a + 1 + + +def nested_pool(): + import multiprocessing.process as proc + + assert not proc._current_process._config.get("daemon") + + data = [1, 2, 3, 4, 5, 6] + with multiprocessing.Pool(5) as p: + result = p.map_async(inc, data) + result.wait() + data = result.get() + + return sum(data) + + +@pytest.mark.parametrize("backend", backends) +def test_nested_submit_pool(backend): + if backend is Dask: + pytest.xfail("Dask does not support nesting") + with backend(5) as executor: + futures = [executor.submit(nested_pool) for i in range(5)] - if backend == multiprocess: - exception = NotImplementedError - elif backend == thread: - exception = TypeError + results = executor.async_get(futures, timeout=2) - with pytest.raises(exception): - [executor.submit(nested, executor) for i in range(5)] + for r in results: + assert r.value == 27 @pytest.mark.parametrize("executor", executors) @@ -310,3 +339,40 @@ def test_executors_del_does_not_raise(backend): del executor.client del executor + + +def pytorch_workon(pid): + assert not proc._current_process._config.get("daemon") + + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + dataset = datasets.FakeData(128, transform=transform) + + loader = torch.utils.data.DataLoader(dataset, num_workers=2, batch_size=64) + + for i, _ in enumerate(loader): + pass + + return i + + +@pytest.mark.parametrize("backend", backends) +def test_pytorch_dataloader(backend): + if backend is Dask: + pytest.xfail("Dask does not support nesting") + + if not HAS_PYTORCH: + pytest.skip("Pytorch is not installed skipping") + return + + with backend(2) as executor: + futures = [executor.submit(pytorch_workon, i) for i in range(2)] + + results = executor.async_get(futures, timeout=2) + + for r in results: + assert r.value == 1