Skip to content

Commit

Permalink
make sure profiling information is passed on to MFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
ungarj committed Nov 8, 2023
1 parent d3a5b2c commit c70c81a
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 61 deletions.
4 changes: 2 additions & 2 deletions mapchete/commands/_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def _copy_tiles(
# check which source tiles exist
logger.debug("looking for existing source tiles...")
src_tiles_exist = dict(
tiles_exist(config=src_mp.config, output_tiles=tiles, multi=workers)
tiles_exist(config=src_mp.config, output_tiles=tiles, workers=workers)
)

# check which destination tiles exist
logger.debug("looking for existing destination tiles...")
dst_tiles_exist = dict(
tiles_exist(config=dst_mp.config, output_tiles=tiles, multi=workers)
tiles_exist(config=dst_mp.config, output_tiles=tiles, workers=workers)
)

# copy
Expand Down
2 changes: 1 addition & 1 deletion mapchete/commands/_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def _empty_callback(_):
dask_client=dask_client,
multiprocessing_start_method=multiprocessing_start_method,
max_workers=workers,
profiling=profiling,
),
as_iterator=as_iterator,
preprocessing_tasks=preprocessing_tasks,
tiles_tasks=tiles_tasks,
process_area=mp.config.init_area,
stac_item_path=stac_item_path,
profiling=profiling,
)
# explicitly exit the mp object on failure
except Exception as exc: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion mapchete/commands/_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _empty_callback(*_):
# this is required to omit tiles touching the config area
if mp.config.area_at_zoom(zoom).intersection(t.bbox).area
],
multi=multi,
workers=multi,
):
if exists:
tiles[zoom].append(tile)
Expand Down
57 changes: 50 additions & 7 deletions mapchete/executor/dask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
import os
from functools import cached_property, partial
from functools import cached_property
from typing import Any, Iterator, List

from dask.distributed import Client, LocalCluster, as_completed, wait

from mapchete.errors import JobCancelledError
from mapchete.executor.base import ExecutorBase
from mapchete.executor.future import MFuture
from mapchete.executor.types import Result
from mapchete.timer import Timer

logger = logging.getLogger(__name__)

Expand All @@ -21,8 +25,7 @@ def __init__(
max_workers=None,
**kwargs,
):
from dask.distributed import Client, LocalCluster, as_completed

self.cancel_signal = False
self._executor_client = dask_client
self._local_cluster = None
if self._executor_client: # pragma: no cover
Expand All @@ -47,14 +50,21 @@ def __init__(
def map(self, func, iterable, fargs=None, fkwargs=None) -> List[Any]:
fargs = fargs or []
fkwargs = fkwargs or {}

def _extract_result(future):
result = future.result()
if isinstance(result, Result):
return result.output
return result

return [
f.result()
for f in self._executor.map(partial(func, *fargs, **fkwargs), iterable)
_extract_result(f)
for f in self._executor.map(
self.func_partial(func, *fargs, **fkwargs), iterable
)
]

def _wait(self):
from dask.distributed import wait

wait(self.running_futures)

def _as_completed(self, *args, **kwargs) -> Iterator[MFuture]:
Expand Down Expand Up @@ -184,6 +194,39 @@ def as_completed(
self._ac_iterator.clear()
self._submitted = 0

def compute_task_graph(
self,
dask_collection=None,
with_results=False,
raise_errors=False,
) -> Iterator[MFuture]:
# send to scheduler

with Timer() as t:
futures = self._executor.compute(
dask_collection, optimize_graph=True, traverse=True
)
logger.debug("%s tasks sent to scheduler in %s", len(futures), t)
self._submitted += len(futures)

logger.debug("wait for tasks to finish...")
for batch in as_completed(
futures,
with_results=with_results,
raise_errors=raise_errors,
loop=self._executor.loop,
).batches():
for item in batch:
self._submitted -= 1
if with_results:
future, result = item
else:
future, result = item, None
if self.cancel_signal: # pragma: no cover
logger.debug("executor cancelled")
raise JobCancelledError()
yield self._finished_future(future, result, _dask=True)

def _submit_chunk(self, chunk=None, func=None, fargs=None, fkwargs=None):
if chunk:
logger.debug("submit chunk of %s items to cluster", len(chunk))
Expand Down
25 changes: 19 additions & 6 deletions mapchete/executor/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def from_future(
else:
name = str(future)

if hasattr(future, "profiling"):
profiling = future.profiling
else:
profiling = {}

if lazy:
# keep around Future for later and don't call Future.result()
return MFuture(
Expand All @@ -95,20 +100,28 @@ def from_future(
cancelled=future.cancelled(),
status=status,
name=name,
profiling=profiling,
)
else:
# immediately fetch Future.result() or use provided result
try:
result = result or future.result(timeout=timeout)
exception = future.exception(timeout=timeout)
except Exception as exc:
return MFuture(exception=exc, status=status, name=name)

return MFuture(result=result, exception=exception, status=status, name=name)
return MFuture(
exception=exc, status=status, name=name, profiling=profiling
)
return MFuture(
result=result,
exception=exception,
status=status,
name=name,
profiling=profiling,
)

@staticmethod
def from_result(result: Any) -> MFuture:
return MFuture(result=result)
def from_result(result: Any, profiling: Optional[dict] = None) -> MFuture:
return MFuture(result=result, profiling=profiling)

@staticmethod
def skip(skip_info: Optional[Any] = None, result: Optional[Any] = None) -> MFuture:
Expand All @@ -127,9 +140,9 @@ def from_func(
def from_func_partial(func: Callable, item: Any) -> MFuture:
try:
result = func(item)
return MFuture(result=result.output, profiling=result.profiling)
except Exception as exc: # pragma: no cover
return MFuture(exception=exc)
return MFuture(result=result.output, profiling=result.profiling)

def result(self, timeout: int = FUTURE_TIMEOUT, **kwargs) -> Any:
"""Return task result."""
Expand Down
56 changes: 22 additions & 34 deletions mapchete/processing/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from contextlib import ExitStack
from typing import Any, Iterator, Optional, Union

from distributed import as_completed

from mapchete.enums import Concurrency
from mapchete.errors import MapcheteNodataTile
from mapchete.executor import DaskExecutor, Executor, ExecutorBase
Expand All @@ -24,6 +22,7 @@
logger = logging.getLogger(__name__)


# TODO: this function probably better goes to base
def compute(
process,
zoom_levels: Optional[ZoomLevelsLike] = None,
Expand Down Expand Up @@ -116,6 +115,7 @@ def compute(
logger.info("computed %s tasks in %s", num_processed, duration)


# TODO: this function has a better place in the base module
def task_batches(
process, zoom=None, tile=None, skip_output_check=False, propagate_results=True
) -> Iterator[Union[TaskBatch, TileTaskBatch]]:
Expand Down Expand Up @@ -207,40 +207,29 @@ def task_batches(


def _compute_task_graph(
dask_collection=None,
executor=None,
with_results=False,
dask_collection,
executor: DaskExecutor,
with_results: bool = False,
write_in_parent_process: bool = False,
raise_errors=False,
raise_errors: bool = False,
output_writer: Optional[Any] = None,
**kwargs,
) -> Iterator[MFuture]:
# send to scheduler
with Timer() as t:
futures = executor._executor.compute(
dask_collection, optimize_graph=True, traverse=True
)
logger.debug("%s tasks sent to scheduler in %s", len(futures), t)

logger.debug("wait for tasks to finish...")
for batch in as_completed(
futures,
with_results=with_results,
raise_errors=raise_errors,
loop=executor._executor.loop,
).batches():
for future in batch:
if write_in_parent_process:
yield MFuture.from_result(
result=_write(
process_info=future.result(),
output_writer=output_writer,
append_output=True,
)
)
else:
yield MFuture.from_future(future)
futures.remove(future)
# send task graph to executor and yield as ready
for future in executor.compute_task_graph(
dask_collection, with_results=with_results, raise_errors=raise_errors
):
if write_in_parent_process:
yield MFuture.from_result(
result=_write(
process_info=future.result(),
output_writer=output_writer,
append_output=True,
),
profiling=future.profiling,
)
else:
yield MFuture.from_future(future)


def _compute_tasks(
Expand Down Expand Up @@ -508,7 +497,6 @@ def _run_multi_overviews(
# here we store the parents of processed tiles so we can update overviews
# also in "continue" mode in case there were updates at the baselevel
overview_parents = set()

for i, zoom in enumerate(zoom_levels.descending()):
logger.debug("sending tasks to executor %s...", executor)
# get generator list of tiles, whether they are to be skipped and skip_info
Expand Down Expand Up @@ -723,7 +711,7 @@ def _run_multi_no_overviews(
# output already has been written, so just use task process info
else:
process_info = future.result()
yield MFuture.from_result(result=process_info)
yield MFuture.from_result(result=process_info, profiling=future.profiling)


###############################
Expand Down
1 change: 0 additions & 1 deletion mapchete/processing/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
self._process_area = process_area
self.bounds = Bounds(*process_area.bounds) if process_area is not None else None
self.stac_item_path = stac_item_path

if not as_iterator:
self._results = list(self._run())

Expand Down
4 changes: 3 additions & 1 deletion mapchete/processing/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional

from mapchete.tile import BufferedTile
Expand All @@ -12,6 +12,7 @@ class TileProcessInfo:
written: bool = False
write_msg: Optional[str] = None
data: Optional[Any] = None
profiling: dict = field(default_factory=dict)


@dataclass
Expand All @@ -22,3 +23,4 @@ class PreprocessingProcessInfo:
written: bool = False
write_msg: Optional[str] = None
data: Optional[Any] = None
profiling: dict = field(default_factory=dict)
23 changes: 23 additions & 0 deletions test/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,29 @@ def test_execute_preprocessing_tasks(concurrency, preprocess_cache_raster_vector
assert len(job)


@pytest.mark.parametrize(
"concurrency",
[
"threads",
"dask",
"processes",
None,
],
)
def test_execute_profiling(cleantopo_br_metatiling_1, concurrency):
zoom = 5
job = execute(
cleantopo_br_metatiling_1.dict,
zoom=zoom,
as_iterator=True,
profiling=True,
concurrency=concurrency,
dask_compute_graph=False,
)
for t in job:
assert t.profiling


def test_convert_geodetic(cleantopo_br_tif, mp_tmpdir):
"""Automatic geodetic tile pyramid creation of raster files."""
job = convert(cleantopo_br_tif, mp_tmpdir, output_pyramid="geodetic")
Expand Down
Loading

0 comments on commit c70c81a

Please sign in to comment.