Skip to content

Commit

Permalink
handle task dependencies; Executor initialized with 1 worker will not…
Browse files Browse the repository at this point in the history
… automatically return a SequentialExecutor
  • Loading branch information
ungarj committed Nov 29, 2023
1 parent 46928b9 commit 8453b61
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 58 deletions.
1 change: 0 additions & 1 deletion mapchete/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class ProcessConfig(BaseModel, arbitrary_types_allowed=True):

class DaskSettings(BaseModel):
process_graph: bool = True
propagate_results: bool = False
max_submitted_tasks: int = 500
chunksize: int = 100
scheduler: Optional[str] = None
2 changes: 1 addition & 1 deletion mapchete/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __new__(cls, *args, concurrency=None, **kwargs) -> ExecutorBase:
if concurrency == "dask":
return DaskExecutor(*args, **kwargs)

elif concurrency is None or kwargs.get("max_workers") == 1:
elif concurrency is None:
return SequentialExecutor(*args, **kwargs)

elif concurrency in ["processes", "threads"]:
Expand Down
2 changes: 0 additions & 2 deletions mapchete/formats/default/gtiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,6 @@ def write(self, process_tile, data):
from_bounds(
*out_tile.bounds,
transform=self.dst.transform,
height=self.dst.height,
width=self.dst.width,
)
.round_lengths(pixel_precision=0)
.round_offsets(pixel_precision=0)
Expand Down
4 changes: 4 additions & 0 deletions mapchete/formats/default/raster_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def read(self, indexes=None, resampling="nearest", **kwargs):
data : array
"""
if self._memory_cache_active:
logger.debug(
"available preprocessing tasks results: %s",
self.preprocessing_tasks_results,
)
self._in_memory_raster = (
self._in_memory_raster
or self.preprocessing_tasks_results.get(self.cache_task_key)
Expand Down
5 changes: 1 addition & 4 deletions mapchete/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,6 @@ def execute(
# we have to do this before it can be decided which type of processing can be applied
tasks = self.tasks(zoom=zoom, tile=tile) if tasks is None else tasks

# TODO: check this again
propagate_results = propagate_results or dask_settings.propagate_results

if len(tasks) == 0:
return

Expand Down Expand Up @@ -394,7 +391,7 @@ def execute_tile(
self.execute_preprocessing_tasks()
try:
return self.config.output.streamline_output(
TileTask(tile=process_tile, config=self.config).execute().output
TileTask(tile=process_tile, config=self.config).execute()
)
except MapcheteNodataTile:
if raise_nodata: # pragma: no cover
Expand Down
4 changes: 3 additions & 1 deletion mapchete/processing/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def batches(
if batch.id != "preprocessing_tasks":
for task in batch:
for id, result in preprocessing_tasks_results.items():
task.set_preprocessing_task_result(id, result)
task.add_dependency(id, result)

for future in executor.as_completed(task_wrapper, batch):
task_info = TaskInfo.from_future(future)
Expand Down Expand Up @@ -149,6 +149,7 @@ def execute_wrapper(
if isinstance(output, TaskInfo):
return output
logger.debug((task.id, processor_message))

if task.tile:
return TaskInfo(
id=default_tile_task_id(task.tile),
Expand All @@ -159,6 +160,7 @@ def execute_wrapper(
write_msg=None,
output=output if append_data else None,
)

else:
return TaskInfo(
id=task.id,
Expand Down
81 changes: 45 additions & 36 deletions mapchete/processing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy.ma as ma
from dask.delayed import Delayed, DelayedLeaf, delayed
from shapely.geometry import base, box, mapping
from shapely.geometry import base, box, mapping, shape

from mapchete.config import MapcheteConfig
from mapchete.config.process_func import ProcessFunc
Expand Down Expand Up @@ -46,6 +46,7 @@ class Task(ABC):
result_key_name: str
geometry: Optional[Union[base.BaseGeometry, dict]] = None
bounds: Optional[Bounds] = None
tile: Optional[BufferedTile] = None

def __init__(
self,
Expand Down Expand Up @@ -94,12 +95,8 @@ def add_dependencies(self, dependencies: Dict[str, TaskInfo]) -> None:
)
self.dependencies.update(dependencies)

def execute(self, dependencies: Optional[Dict[str, TaskInfo]] = None) -> TaskInfo:
return TaskInfo(
id=self.id,
output=self.func(*self.fargs, **self.fkwargs),
processed=True,
)
def execute(self, dependencies: Optional[Dict[str, TaskInfo]] = None) -> Any:
return self.func(*self.fargs, **self.fkwargs)

def has_geometry(self) -> bool:
return self.geometry is not None
Expand Down Expand Up @@ -201,6 +198,8 @@ class TileTask(Task):
config_baselevels: ZoomLevels
process = Optional[ProcessFunc]
config_dir = Optional[MPath]
tile: BufferedTile
_dependencies: dict

def __init__(
self,
Expand Down Expand Up @@ -237,12 +236,44 @@ def __init__(
self.output_reader = (
None if skip or not config.baselevels else config.output_reader
)
self._dependencies = dict()
super().__init__(id=self.id, geometry=tile.bbox)

def __repr__(self): # pragma: no cover
return f"TileTask(id={self.id}, tile={self.tile}, bounds={self.bounds})"

def execute(self, dependencies: Optional[dict] = None) -> TaskInfo:
def add_dependency(self, task_key: str, result: Any, raise_error: bool = True):
"""Append preprocessing task result to input."""
# if dependency has geo information, only add if it intersects with task!
try:
if not shape(result).intersects(shape(self)):
logger.debug("dependency does not intersect with task")
return
except AttributeError:
pass

if ":" in task_key:
inp_key, inp_task_key = task_key.split(":")[:2]
else:
raise KeyError(
"preprocessing task cannot be assigned to an input "
f"because of a malformed task key: {task_key}"
)

input_keys = {inp.input_key for inp in self.input.values()}
if inp_key not in input_keys:
if raise_error:
raise KeyError(
f"task {inp_task_key} cannot be assigned to input with key {inp_key} "
f"(available keys: {input_keys})"
)
else:
return

logger.debug("remember preprocessing task (%s) result for execution", task_key)
self._dependencies[task_key] = result

def execute(self, dependencies: Optional[dict] = None) -> Any:
"""
Run the Mapchete process and return the result.
Expand Down Expand Up @@ -272,33 +303,7 @@ def execute(self, dependencies: Optional[dict] = None) -> TaskInfo:
raise MapcheteNodataTile
elif process_output is None:
raise MapcheteProcessOutputError("process output is empty")
return TaskInfo(output=process_output, processed=True, tile=self.tile)

def set_preprocessing_task_result(
self, task_key: str, result: Any, raise_error: bool = False
):
"""Append preprocessing task result to input."""
# TODO: if result has geo information, only add if it intersects with task!
if ":" in task_key:
inp_key, inp_task_key = task_key.split(":")[:2]
else:
raise KeyError(
"preprocessing task cannot be assigned to an input "
f"because of a malformed task key: {task_key}"
)
input_keys = {inp.input_key for inp in self.input.values()}
for inp in self.input.values():
if inp_key == inp.input_key:
break
else: # pragma: no cover
if raise_error:
raise KeyError(
f"task {inp_task_key} cannot be assigned to input with key {inp_key} "
f"(available keys: {input_keys})"
)
else:
return
inp.set_preprocessing_task_result(inp_task_key, result)
return process_output

def _execute(self, dependencies: Optional[Dict[str, TaskInfo]] = None) -> Any:
# If baselevel is active and zoom is outside of baselevel,
Expand All @@ -311,9 +316,13 @@ def _execute(self, dependencies: Optional[Dict[str, TaskInfo]] = None) -> Any:
# Otherwise, execute from process file.
try:
with Timer() as duration:
if self._dependencies:
dependencies.update(self._dependencies)
# append dependent preprocessing task results to input objects
if dependencies:
for task_key, task_result in dependencies.items():
if isinstance(task_result, TaskInfo):
task_result = task_result.output
if not task_key.startswith("tile_task"):
inp_key, task_key = task_key.split(":")[0], ":".join(
task_key.split(":")[1:]
Expand All @@ -323,7 +332,7 @@ def _execute(self, dependencies: Optional[Dict[str, TaskInfo]] = None) -> Any:
for inp in self.input.values():
if inp.input_key == inp_key:
inp.set_preprocessing_task_result(
task_key=task_key, result=task_result.output
task_key=task_key, result=task_result
)
# Actually run process.
process_data = self.process(
Expand Down
6 changes: 1 addition & 5 deletions mapchete/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,7 @@ def __ne__(self, other):
return not self.__eq__(other)

def __repr__(self):
return "BufferedTile(%s, tile_pyramid=%s, pixelbuffer=%s)" % (
self.id,
self.tp,
self.pixelbuffer,
)
return f"BufferedTile(zoom={self.zoom}, row={self.row}, col={self.col})"

def __hash__(self):
return hash(repr(self))
Expand Down
31 changes: 23 additions & 8 deletions test/test_processing_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_baselevels(baselevels):
"""Baselevel interpolation."""
with mapchete.open(baselevels.dict, mode="continue") as mp:
# process data before getting baselevels
list(mp.execute())
list(mp.execute(concurrency=None))

# get tile from lower zoom level
for tile in mp.get_process_tiles(4):
Expand Down Expand Up @@ -679,6 +679,7 @@ def test_execute_continue(
)
else:
execute_kwargs = dict(concurrency=concurrency)

zoom = 3

# run red_raster on tile 1, 0, 0
Expand Down Expand Up @@ -744,23 +745,37 @@ def test_execute_continue(
)


def test_execute_dask_without_results(baselevels, dask_executor):
@pytest.mark.parametrize(
"concurrency,process_graph",
[
("threads", None),
("dask", True),
("dask", False),
("processes", None),
(None, None),
],
)
def test_execute_without_results(baselevels, dask_executor, concurrency, process_graph):
if concurrency == "dask":
execute_kwargs = dict(
executor=dask_executor,
dask_settings=DaskSettings(process_graph=process_graph),
)
else:
execute_kwargs = dict(concurrency=concurrency)

# make sure task results are appended to tasks
with baselevels.mp() as mp:
tile_tasks = 0
for task_info in mp.execute(
executor=dask_executor, dask_settings=DaskSettings(propagate_results=True)
):
for task_info in mp.execute(**execute_kwargs, propagate_results=True):
assert task_info.output is not None
tile_tasks += 1
assert tile_tasks == 6

# make sure task results are None
with baselevels.mp() as mp:
tile_tasks = 0
for task_info in mp.execute(
concurrency="dask", dask_settings=DaskSettings(propagate_results=True)
):
for task_info in mp.execute(**execute_kwargs, propagate_results=False):
assert task_info.output is None
tile_tasks += 1
assert tile_tasks == 6
Expand Down
1 change: 1 addition & 0 deletions test/test_processing_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from shapely.geometry import shape

from mapchete.errors import NoTaskGeometry
from mapchete.executor import Executor
from mapchete.processing.tasks import (
Task,
TaskBatch,
Expand Down

0 comments on commit 8453b61

Please sign in to comment.