Skip to content

Commit

Permalink
Avoid aliasing in MPUChunk.merge
Browse files Browse the repository at this point in the history
Merging two chunks should not modify inputs, or return
chunk that aliases members of inputs.
  • Loading branch information
Kirill888 committed Nov 5, 2023
1 parent 48d6edb commit 48cfa71
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions odc/geo/cog/_mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

from copy import copy
from functools import partial
from uuid import uuid4
from typing import (
Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
observed: Optional[List[Tuple[int, Any]]] = None,
is_final: bool = False,
lhs_keep: int = 0,
tk: str | None = None
tk: str | None = None,
) -> None:
if tk is None:
tk = uuid4().hex
Expand All @@ -104,6 +105,19 @@ def __init__(
# if supplying data must also supply observed
assert data is None or (observed is not None and len(observed) > 0)

def clone(self) -> "MPUChunk":
return MPUChunk(
self.nextPartId,
self.write_credits,
copy(self.data),
copy(self.left_data),
copy(self.parts),
copy(self.observed),
self.is_final,
self.lhs_keep,
self.tk,
)

def __dask_tokenize__(self):
return (
"MPUChunk",
Expand Down Expand Up @@ -155,8 +169,8 @@ def merge(
lhs.nextPartId,
lhs.write_credits + rhs.write_credits,
lhs.data + rhs.data,
lhs.left_data,
lhs.parts,
copy(lhs.left_data),
copy(lhs.parts),
lhs.observed + rhs.observed,
rhs.is_final,
lhs.lhs_keep,
Expand All @@ -165,12 +179,13 @@ def merge(

# Flush `lhs.data + rhs.left_data` if we can
# or else move it into .left_data
lhs = lhs.clone()
lhs.flush_rhs(write, rhs.left_data)

return MPUChunk(
rhs.nextPartId,
rhs.write_credits,
rhs.data,
copy(rhs.data),
lhs.left_data,
lhs.parts + rhs.parts,
lhs.observed + rhs.observed,
Expand Down Expand Up @@ -334,7 +349,9 @@ def from_dask_bag(
import dask.bag
from dask.base import tokenize

tk = tokenize(partId, chunks, writes_per_chunk, lhs_keep, write, spill_sz, split_every)
tk = tokenize(
partId, chunks, writes_per_chunk, lhs_keep, write, spill_sz, split_every
)

mpus = dask.bag.from_sequence(
MPUChunk.gen_bunch(
Expand Down

0 comments on commit 48cfa71

Please sign in to comment.