Skip to content

Commit

Permalink
add TaskBatches class to rule them all
Browse files Browse the repository at this point in the history
  • Loading branch information
ungarj committed Nov 16, 2023
1 parent 4ef22d5 commit a4774ee
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 47 deletions.
37 changes: 13 additions & 24 deletions mapchete/commands/_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def execute(
raw_conf=raw_conf(mapchete_config),
)

# automatically use dask Executor if dask scheduler is defined
if dask_scheduler or dask_client or concurrency == "dask":
concurrency = "dask"

# be careful opening mapchete not as context manager
with mapchete.open(
mapchete_config,
Expand All @@ -148,30 +152,14 @@ def execute(
all_observers.notify(status=Status.initializing)

# determine tasks
preprocessing_tasks = mp.config.preprocessing_tasks_count()
tiles_tasks = 1 if tile else mp.count_tiles()
total_tasks = preprocessing_tasks + tiles_tasks
all_observers.notify(
message=f"processing {preprocessing_tasks} preprocessing tasks and {tiles_tasks} tile tasks on {workers} worker(s)"
)
if total_tasks == 0:
tasks = mp._task_batches(zoom=zoom, tile=tile, concurrency=concurrency)

if len(tasks) == 0:
all_observers.notify(status=Status.done)
return

# automatically use dask Executor if dask scheduler is defined
if dask_scheduler or dask_client or concurrency == "dask":
concurrency = "dask"

# use sequential Executor if only one tile or only one worker is defined
elif total_tasks == 1 or workers == 1:
logger.debug(
"using sequential Executor because there is only one %s",
"task" if total_tasks == 1 else "worker",
)
concurrency = None
all_observers.notify(message=f"processing X tasks on {workers} worker(s)")

all_observers.notify(message="waiting for executor ...")

with executor_getter(
concurrency=concurrency,
dask_scheduler=dask_scheduler,
Expand All @@ -181,10 +169,10 @@ def execute(
) as executor:
all_observers.notify(
status=Status.running,
progress=Progress(total=total_tasks),
message=f"sending {total_tasks} tasks to {executor} ...",
progress=Progress(total=len(tasks)),
message=f"sending {len(tasks)} tasks to {executor} ...",
)

# TODO it would be nice to track the time it took sending tasks to the executor
try:
for ii, future in enumerate(
mp.compute(
Expand All @@ -197,6 +185,7 @@ def execute(
dask_propagate_results=dask_propagate_results,
profiling=profiling,
),
# executor.compute(tasks),
1,
):
result = TaskResult.from_future(future)
Expand All @@ -216,7 +205,7 @@ def execute(
all_observers.notify(message=msg)

all_observers.notify(
progress=Progress(total=total_tasks, current=ii),
progress=Progress(total=len(tasks), current=ii),
task_result=result,
)

Expand Down
85 changes: 81 additions & 4 deletions mapchete/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_run_on_single_tile,
compute,
)
from mapchete.processing.tasks import TaskBatch, TileTask, TileTaskBatch
from mapchete.processing.tasks import TaskBatch, TaskBatches, TileTask, TileTaskBatch
from mapchete.stac import tile_direcotry_item_to_dict, update_tile_directory_stac_item
from mapchete.tile import BufferedTile, count_tiles
from mapchete.timer import Timer
Expand Down Expand Up @@ -179,6 +179,81 @@ def task_batches(
profilers=profilers,
)

def _task_batches(
self,
zoom: Optional[ZoomLevelsLike] = None,
tile: Optional[TileLike] = None,
concurrency: Concurrency = Concurrency.processes,
no_task_graph: bool = False,
) -> TaskBatches:
"""
Determine work todo and return as collection of tasks.
Depending on the settings, this will return either a task graph or
task batches in case layers have dependencies, or a single large
batch of tasks.
TODO: streaming?
"""
profilers = []
skip_output_check = False
dask_propagate_results = True

# if no_task_graph:
# graph = False
# else:
# graph = concurrency == Concurrency.dask

# first, get task batches
task_batches = self.task_batches(
zoom=zoom,
tile=tile,
skip_output_check=skip_output_check,
propagate_results=self.config.output.write_in_parent_process
or dask_propagate_results,
profilers=profilers,
)

# better to materialize them now, because we have to see what can be thrown away
# this depends on the process mode
task_collection = TaskBatches(
task_batches=task_batches,
mode=self.config.mode,
)

# under certain conditions, we can avoid preserving dependencies between tasks
# and even don't bother doing graph processing:
# - no baselevels and no preprocessing tasks
# - only one zoom level and no preprocessing tasks
# preserve_dependencies = True

return task_collection

def execute_task_collection(
self,
executor: ExecutorBase,
task_collection: TaskBatches,
):
raise NotImplementedError

def execute(
self,
executor: ExecutorBase,
zoom: Optional[ZoomLevelsLike] = None,
tile: Optional[TileLike] = None,
concurrency: Concurrency = Concurrency.processes,
no_task_graph: bool = False,
) -> None:
self.execute_task_collection(
self.task_batches(
zoom=zoom,
tile=tile,
concurrency=concurrency,
no_task_graph=no_task_graph,
),
executor=executor,
)

def compute(
self,
zoom: Optional[ZoomLevelsLike] = None,
Expand Down Expand Up @@ -476,7 +551,9 @@ def count_tiles(
logger.debug("tiles counted in %s", t)
return self._count_tiles_cache[(minzoom, maxzoom)]

def execute(self, process_tile: BufferedTile, raise_nodata: bool = False) -> Any:
def execute_tile(
self, process_tile: BufferedTile, raise_nodata: bool = False
) -> Any:
"""
Run Mapchete process on a tile.
Expand Down Expand Up @@ -694,7 +771,7 @@ def _process_and_overwrite_output(self, tile, process_tile):
if self.with_cache:
output = self._execute_using_cache(process_tile)
else:
output = self.execute(process_tile)
output = self.execute_tile(process_tile)

self.write(process_tile, output)
return self._extract(in_tile=process_tile, in_data=output, out_tile=tile)
Expand Down Expand Up @@ -723,7 +800,7 @@ def _execute_using_cache(self, process_tile):
return self.process_tile_cache[process_tile.id]
else:
try:
output = self.execute(process_tile)
output = self.execute_tile(process_tile)
self.process_tile_cache[process_tile.id] = output
if self.config.mode in ["continue", "overwrite"]:
self.write(process_tile, output)
Expand Down
69 changes: 66 additions & 3 deletions mapchete/processing/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, Iterator, List, Optional
from typing import Callable, Iterator, List, Optional, Union
from uuid import uuid4

from dask.delayed import Delayed, DelayedLeaf, delayed
from shapely.geometry import box, mapping

from mapchete.enums import ProcessingMode
from mapchete.errors import (
MapcheteNodataTile,
MapcheteProcessOutputError,
Expand Down Expand Up @@ -403,9 +405,70 @@ def _validate(self, items):
yield item.tile, item


def to_dask_collection(batches):
from dask.delayed import delayed
class TaskBatches:
_len: int = None
_task_batches_generator: Iterator[Union[TaskBatch, TileTaskBatch]]
preprocessing_batches: List[TaskBatch]
tile_batches: List[TileTaskBatch]
materialized: bool = False

def __init__(
self,
task_batches: Iterator[Union[TaskBatch, TileTaskBatch]],
mode: ProcessingMode = ProcessingMode.CONTINUE,
):
self._task_batches_generator = task_batches

def __len__(self):
return 1
if self._len is None:
raise AttributeError("cannot determine size of TaskCollection yet")
return self._len

def materialize(self):
if self.materialized:
return
self._preprocessing_batches = []
self._tile_batches = []
for batch in self._task_batches_generator:
if isinstance(batch, TileTaskBatch):
self.tile_batches.append(batch)
else:
self.preprocessing_batches.append(batch)
self.materialized = True

@property
def preprocessing_batches(self) -> List[TaskBatch]:
self.materialize()
return self._preprocessing_batches

@property
def tile_batches(self) -> List[TileTaskBatch]:
self.materialize()
return self._tile_batches

def clean_up(self) -> None:
raise NotImplementedError

def as_dask_graph(self) -> List[Union[Delayed, DelayedLeaf]]:
return to_dask_collection(
(
batch
for phase in (self.preprocessing_batches, self.tile_batches)
for batch in phase
)
)

def as_one_batch(self) -> List[Task]:
raise NotImplementedError

def as_layered_batches(self) -> List[List[Task]]:
raise NotImplementedError


def to_dask_collection(
batches: Iterator[Union[TaskBatch, TileTaskBatch]],
) -> List[Union[Delayed, DelayedLeaf]]:
tasks = {}
previous_batch = None
for batch in batches:
Expand Down
3 changes: 2 additions & 1 deletion test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ def test_init_overrides_config(example_mapchete):
def test_custom_process(example_custom_process_mapchete):
with mapchete.open(example_custom_process_mapchete.dict) as mp:
assert (
mp.execute(example_custom_process_mapchete.first_process_tile()) is not None
mp.execute_tile(example_custom_process_mapchete.first_process_tile())
is not None
)


Expand Down
8 changes: 4 additions & 4 deletions test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def test_execute(example_mapchete):
# in readonly mode
with mapchete.open(example_mapchete.dict, mode="readonly") as mp:
with pytest.raises(AttributeError):
mp.execute(next(mp.get_process_tiles()))
mp.execute_tile(next(mp.get_process_tiles()))
# wrong tile type
with mapchete.open(example_mapchete.dict) as mp:
with pytest.raises(TypeError):
mp.execute("invalid")
mp.execute_tile("invalid")


def test_read(example_mapchete):
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_process_exception(mp_tmpdir, cleantopo_br, process_error_py):
config.update(process=process_error_py)
with mapchete.open(config) as mp:
with pytest.raises(AssertionError):
mp.execute((5, 0, 0))
mp.execute_tile((5, 0, 0))


def test_output_error(mp_tmpdir, cleantopo_br, output_error_py):
Expand All @@ -287,7 +287,7 @@ def test_output_error(mp_tmpdir, cleantopo_br, output_error_py):
config.update(process=output_error_py)
with mapchete.open(config) as mp:
with pytest.raises(errors.MapcheteProcessOutputError):
mp.execute((5, 0, 0))
mp.execute_tile((5, 0, 0))


def _raise_error(i):
Expand Down
2 changes: 1 addition & 1 deletion test/test_formats_flatgeobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_s3_output_data_rw(flatgeobuf_s3):
# write empty
mp.write(tile, None)
# write data
raw_output = mp.execute(tile)
raw_output = mp.execute_tile(tile)
assert isinstance(raw_output, list)
assert len(raw_output)
mp.write(tile, raw_output)
Expand Down
2 changes: 1 addition & 1 deletion test/test_formats_geobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_s3_output_data_rw(geobuf_s3):
# write empty
mp.write(tile, None)
# write data
raw_output = mp.execute(tile)
raw_output = mp.execute_tile(tile)
mp.write(tile, raw_output)
# read data
read_output = mp.get_raw_output(tile)
Expand Down
2 changes: 1 addition & 1 deletion test/test_formats_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_s3_output_data_rw(geojson_s3):
# write empty
mp.write(tile, None)
# write data
raw_output = mp.execute(tile)
raw_output = mp.execute_tile(tile)
mp.write(tile, raw_output)
# read data
read_output = mp.get_raw_output(tile)
Expand Down
2 changes: 1 addition & 1 deletion test/test_formats_geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_write_geotiff_tags(mp_tmpdir, cleantopo_br, write_rasterfile_tags_py):
conf.update(process=write_rasterfile_tags_py)
with mapchete.open(conf) as mp:
for tile in mp.get_process_tiles():
data, tags = mp.execute(tile)
data, tags = mp.execute_tile(tile)
assert data.any()
assert isinstance(tags, dict)
mp.write(process_tile=tile, data=(data, tags))
Expand Down
2 changes: 1 addition & 1 deletion test/test_io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def test_output_s3_single_gtiff_error(output_s3_single_gtiff_error):
# the process file will raise an exception on purpose
with pytest.raises(AssertionError):
with output_s3_single_gtiff_error.mp() as mp:
mp.execute(output_s3_single_gtiff_error.first_process_tile())
mp.execute_tile(output_s3_single_gtiff_error.first_process_tile())
# make sure no output has been written
assert not path_exists(mp.config.output.path)

Expand Down
Loading

0 comments on commit a4774ee

Please sign in to comment.