Skip to content

Commit

Permalink
Merge pull request #57 from LSSTDESC/dask
Browse files Browse the repository at this point in the history
Enable dask in pipeline stages
  • Loading branch information
joezuntz authored Aug 6, 2021
2 parents 8fa0121 + 60ab3ed commit 616d068
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ jobs:

- name: Install
run: |
pip install --upgrade pytest pytest-mock codecov pytest-cov h5py
pip install --upgrade pytest pytest-mock codecov pytest-cov h5py pyyaml mockmpi
pip install dask[distributed]
pip install git+https://github.com/joezuntz/dask-mpi
pip install .[test,cwl,parsl]
- name: Tests
Expand Down
101 changes: 95 additions & 6 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

SERIAL = "serial"
MPI_PARALLEL = "mpi"
DASK_PARALLEL = "dask"

IN_PROGRESS_PREFIX = "inprogress_"

Expand All @@ -29,6 +30,7 @@ class PipelineStage:
"""

parallel = True
dask_parallel = False
config_options = {}
doc = ""

Expand Down Expand Up @@ -141,6 +143,12 @@ def __init__(self, args, comm=None):
self._size = 1
self._rank = 0

# If we are running under MPI but this subclass has enabled dask
# then we note that here. It stops various MPI-specific things happening
# later
if (self._parallel == MPI_PARALLEL) and self.dask_parallel:
self._parallel = DASK_PARALLEL

pipeline_stages = {}
incomplete_pipeline_stages = {}

Expand Down Expand Up @@ -372,7 +380,7 @@ def _parse_command_line(cls, cmd=None):
return args

@classmethod
def execute(cls, args):
def execute(cls, args, comm=None):
"""
Create an instance of this stage and run it
with the specified inputs and outputs.
Expand All @@ -386,11 +394,22 @@ def execute(cls, args):
"""
import pdb

stage = cls(args)
# Create the stage instance. Running under dask this only
# actually needs to happen for one process, but it's not a major
# overhead and lets us do a whole bunch of other setup above
stage = cls(args, comm=comm)

# This happens before dask is initialized
if stage.rank == 0:
print(f"Executing stage: {cls.name}")

if stage.is_dask():
is_client = stage.start_dask()
# worker and scheduler stages do not execute the
# run method under dask
if not is_client:
return

if args.cprofile:
profile = cProfile.Profile()
profile.enable()
Expand All @@ -412,6 +431,8 @@ def execute(cls, args):
finally:
if args.memmon:
monitor.stop()
if stage.is_dask():
stage.stop_dask()

# The default finalization renames any output files to their
# final location, but subclasses can override to do other things too
Expand All @@ -431,7 +452,11 @@ def execute(cls, args):
profile.dump_stats(args.cprofile)
profile.print_stats("cumtime")

if stage.rank == 0:
# Under dask the
# the root process has gone off to become the scheduler,
# and process 1 becomes the client which runs this code
# and gets to this point
if stage.rank == 0 or stage.is_dask():
print(f"Stage complete: {cls.name}")

def finalize(self):
Expand All @@ -441,8 +466,10 @@ def finalize(self):
self.comm.Barrier()

# Move files to their final path
# only the master process moves things
if self.rank == 0:
# Only the root process moves things, except under dask it is
# process 1, which is the only process that reaches this point
# (as noted above)
if (self.rank == 0) or self.is_dask():
for tag in self.output_tags():
# find the old and new names
temp_name = self.get_output(tag)
Expand All @@ -463,7 +490,6 @@ def finalize(self):
#############################################
# Parallelism-related methods and properties.
#############################################

@property
def rank(self):
"""The rank of this process under MPI (0 if not running under MPI)"""
Expand Down Expand Up @@ -494,6 +520,69 @@ def is_mpi(self):
"""
return self._parallel == MPI_PARALLEL

def is_dask(self):
"""
Returns True if the stage is being run in parallel with Dask.
"""
return self._parallel == DASK_PARALLEL

def start_dask(self):
"""
Prepare dask to run under MPI. After calling this method
only a single process, MPI rank 1 will continue to exeute code
"""

# using the programmatic dask configuration system
# does not seem to work. Presumably the loggers have already
# been created by the time we modify the config. Doing it with
# env vars seems to work. If the user has already set this then
# we use that value. Otherwise we only want error logs
import os

key = "DASK_LOGGING__DISTRIBUTED"
os.environ[key] = os.environ.get(key, "error")
try:
import dask
import dask_mpi
import dask.distributed
except ImportError:
print(
"ERROR: Using --mpi option on stages that use dask requires "
"dask[distributed] and dask_mpi to be installed."
)
raise

if self.size < 3:
raise ValueError(
"Dask requires at least three processes. One becomes a scheduler "
"process, one is a client that runs the code, and more are required "
"as worker processes."
)

# This requires my fork until/unless they merge the PR, to allow
# us to pass in these two arguments. In vanilla dask-mpi sys.exit
# is called at the end of the event loop without returning to us.
# After this point only a single process, MPI rank 1,
# should continue to exeute code. The others enter an event
# loop and return with is_client=False, which we return here
# to tell the caller that they should not run everything.
is_client = dask_mpi.initialize(comm=self.comm, exit=False)

if is_client:
# Connect this local process to remote workers.
self.dask_client = dask.distributed.Client()
# I don't yet know how to see this dashboard link at nersc
print(f"Started dask. Diagnostics at {self.dask_client.dashboard_link}")

return is_client

def stop_dask(self):
"""
End the dask event loop
"""
from dask_mpi import send_close_signal
send_close_signal()

def split_tasks_by_rank(self, tasks):
"""Iterate through a list of items, yielding ones this process is responsible for/
Expand Down
35 changes: 35 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from ceci.stage import PipelineStage
import mockmpi

def core_dask(comm):
class DaskTestStage(PipelineStage):
name = "dasktest"
dask_parallel = True
inputs = []
outputs = []
config_options = {}

def run(self):
import dask.array as da
arr = da.arange(100)
x = arr.sum()
x = x.compute()
assert x == 4950


args = DaskTestStage._parse_command_line(["dasktest", "--config", "tests/config.yml"])
DaskTestStage.execute(args, comm=comm)

# check that all procs get here
if comm is not None:
comm.Barrier()


def test_dask():
core_dask(None)
mockmpi.mock_mpiexec(3, core_dask)
mockmpi.mock_mpiexec(5, core_dask)


if __name__ == '__main__':
test_dask()

0 comments on commit 616d068

Please sign in to comment.