From c70c81af9a3ca75661942b923e6ac98ed6d8e2fc Mon Sep 17 00:00:00 2001 From: Joachim Ungar Date: Wed, 8 Nov 2023 08:25:57 +0100 Subject: [PATCH] make sure profiling information is passed on to MFuture --- mapchete/commands/_cp.py | 4 +-- mapchete/commands/_execute.py | 2 +- mapchete/commands/_rm.py | 2 +- mapchete/executor/dask.py | 57 +++++++++++++++++++++++++++++----- mapchete/executor/future.py | 25 +++++++++++---- mapchete/processing/compute.py | 56 +++++++++++++-------------------- mapchete/processing/job.py | 1 - mapchete/processing/types.py | 4 ++- test/test_commands.py | 23 ++++++++++++++ test/test_formats_geotiff.py | 12 +++---- test/test_mapchete.py | 4 +-- 11 files changed, 129 insertions(+), 61 deletions(-) diff --git a/mapchete/commands/_cp.py b/mapchete/commands/_cp.py index ffea7132..1d13ab8a 100644 --- a/mapchete/commands/_cp.py +++ b/mapchete/commands/_cp.py @@ -216,13 +216,13 @@ def _copy_tiles( # check which source tiles exist logger.debug("looking for existing source tiles...") src_tiles_exist = dict( - tiles_exist(config=src_mp.config, output_tiles=tiles, multi=workers) + tiles_exist(config=src_mp.config, output_tiles=tiles, workers=workers) ) # check which destination tiles exist logger.debug("looking for existing destination tiles...") dst_tiles_exist = dict( - tiles_exist(config=dst_mp.config, output_tiles=tiles, multi=workers) + tiles_exist(config=dst_mp.config, output_tiles=tiles, workers=workers) ) # copy diff --git a/mapchete/commands/_execute.py b/mapchete/commands/_execute.py index ac2089ab..2a614802 100644 --- a/mapchete/commands/_execute.py +++ b/mapchete/commands/_execute.py @@ -192,13 +192,13 @@ def _empty_callback(_): dask_client=dask_client, multiprocessing_start_method=multiprocessing_start_method, max_workers=workers, - profiling=profiling, ), as_iterator=as_iterator, preprocessing_tasks=preprocessing_tasks, tiles_tasks=tiles_tasks, process_area=mp.config.init_area, stac_item_path=stac_item_path, + profiling=profiling, ) # explicitly exit the mp object on failure except Exception as exc: # pragma: no cover diff --git a/mapchete/commands/_rm.py b/mapchete/commands/_rm.py index f33fbf76..4a772305 100644 --- a/mapchete/commands/_rm.py +++ b/mapchete/commands/_rm.py @@ -106,7 +106,7 @@ def _empty_callback(*_): # this is required to omit tiles touching the config area if mp.config.area_at_zoom(zoom).intersection(t.bbox).area ], - multi=multi, + workers=multi, ): if exists: tiles[zoom].append(tile) diff --git a/mapchete/executor/dask.py b/mapchete/executor/dask.py index 361d4f96..8a6c6453 100644 --- a/mapchete/executor/dask.py +++ b/mapchete/executor/dask.py @@ -1,11 +1,15 @@ import logging import os -from functools import cached_property, partial +from functools import cached_property from typing import Any, Iterator, List +from dask.distributed import Client, LocalCluster, as_completed, wait + from mapchete.errors import JobCancelledError from mapchete.executor.base import ExecutorBase from mapchete.executor.future import MFuture +from mapchete.executor.types import Result +from mapchete.timer import Timer logger = logging.getLogger(__name__) @@ -21,8 +25,7 @@ def __init__( max_workers=None, **kwargs, ): - from dask.distributed import Client, LocalCluster, as_completed - + self.cancel_signal = False self._executor_client = dask_client self._local_cluster = None if self._executor_client: # pragma: no cover @@ -47,14 +50,21 @@ def __init__( def map(self, func, iterable, fargs=None, fkwargs=None) -> List[Any]: fargs = fargs or [] fkwargs = fkwargs or {} + + def _extract_result(future): + result = future.result() + if isinstance(result, Result): + return result.output + return result + return [ - f.result() - for f in self._executor.map(partial(func, *fargs, **fkwargs), iterable) + _extract_result(f) + for f in self._executor.map( + self.func_partial(func, *fargs, **fkwargs), iterable + ) ] def _wait(self): - from dask.distributed import wait - wait(self.running_futures) def _as_completed(self, *args, **kwargs) -> Iterator[MFuture]: @@ -184,6 +194,39 @@ def as_completed( self._ac_iterator.clear() self._submitted = 0 + def compute_task_graph( + self, + dask_collection=None, + with_results=False, + raise_errors=False, + ) -> Iterator[MFuture]: + # send to scheduler + + with Timer() as t: + futures = self._executor.compute( + dask_collection, optimize_graph=True, traverse=True + ) + logger.debug("%s tasks sent to scheduler in %s", len(futures), t) + self._submitted += len(futures) + + logger.debug("wait for tasks to finish...") + for batch in as_completed( + futures, + with_results=with_results, + raise_errors=raise_errors, + loop=self._executor.loop, + ).batches(): + for item in batch: + self._submitted -= 1 + if with_results: + future, result = item + else: + future, result = item, None + if self.cancel_signal: # pragma: no cover + logger.debug("executor cancelled") + raise JobCancelledError() + yield self._finished_future(future, result, _dask=True) + def _submit_chunk(self, chunk=None, func=None, fargs=None, fkwargs=None): if chunk: logger.debug("submit chunk of %s items to cluster", len(chunk)) diff --git a/mapchete/executor/future.py b/mapchete/executor/future.py index 863c0086..05781911 100644 --- a/mapchete/executor/future.py +++ b/mapchete/executor/future.py @@ -87,6 +87,11 @@ def from_future( else: name = str(future) + if hasattr(future, "profiling"): + profiling = future.profiling + else: + profiling = {} + if lazy: # keep around Future for later and don't call Future.result() return MFuture( @@ -95,6 +100,7 @@ def from_future( cancelled=future.cancelled(), status=status, name=name, + profiling=profiling, ) else: # immediately fetch Future.result() or use provided result @@ -102,13 +108,20 @@ def from_future( result = result or future.result(timeout=timeout) exception = future.exception(timeout=timeout) except Exception as exc: - return MFuture(exception=exc, status=status, name=name) - - return MFuture(result=result, exception=exception, status=status, name=name) + return MFuture( + exception=exc, status=status, name=name, profiling=profiling + ) + return MFuture( + result=result, + exception=exception, + status=status, + name=name, + profiling=profiling, + ) @staticmethod - def from_result(result: Any) -> MFuture: - return MFuture(result=result) + def from_result(result: Any, profiling: Optional[dict] = None) -> MFuture: + return MFuture(result=result, profiling=profiling) @staticmethod def skip(skip_info: Optional[Any] = None, result: Optional[Any] = None) -> MFuture: @@ -127,9 +140,9 @@ def from_func( def from_func_partial(func: Callable, item: Any) -> MFuture: try: result = func(item) - return MFuture(result=result.output, profiling=result.profiling) except Exception as exc: # pragma: no cover return MFuture(exception=exc) + return MFuture(result=result.output, profiling=result.profiling) def result(self, timeout: int = FUTURE_TIMEOUT, **kwargs) -> Any: """Return task result.""" diff --git a/mapchete/processing/compute.py b/mapchete/processing/compute.py index 5f33170a..7fe26919 100644 --- a/mapchete/processing/compute.py +++ b/mapchete/processing/compute.py @@ -3,8 +3,6 @@ from contextlib import ExitStack from typing import Any, Iterator, Optional, Union -from distributed import as_completed - from mapchete.enums import Concurrency from mapchete.errors import MapcheteNodataTile from mapchete.executor import DaskExecutor, Executor, ExecutorBase @@ -24,6 +22,7 @@ logger = logging.getLogger(__name__) +# TODO: this function probably better goes to base def compute( process, zoom_levels: Optional[ZoomLevelsLike] = None, @@ -116,6 +115,7 @@ def compute( logger.info("computed %s tasks in %s", num_processed, duration) +# TODO: this function has a better place in the base module def task_batches( process, zoom=None, tile=None, skip_output_check=False, propagate_results=True ) -> Iterator[Union[TaskBatch, TileTaskBatch]]: @@ -207,40 +207,29 @@ def task_batches( def _compute_task_graph( - dask_collection=None, - executor=None, - with_results=False, + dask_collection, + executor: DaskExecutor, + with_results: bool = False, write_in_parent_process: bool = False, - raise_errors=False, + raise_errors: bool = False, output_writer: Optional[Any] = None, **kwargs, ) -> Iterator[MFuture]: - # send to scheduler - with Timer() as t: - futures = executor._executor.compute( - dask_collection, optimize_graph=True, traverse=True - ) - logger.debug("%s tasks sent to scheduler in %s", len(futures), t) - - logger.debug("wait for tasks to finish...") - for batch in as_completed( - futures, - with_results=with_results, - raise_errors=raise_errors, - loop=executor._executor.loop, - ).batches(): - for future in batch: - if write_in_parent_process: - yield MFuture.from_result( - result=_write( - process_info=future.result(), - output_writer=output_writer, - append_output=True, - ) - ) - else: - yield MFuture.from_future(future) - futures.remove(future) + # send task graph to executor and yield as ready + for future in executor.compute_task_graph( + dask_collection, with_results=with_results, raise_errors=raise_errors + ): + if write_in_parent_process: + yield MFuture.from_result( + result=_write( + process_info=future.result(), + output_writer=output_writer, + append_output=True, + ), + profiling=future.profiling, + ) + else: + yield MFuture.from_future(future) def _compute_tasks( @@ -508,7 +497,6 @@ def _run_multi_overviews( # here we store the parents of processed tiles so we can update overviews # also in "continue" mode in case there were updates at the baselevel overview_parents = set() - for i, zoom in enumerate(zoom_levels.descending()): logger.debug("sending tasks to executor %s...", executor) # get generator list of tiles, whether they are to be skipped and skip_info @@ -723,7 +711,7 @@ def _run_multi_no_overviews( # output already has been written, so just use task process info else: process_info = future.result() - yield MFuture.from_result(result=process_info) + yield MFuture.from_result(result=process_info, profiling=future.profiling) ############################### diff --git a/mapchete/processing/job.py b/mapchete/processing/job.py index 9bcbb0fa..fadc262a 100644 --- a/mapchete/processing/job.py +++ b/mapchete/processing/job.py @@ -57,7 +57,6 @@ def __init__( self._process_area = process_area self.bounds = Bounds(*process_area.bounds) if process_area is not None else None self.stac_item_path = stac_item_path - if not as_iterator: self._results = list(self._run()) diff --git a/mapchete/processing/types.py b/mapchete/processing/types.py index 170edf2a..03ee4097 100644 --- a/mapchete/processing/types.py +++ b/mapchete/processing/types.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional from mapchete.tile import BufferedTile @@ -12,6 +12,7 @@ class TileProcessInfo: written: bool = False write_msg: Optional[str] = None data: Optional[Any] = None + profiling: dict = field(default_factory=dict) @dataclass @@ -22,3 +23,4 @@ class PreprocessingProcessInfo: written: bool = False write_msg: Optional[str] = None data: Optional[Any] = None + profiling: dict = field(default_factory=dict) diff --git a/test/test_commands.py b/test/test_commands.py index 5c06fbd4..8f58ec60 100644 --- a/test/test_commands.py +++ b/test/test_commands.py @@ -168,6 +168,29 @@ def test_execute_preprocessing_tasks(concurrency, preprocess_cache_raster_vector assert len(job) +@pytest.mark.parametrize( + "concurrency", + [ + "threads", + "dask", + "processes", + None, + ], +) +def test_execute_profiling(cleantopo_br_metatiling_1, concurrency): + zoom = 5 + job = execute( + cleantopo_br_metatiling_1.dict, + zoom=zoom, + as_iterator=True, + profiling=True, + concurrency=concurrency, + dask_compute_graph=False, + ) + for t in job: + assert t.profiling + + def test_convert_geodetic(cleantopo_br_tif, mp_tmpdir): """Automatic geodetic tile pyramid creation of raster files.""" job = convert(cleantopo_br_tif, mp_tmpdir, output_pyramid="geodetic") diff --git a/test/test_formats_geotiff.py b/test/test_formats_geotiff.py index 2c2db8d9..77f331cd 100644 --- a/test/test_formats_geotiff.py +++ b/test/test_formats_geotiff.py @@ -215,7 +215,7 @@ def test_output_single_gtiff(output_single_gtiff): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data @@ -343,7 +343,7 @@ def test_output_single_gtiff_s3(output_single_gtiff_s3): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data @@ -379,7 +379,7 @@ def test_output_single_gtiff_s3_tempfile(output_single_gtiff_s3): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data @@ -409,7 +409,7 @@ def test_output_single_gtiff_cog(output_single_gtiff_cog): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data @@ -445,7 +445,7 @@ def test_output_single_gtiff_cog_tempfile(output_single_gtiff_cog): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data @@ -477,7 +477,7 @@ def test_output_single_gtiff_cog_s3(output_single_gtiff_cog_s3): # check if tile exists assert not mp.config.output.tiles_exist(process_tile) # write - mp.batch_process(multi=2) + mp.batch_process(workers=2) # check if tile exists assert mp.config.output.tiles_exist(process_tile) # read again, this time with data diff --git a/test/test_mapchete.py b/test/test_mapchete.py index a981e7dd..726b278b 100644 --- a/test/test_mapchete.py +++ b/test/test_mapchete.py @@ -513,9 +513,9 @@ def test_batch_process(cleantopo_tl): # process single tile mp.batch_process(tile=(2, 0, 0)) # process using multiprocessing - mp.batch_process(zoom=2, multi=2) + mp.batch_process(zoom=2, workers=2) # process without multiprocessing - mp.batch_process(zoom=2, multi=1) + mp.batch_process(zoom=2, workers=1) def test_skip_tiles(cleantopo_tl):