Skip to content

Commit

Permalink
Refactor mpu remove S3 specifics
Browse files Browse the repository at this point in the history
- use protocol instead of simple callback
- writer comes with extra info about constraints
  • Loading branch information
Kirill888 committed Sep 28, 2023
1 parent c236c30 commit b8c35a9
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 83 deletions.
57 changes: 41 additions & 16 deletions odc/geo/cog/_mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,37 @@
Multi-part upload as a graph
"""
from functools import partial
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Iterator, List, Optional, Protocol, Tuple, Union

from dask import bag as dask_bag

SomeData = Union[bytes, bytearray]
_5MB = 5 * (1 << 20)
PartsWriter = Callable[[int, bytearray], Dict[str, Any]]


class PartsWriter(Protocol):
"""Protocol for labeled parts data writer."""

def __call__(self, part: int, data: SomeData) -> Dict[str, Any]:
...

def finalise(self, parts: List[Dict[str, Any]]) -> Any:
...

@property
def min_write_sz(self) -> int:
...

@property
def max_write_sz(self) -> int:
...

@property
def min_part(self) -> int:
...

@property
def max_part(self) -> int:
...


class MPUChunk:
Expand Down Expand Up @@ -38,10 +62,6 @@ def __init__(
observed: Optional[List[Tuple[int, Any]]] = None,
is_final: bool = False,
) -> None:
assert (partId == 0 and write_credits == 0) or (
partId >= 1 and write_credits >= 0 and partId + write_credits <= 10_000
)

self.nextPartId = partId
self.write_credits = write_credits
self.data = bytearray() if data is None else data
Expand Down Expand Up @@ -127,34 +147,37 @@ def final_flush(
if extra_data is not None:
data += extra_data

def _flush_data(do_write: PartsWriter):
part = do_write(self.nextPartId, data)
def _flush_data(pw: PartsWriter):
assert pw.min_part <= self.nextPartId <= pw.max_part

part = pw(self.nextPartId, data)
self.parts.append(part)
self.data = bytearray()
self.nextPartId += 1
self.write_credits -= 1
return len(data)

# Have enough write credits
# AND (have enough bytes OR it's the last chunk)
can_flush = self.write_credits > 0 and (self.is_final or len(data) >= _5MB)
def can_flush(pw: PartsWriter):
return self.write_credits > 0 and (
self.is_final or len(data) >= pw.min_write_sz
)

if self.started_write:
# When starting to write we ensure that there is always enough
# data and write credits left to flush the remainder
#
# User must have provided `write` function
assert can_flush is True

if write is None:
raise RuntimeError("Flush required but no writer provided")

assert can_flush(write)
return _flush_data(write)

# Haven't started writing yet
# - Flush if possible and writer is provided
# - OR just move all the data to .left_data section
if can_flush and write is not None:
if write is not None and can_flush(write):
return _flush_data(write)

if self.left_data is None:
Expand All @@ -165,8 +188,10 @@ def _flush_data(do_write: PartsWriter):
return 0

def maybe_write(self, write: PartsWriter, spill_sz: int) -> int:
# if not last section keep 5MB and 1 partId around after flush
bytes_to_keep, parts_to_keep = (0, 0) if self.is_final else (_5MB, 1)
# if not last section keep 'min_write_sz' and 1 partId around after flush
bytes_to_keep, parts_to_keep = (
(0, 0) if self.is_final else (write.min_write_sz, 1)
)

if self.write_credits - 1 < parts_to_keep:
return 0
Expand Down
63 changes: 41 additions & 22 deletions odc/geo/cog/_s3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
S3 utils for COG to S3.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from cachetools import cached
from dask import bag as dask_bag

from ._mpu import MPUChunk, PartsWriter
from ._mpu import MPUChunk, PartsWriter, SomeData

if TYPE_CHECKING:
from botocore.credentials import ReadOnlyCredentials
Expand Down Expand Up @@ -52,27 +52,41 @@ def s3_client(self):
aws_session_token=creds.token,
)

def __call__(self, partId: int, data: bytearray) -> Dict[str, Any]:
def initiate(self) -> str:
assert self.uploadId == ""
s3 = self.s3_client()

rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key)
uploadId = rr["UploadId"]
self.uploadId = uploadId
return uploadId

def __call__(self, part: int, data: SomeData) -> Dict[str, Any]:
s3 = self.s3_client()
assert self.uploadId != ""
rr = s3.upload_part(
PartNumber=partId,
PartNumber=part,
Body=data,
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
)
etag = rr["ETag"]
return {"PartNumber": partId, "ETag": etag}
return {"PartNumber": part, "ETag": etag}

def initiate(self) -> str:
assert self.uploadId == ""
def finalise(self, parts: List[Dict[str, Any]]) -> str:
s3 = self.s3_client()
rr = s3.complete_multipart_upload(
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
MultipartUpload={"Parts": parts},
)

rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key)
uploadId = rr["UploadId"]
self.uploadId = uploadId
return uploadId
return rr["ETag"]

def complete(self, root: MPUChunk) -> str:
return self.finalise(root.parts)

@property
def started(self) -> bool:
Expand All @@ -88,17 +102,6 @@ def cancel(self, other: str = ""):
if uploadId == self.uploadId:
self.uploadId = ""

def complete(self, root: MPUChunk) -> str:
s3 = self.s3_client()
rr = s3.complete_multipart_upload(
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
MultipartUpload={"Parts": root.parts},
)

return rr["ETag"]

def list_active(self):
s3 = self.s3_client()
rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key)
Expand Down Expand Up @@ -137,3 +140,19 @@ def substream(
spill_sz=spill_sz,
write=write,
)

@property
def min_write_sz(self) -> int:
return 5 * (1 << 20)

@property
def max_write_sz(self) -> int:
return 5 * (1 << 30)

@property
def min_part(self) -> int:
return 1

@property
def max_part(self) -> int:
return 10_000
Loading

0 comments on commit b8c35a9

Please sign in to comment.