Skip to content

Commit

Permalink
amend me: s3 chunked writer
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Sep 28, 2023
1 parent ddfde94 commit 6000b2e
Show file tree
Hide file tree
Showing 2 changed files with 658 additions and 0 deletions.
394 changes: 394 additions & 0 deletions odc/geo/cog/_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
"""
S3 utils for COG to S3.
"""
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)

from cachetools import cached
from dask import bag as dask_bag

SomeData = Union[bytes, bytearray]
_5MB = 5 * (1 << 20)
PartsWriter = Callable[[int, bytearray], Dict[str, Any]]
if TYPE_CHECKING:
from botocore.credentials import ReadOnlyCredentials


class MPUChunk:
"""
chunk cache and writer
"""

# pylint: disable=too-many-arguments

__slots__ = (
"nextPartId",
"write_credits",
"data",
"left_data",
"parts",
"observed",
"is_final",
)

def __init__(
self,
partId: int,
write_credits: int,
data: Optional[bytearray] = None,
left_data: Optional[bytearray] = None,
parts: Optional[List[Dict[str, Any]]] = None,
observed: Optional[List[Tuple[int, Any]]] = None,
is_final: bool = False,
) -> None:
assert (partId == 0 and write_credits == 0) or (
partId >= 1 and write_credits >= 0 and partId + write_credits <= 10_000
)

self.nextPartId = partId
self.write_credits = write_credits
self.data = bytearray() if data is None else data
self.left_data = left_data
self.parts: List[Dict[str, Any]] = [] if parts is None else parts
self.observed: List[Tuple[int, Any]] = [] if observed is None else observed
self.is_final = is_final

def __dask_tokenize__(self):
return (
"MPUChunk",
self.nextPartId,
self.write_credits,
self.data,
self.left_data,
self.parts,
self.observed,
self.is_final,
)

def __repr__(self) -> str:
s = f"MPUChunk: {self.nextPartId}#{self.write_credits} cache: {len(self.data)}"
if self.observed:
s = f"{s} observed[{len(self.observed)}]"
if self.parts:
s = f"{s} parts: [{len(self.parts)}]"
if self.is_final:
s = f"{s} final"
return s

def append(self, data: SomeData, chunk_id: Any = None):
sz = len(data)
self.observed.append((sz, chunk_id))
self.data += data

@property
def started_write(self) -> bool:
return len(self.parts) > 0

@staticmethod
def merge(
lhs: "MPUChunk",
rhs: "MPUChunk",
write: Optional[PartsWriter] = None,
) -> "MPUChunk":
"""
If ``write=`` is not provided but flush is needed, RuntimeError will be raised.
"""
if not rhs.started_write:
# no writes on the right
# Just append
assert rhs.left_data is None
assert len(rhs.parts) == 0

return MPUChunk(
lhs.nextPartId,
lhs.write_credits + rhs.write_credits,
lhs.data + rhs.data,
lhs.left_data,
lhs.parts,
lhs.observed + rhs.observed,
rhs.is_final,
)

# Flush `lhs.data + rhs.left_data` if we can
# or else move it into .left_data
lhs.final_flush(write, rhs.left_data)

return MPUChunk(
rhs.nextPartId,
rhs.write_credits,
rhs.data,
lhs.left_data,
lhs.parts + rhs.parts,
lhs.observed + rhs.observed,
rhs.is_final,
)

def final_flush(
self, write: Optional[PartsWriter], extra_data: Optional[bytearray] = None
) -> int:
data = self.data
if extra_data is not None:
data += extra_data

def _flush_data(do_write: PartsWriter):
part = do_write(self.nextPartId, data)
self.parts.append(part)
self.data = bytearray()
self.nextPartId += 1
self.write_credits -= 1
return len(data)

# Have enough write credits
# AND (have enough bytes OR it's the last chunk)
can_flush = self.write_credits > 0 and (self.is_final or len(data) >= _5MB)

if self.started_write:
# When starting to write we ensure that there is always enough
# data and write credits left to flush the remainder
#
# User must have provided `write` function
assert can_flush is True

if write is None:
raise RuntimeError("Flush required but no writer provided")

return _flush_data(write)

# Haven't started writing yet
# - Flush if possible and writer is provided
# - OR just move all the data to .left_data section
if can_flush and write is not None:
return _flush_data(write)

if self.left_data is None:
self.left_data, self.data = data, bytearray()
else:
self.left_data, self.data = self.left_data + data, bytearray()

return 0

def maybe_write(self, write: PartsWriter, spill_sz: int) -> int:
# if not last section keep 5MB and 1 partId around after flush
bytes_to_keep, parts_to_keep = (0, 0) if self.is_final else (_5MB, 1)

if self.write_credits - 1 < parts_to_keep:
return 0

bytes_to_write = len(self.data) - bytes_to_keep
if bytes_to_write < spill_sz:
return 0

part = write(self.nextPartId, self.data[:bytes_to_write])

self.parts.append(part)
self.data = self.data[bytes_to_write:]
self.nextPartId += 1
self.write_credits -= 1

return bytes_to_write

@staticmethod
def gen_bunch(
partId: int,
n: int,
*,
writes_per_chunk: int = 1,
mark_final: bool = False,
) -> Iterator["MPUChunk"]:
for idx in range(n):
is_final = mark_final and idx == (n - 1)
yield MPUChunk(
partId + idx * writes_per_chunk, writes_per_chunk, is_final=is_final
)

@staticmethod
def from_dask_bag(
partId: int,
chunks: dask_bag.Bag,
*,
writes_per_chunk: int = 1,
mark_final: bool = False,
write: Optional[PartsWriter] = None,
spill_sz: int = 0,
) -> dask_bag.Item:
mpus = dask_bag.from_sequence(
MPUChunk.gen_bunch(
partId,
chunks.npartitions,
writes_per_chunk=writes_per_chunk,
mark_final=mark_final,
),
npartitions=chunks.npartitions,
)

mpus = dask_bag.map_partitions(
_mpu_append_chunks_op,
mpus,
chunks,
token="mpu.append",
)

return mpus.fold(partial(_merge_and_spill_op, write=write, spill_sz=spill_sz))


class MultiPartUpload:
"""
Dask to S3 dumper.
"""

def __init__(
self,
bucket: str,
key: str,
*,
uploadId: str = "",
profile: Optional[str] = None,
endpoint_url: Optional[str] = None,
creds: Optional["ReadOnlyCredentials"] = None,
):
self.bucket = bucket
self.key = key
self.uploadId = uploadId
self.profile = profile
self.endpoint_url = endpoint_url
self.creds = creds

# @cached({}, key=lambda _self: (_self.profile, _self.endpoint_url, _self.creds))
@cached({})
def s3_client(self):
# pylint: disable=import-outside-toplevel,import-error
from botocore.session import Session

sess = Session(profile=self.profile)
creds = self.creds
if creds is None:
return sess.create_client("s3", endpoint_url=self.endpoint_url)
return sess.create_client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=creds.access_key,
aws_secret_access_key=creds.secret_key,
aws_session_token=creds.token,
)

def __call__(self, partId: int, data: bytearray) -> Dict[str, Any]:
s3 = self.s3_client()
assert self.uploadId != ""
rr = s3.upload_part(
PartNumber=partId,
Body=data,
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
)
etag = rr["ETag"]
return {"PartNumber": partId, "ETag": etag}

def initiate(self) -> str:
assert self.uploadId == ""
s3 = self.s3_client()

rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key)
uploadId = rr["UploadId"]
self.uploadId = uploadId
return uploadId

@property
def started(self) -> bool:
return len(self.uploadId) > 0

def cancel(self, other: str = ""):
uploadId = other if other else self.uploadId
if not uploadId:
return

s3 = self.s3_client()
s3.abort_multipart_upload(Bucket=self.bucket, Key=self.key, UploadId=uploadId)
if uploadId == self.uploadId:
self.uploadId = ""

def complete(self, root: MPUChunk) -> str:
s3 = self.s3_client()
rr = s3.complete_multipart_upload(
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
MultipartUpload={"Parts": root.parts},
)

return rr["ETag"]

def list_active(self):
s3 = self.s3_client()
rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key)
return [x["UploadId"] for x in rr.get("Uploads", [])]

def read(self, **kw):
s3 = self.s3_client()
return s3.get_object(Bucket=self.bucket, Key=self.key, **kw)["Body"].read()

def __dask_tokenize__(self):
return (
self.bucket,
self.key,
self.uploadId,
)

def substream(
self,
partId: int,
chunks: dask_bag.Bag,
*,
writes_per_chunk: int = 1,
mark_final: bool = False,
spill_sz: int = 20 * (1 << 20),
) -> dask_bag.Item:
write: Optional[PartsWriter] = None
if spill_sz > 0:
if not self.started:
self.initiate()
write = self
return MPUChunk.from_dask_bag(
partId,
chunks,
writes_per_chunk=writes_per_chunk,
mark_final=mark_final,
spill_sz=spill_sz,
write=write,
)


def _mpu_append_chunks_op(
mpus: Iterable[MPUChunk], chunks: Iterable[Tuple[bytes, Any]]
):
# expect 1 MPUChunk per partition
(mpu,) = mpus
for chunk in chunks:
data, chunk_id = chunk
mpu.append(data, chunk_id)
return [mpu]


def _merge_and_spill_op(
lhs: MPUChunk,
rhs: MPUChunk,
write: Optional[PartsWriter] = None,
spill_sz: int = 0,
) -> MPUChunk:
mm = MPUChunk.merge(lhs, rhs, write)
if write is None or spill_sz == 0:
return mm

mm.maybe_write(write, spill_sz)
return mm
Loading

0 comments on commit 6000b2e

Please sign in to comment.