Skip to content

Commit

Permalink
Ensure unique keys for MPUChunk objects
Browse files Browse the repository at this point in the history
Add .tk to MPUChunk to distinguish empty chunks from different
sources.
  • Loading branch information
Kirill888 committed Nov 5, 2023
1 parent 2fa7f90 commit 48d6edb
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions odc/geo/cog/_mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

from functools import partial
from uuid import uuid4
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -73,6 +74,7 @@ class MPUChunk:
"observed",
"is_final",
"lhs_keep",
"tk",
)

def __init__(
Expand All @@ -85,7 +87,11 @@ def __init__(
observed: Optional[List[Tuple[int, Any]]] = None,
is_final: bool = False,
lhs_keep: int = 0,
tk: str | None = None
) -> None:
if tk is None:
tk = uuid4().hex

self.nextPartId = partId
self.write_credits = write_credits
self.data = bytearray() if data is None else data
Expand All @@ -94,6 +100,7 @@ def __init__(
self.observed: List[Tuple[int, Any]] = [] if observed is None else observed
self.is_final = is_final
self.lhs_keep = lhs_keep
self.tk = tk
# if supplying data must also supply observed
assert data is None or (observed is not None and len(observed) > 0)

Expand All @@ -107,6 +114,7 @@ def __dask_tokenize__(self):
self.parts,
self.observed,
self.is_final,
self.tk,
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -152,6 +160,7 @@ def merge(
lhs.observed + rhs.observed,
rhs.is_final,
lhs.lhs_keep,
lhs.tk,
)

# Flush `lhs.data + rhs.left_data` if we can
Expand All @@ -167,6 +176,7 @@ def merge(
lhs.observed + rhs.observed,
rhs.is_final,
lhs.lhs_keep,
lhs.tk,
)

def flush_rhs(
Expand Down Expand Up @@ -296,6 +306,7 @@ def gen_bunch(
writes_per_chunk: int = 1,
mark_final: bool = False,
lhs_keep: int = 0,
tk: str | None = None,
) -> Iterator["MPUChunk"]:
for idx in range(n):
is_final = mark_final and idx == (n - 1)
Expand All @@ -304,6 +315,7 @@ def gen_bunch(
writes_per_chunk,
is_final=is_final,
lhs_keep=lhs_keep,
tk=tk,
)

@staticmethod
Expand All @@ -320,6 +332,9 @@ def from_dask_bag(
) -> "dask.bag.Item":
# pylint: disable=import-outside-toplevel
import dask.bag
from dask.base import tokenize

tk = tokenize(partId, chunks, writes_per_chunk, lhs_keep, write, spill_sz, split_every)

mpus = dask.bag.from_sequence(
MPUChunk.gen_bunch(
Expand All @@ -328,6 +343,7 @@ def from_dask_bag(
writes_per_chunk=writes_per_chunk,
mark_final=mark_final,
lhs_keep=lhs_keep,
tk=tk,
),
npartitions=chunks.npartitions,
)
Expand Down

0 comments on commit 48d6edb

Please sign in to comment.