From 6000b2ed66cfab2dd068084e77a3d4cf24ed21a0 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 25 Sep 2023 22:45:12 +1000 Subject: [PATCH] amend me: s3 chunked writer --- odc/geo/cog/_s3.py | 394 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_s3.py | 264 ++++++++++++++++++++++++++++++ 2 files changed, 658 insertions(+) create mode 100644 odc/geo/cog/_s3.py create mode 100644 tests/test_s3.py diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py new file mode 100644 index 00000000..c0b695be --- /dev/null +++ b/odc/geo/cog/_s3.py @@ -0,0 +1,394 @@ +""" +S3 utils for COG to S3. +""" +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) + +from cachetools import cached +from dask import bag as dask_bag + +SomeData = Union[bytes, bytearray] +_5MB = 5 * (1 << 20) +PartsWriter = Callable[[int, bytearray], Dict[str, Any]] +if TYPE_CHECKING: + from botocore.credentials import ReadOnlyCredentials + + +class MPUChunk: + """ + chunk cache and writer + """ + + # pylint: disable=too-many-arguments + + __slots__ = ( + "nextPartId", + "write_credits", + "data", + "left_data", + "parts", + "observed", + "is_final", + ) + + def __init__( + self, + partId: int, + write_credits: int, + data: Optional[bytearray] = None, + left_data: Optional[bytearray] = None, + parts: Optional[List[Dict[str, Any]]] = None, + observed: Optional[List[Tuple[int, Any]]] = None, + is_final: bool = False, + ) -> None: + assert (partId == 0 and write_credits == 0) or ( + partId >= 1 and write_credits >= 0 and partId + write_credits <= 10_000 + ) + + self.nextPartId = partId + self.write_credits = write_credits + self.data = bytearray() if data is None else data + self.left_data = left_data + self.parts: List[Dict[str, Any]] = [] if parts is None else parts + self.observed: List[Tuple[int, Any]] = [] if observed is None else observed + self.is_final = is_final + + def __dask_tokenize__(self): + return ( + "MPUChunk", + self.nextPartId, + self.write_credits, + self.data, + self.left_data, + self.parts, + self.observed, + self.is_final, + ) + + def __repr__(self) -> str: + s = f"MPUChunk: {self.nextPartId}#{self.write_credits} cache: {len(self.data)}" + if self.observed: + s = f"{s} observed[{len(self.observed)}]" + if self.parts: + s = f"{s} parts: [{len(self.parts)}]" + if self.is_final: + s = f"{s} final" + return s + + def append(self, data: SomeData, chunk_id: Any = None): + sz = len(data) + self.observed.append((sz, chunk_id)) + self.data += data + + @property + def started_write(self) -> bool: + return len(self.parts) > 0 + + @staticmethod + def merge( + lhs: "MPUChunk", + rhs: "MPUChunk", + write: Optional[PartsWriter] = None, + ) -> "MPUChunk": + """ + If ``write=`` is not provided but flush is needed, RuntimeError will be raised. + """ + if not rhs.started_write: + # no writes on the right + # Just append + assert rhs.left_data is None + assert len(rhs.parts) == 0 + + return MPUChunk( + lhs.nextPartId, + lhs.write_credits + rhs.write_credits, + lhs.data + rhs.data, + lhs.left_data, + lhs.parts, + lhs.observed + rhs.observed, + rhs.is_final, + ) + + # Flush `lhs.data + rhs.left_data` if we can + # or else move it into .left_data + lhs.final_flush(write, rhs.left_data) + + return MPUChunk( + rhs.nextPartId, + rhs.write_credits, + rhs.data, + lhs.left_data, + lhs.parts + rhs.parts, + lhs.observed + rhs.observed, + rhs.is_final, + ) + + def final_flush( + self, write: Optional[PartsWriter], extra_data: Optional[bytearray] = None + ) -> int: + data = self.data + if extra_data is not None: + data += extra_data + + def _flush_data(do_write: PartsWriter): + part = do_write(self.nextPartId, data) + self.parts.append(part) + self.data = bytearray() + self.nextPartId += 1 + self.write_credits -= 1 + return len(data) + + # Have enough write credits + # AND (have enough bytes OR it's the last chunk) + can_flush = self.write_credits > 0 and (self.is_final or len(data) >= _5MB) + + if self.started_write: + # When starting to write we ensure that there is always enough + # data and write credits left to flush the remainder + # + # User must have provided `write` function + assert can_flush is True + + if write is None: + raise RuntimeError("Flush required but no writer provided") + + return _flush_data(write) + + # Haven't started writing yet + # - Flush if possible and writer is provided + # - OR just move all the data to .left_data section + if can_flush and write is not None: + return _flush_data(write) + + if self.left_data is None: + self.left_data, self.data = data, bytearray() + else: + self.left_data, self.data = self.left_data + data, bytearray() + + return 0 + + def maybe_write(self, write: PartsWriter, spill_sz: int) -> int: + # if not last section keep 5MB and 1 partId around after flush + bytes_to_keep, parts_to_keep = (0, 0) if self.is_final else (_5MB, 1) + + if self.write_credits - 1 < parts_to_keep: + return 0 + + bytes_to_write = len(self.data) - bytes_to_keep + if bytes_to_write < spill_sz: + return 0 + + part = write(self.nextPartId, self.data[:bytes_to_write]) + + self.parts.append(part) + self.data = self.data[bytes_to_write:] + self.nextPartId += 1 + self.write_credits -= 1 + + return bytes_to_write + + @staticmethod + def gen_bunch( + partId: int, + n: int, + *, + writes_per_chunk: int = 1, + mark_final: bool = False, + ) -> Iterator["MPUChunk"]: + for idx in range(n): + is_final = mark_final and idx == (n - 1) + yield MPUChunk( + partId + idx * writes_per_chunk, writes_per_chunk, is_final=is_final + ) + + @staticmethod + def from_dask_bag( + partId: int, + chunks: dask_bag.Bag, + *, + writes_per_chunk: int = 1, + mark_final: bool = False, + write: Optional[PartsWriter] = None, + spill_sz: int = 0, + ) -> dask_bag.Item: + mpus = dask_bag.from_sequence( + MPUChunk.gen_bunch( + partId, + chunks.npartitions, + writes_per_chunk=writes_per_chunk, + mark_final=mark_final, + ), + npartitions=chunks.npartitions, + ) + + mpus = dask_bag.map_partitions( + _mpu_append_chunks_op, + mpus, + chunks, + token="mpu.append", + ) + + return mpus.fold(partial(_merge_and_spill_op, write=write, spill_sz=spill_sz)) + + +class MultiPartUpload: + """ + Dask to S3 dumper. + """ + + def __init__( + self, + bucket: str, + key: str, + *, + uploadId: str = "", + profile: Optional[str] = None, + endpoint_url: Optional[str] = None, + creds: Optional["ReadOnlyCredentials"] = None, + ): + self.bucket = bucket + self.key = key + self.uploadId = uploadId + self.profile = profile + self.endpoint_url = endpoint_url + self.creds = creds + + # @cached({}, key=lambda _self: (_self.profile, _self.endpoint_url, _self.creds)) + @cached({}) + def s3_client(self): + # pylint: disable=import-outside-toplevel,import-error + from botocore.session import Session + + sess = Session(profile=self.profile) + creds = self.creds + if creds is None: + return sess.create_client("s3", endpoint_url=self.endpoint_url) + return sess.create_client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=creds.access_key, + aws_secret_access_key=creds.secret_key, + aws_session_token=creds.token, + ) + + def __call__(self, partId: int, data: bytearray) -> Dict[str, Any]: + s3 = self.s3_client() + assert self.uploadId != "" + rr = s3.upload_part( + PartNumber=partId, + Body=data, + Bucket=self.bucket, + Key=self.key, + UploadId=self.uploadId, + ) + etag = rr["ETag"] + return {"PartNumber": partId, "ETag": etag} + + def initiate(self) -> str: + assert self.uploadId == "" + s3 = self.s3_client() + + rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key) + uploadId = rr["UploadId"] + self.uploadId = uploadId + return uploadId + + @property + def started(self) -> bool: + return len(self.uploadId) > 0 + + def cancel(self, other: str = ""): + uploadId = other if other else self.uploadId + if not uploadId: + return + + s3 = self.s3_client() + s3.abort_multipart_upload(Bucket=self.bucket, Key=self.key, UploadId=uploadId) + if uploadId == self.uploadId: + self.uploadId = "" + + def complete(self, root: MPUChunk) -> str: + s3 = self.s3_client() + rr = s3.complete_multipart_upload( + Bucket=self.bucket, + Key=self.key, + UploadId=self.uploadId, + MultipartUpload={"Parts": root.parts}, + ) + + return rr["ETag"] + + def list_active(self): + s3 = self.s3_client() + rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key) + return [x["UploadId"] for x in rr.get("Uploads", [])] + + def read(self, **kw): + s3 = self.s3_client() + return s3.get_object(Bucket=self.bucket, Key=self.key, **kw)["Body"].read() + + def __dask_tokenize__(self): + return ( + self.bucket, + self.key, + self.uploadId, + ) + + def substream( + self, + partId: int, + chunks: dask_bag.Bag, + *, + writes_per_chunk: int = 1, + mark_final: bool = False, + spill_sz: int = 20 * (1 << 20), + ) -> dask_bag.Item: + write: Optional[PartsWriter] = None + if spill_sz > 0: + if not self.started: + self.initiate() + write = self + return MPUChunk.from_dask_bag( + partId, + chunks, + writes_per_chunk=writes_per_chunk, + mark_final=mark_final, + spill_sz=spill_sz, + write=write, + ) + + +def _mpu_append_chunks_op( + mpus: Iterable[MPUChunk], chunks: Iterable[Tuple[bytes, Any]] +): + # expect 1 MPUChunk per partition + (mpu,) = mpus + for chunk in chunks: + data, chunk_id = chunk + mpu.append(data, chunk_id) + return [mpu] + + +def _merge_and_spill_op( + lhs: MPUChunk, + rhs: MPUChunk, + write: Optional[PartsWriter] = None, + spill_sz: int = 0, +) -> MPUChunk: + mm = MPUChunk.merge(lhs, rhs, write) + if write is None or spill_sz == 0: + return mm + + mm.maybe_write(write, spill_sz) + return mm diff --git a/tests/test_s3.py b/tests/test_s3.py new file mode 100644 index 00000000..3e0bcf22 --- /dev/null +++ b/tests/test_s3.py @@ -0,0 +1,264 @@ +from hashlib import md5 +from typing import Any, Dict, List, Tuple + +import numpy as np + +from odc.geo.cog._s3 import MPUChunk, MultiPartUpload, PartsWriter + +FakeWriteResult = Tuple[int, bytearray, Dict[str, Any]] +# pylint: disable=unbalanced-tuple-unpacking + + +def etag(data): + return f'"{md5(data).hexdigest()}"' + + +def _mb(x: float) -> int: + return int(x * (1 << 20)) + + +def mk_fake_parts_writer(dst: List[FakeWriteResult]) -> PartsWriter: + def write(part: int, data: bytearray) -> Dict[str, Any]: + assert 1 <= part <= 10_000 + ww = {"PartNumber": part, "ETag": f'"{etag(data)}"'} + dst.append((part, data, ww)) + return ww + + return write + + +def _mk_fake_data(sz: int) -> bytes: + return np.random.bytes(sz) + + +def _split(data: bytes, offsets=List[int]) -> Tuple[bytes, ...]: + parts = [] + + for sz in offsets: + parts.append(data[:sz]) + data = data[sz:] + if data: + parts.append(data) + + return tuple(parts) + + +def _data(parts_written: List[FakeWriteResult]) -> bytes: + return b"".join(data for _, data, _ in sorted(parts_written, key=lambda x: x[0])) + + +def _parts(parts_written: List[FakeWriteResult]) -> List[Dict[str, Any]]: + return [part for _, _, part in sorted(parts_written, key=lambda x: x[0])] + + +def test_s3_mpu_merge_small() -> None: + # Test situation where parts get joined and get written eventually in one chunk + parts_written: List[FakeWriteResult] = [] + write = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(100) + da, db, dc = _split(data, [10, 20]) + + a = MPUChunk(1, 10) + b = MPUChunk(12, 2) + c = MPUChunk(14, 3, is_final=True) + assert a.is_final is False + assert b.is_final is False + assert c.is_final is True + assert a.started_write is False + + assert a.maybe_write(write, _mb(5)) == 0 + assert a.started_write is False + + b.append(db, "b1") + c.append(dc, "c1") + a.append(da, "a1") + + # b + c + bc = MPUChunk.merge(b, c, write) + assert bc.is_final is True + assert bc.nextPartId == b.nextPartId + assert bc.write_credits == (b.write_credits + c.write_credits) + assert bc.started_write is False + assert bc.left_data is None + assert bc.data == (db + dc) + + # a + (b + c) + abc = MPUChunk.merge(a, bc, write) + + assert abc.is_final is True + assert abc.started_write is False + assert abc.data == data + assert abc.left_data is None + assert abc.nextPartId == a.nextPartId + assert abc.write_credits == (a.write_credits + b.write_credits + c.write_credits) + assert len(abc.observed) == 3 + assert abc.observed == [(len(da), "a1"), (len(db), "b1"), (len(dc), "c1")] + + assert len(parts_written) == 0 + + assert abc.final_flush(write) == len(data) + assert len(abc.parts) == 1 + assert len(parts_written) == 1 + pid, _data, part = parts_written[0] + assert _data == data + assert pid == a.nextPartId + assert part["PartNumber"] == pid + assert abc.parts[0] == part + + +def test_mpu_multi_writes() -> None: + parts_written: List[FakeWriteResult] = [] + write = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(_mb(20)) + da1, db1, db2, dc1 = _split( + data, + [ + _mb(5.2), + _mb(7), + _mb(6), + ], + ) + a = MPUChunk(1, 10) + b = MPUChunk(10, 2) + c = MPUChunk(12, 3, is_final=True) + assert (a.is_final, b.is_final, c.is_final) == (False, False, True) + + b.append(db1, "b1") + # have enough data to write, but not enough to have left-over + assert b.maybe_write(write, _mb(6)) == 0 + b.append(db2, "b2") + assert b.observed == [(len(db1), "b1"), (len(db2), "b2")] + # have enough data to write and have left-over still + # (7 + 6) - 6 > 5 + assert b.maybe_write(write, _mb(6)) > 0 + assert b.started_write is True + assert len(b.data) == _mb(5) + assert b.write_credits == 1 + assert b.nextPartId == 11 + + a.append(da1, "a1") + assert a.maybe_write(write, _mb(6)) == 0 + + c.append(dc1, "c1") + assert c.maybe_write(write, _mb(6)) == 0 + + # Should flush a + assert len(parts_written) == 1 + ab = MPUChunk.merge(a, b, write) + assert len(parts_written) == 2 + assert len(ab.parts) == 2 + assert ab.started_write is True + assert ab.is_final is False + assert ab.data == (db1 + db2)[-_mb(5) :] + assert ab.left_data is None + + abc = MPUChunk.merge(ab, c, write) + assert len(abc.parts) == 2 + assert abc.is_final is True + assert abc.observed == [ + (len(da1), "a1"), + (len(db1), "b1"), + (len(db2), "b2"), + (len(dc1), "c1"), + ] + + assert abc.final_flush(write, None) > 0 + assert len(abc.parts) == 3 + assert _parts(parts_written) == abc.parts + assert _data(parts_written) == data + + +def test_mpu_left_data() -> None: + parts_written: List[FakeWriteResult] = [] + write = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(_mb(3 + 2 + (6 + 5.2))) + da1, db1, dc1, dc2 = _split( + data, + [ + _mb(3), + _mb(2), + _mb(6), + ], + ) + a = MPUChunk(1, 100) + b = MPUChunk(100, 100) + c = MPUChunk(200, 100) + + c.append(dc1, "c1") + c.append(dc2, "c2") + assert c.maybe_write(write, _mb(6)) > 0 + assert c.started_write is True + assert len(c.parts) == 1 + assert c.left_data is None + + b.append(db1, "b1") + a.append(da1, "a1") + + bc = MPUChunk.merge(b, c, write) + assert bc.started_write is True + assert bc.left_data is not None + assert bc.nextPartId == 201 + assert bc.write_credits == 99 + + # Expect (a.data + bc.left_data) to be written to PartId=1 + abc = MPUChunk.merge(a, bc, write) + assert len(abc.parts) == 2 + assert abc.parts[0]["PartNumber"] == 1 + assert abc.nextPartId == 201 + assert abc.write_credits == 99 + assert abc.is_final is False + + assert abc.observed == [ + (len(da1), "a1"), + (len(db1), "b1"), + (len(dc1), "c1"), + (len(dc2), "c2"), + ] + + assert abc.final_flush(write) > 0 + assert abc.nextPartId == 202 + assert abc.write_credits == 98 + assert len(abc.parts) == 3 + assert _data(parts_written) == data + assert _parts(parts_written) == abc.parts + + +def test_mpu_misc() -> None: + parts_written: List[FakeWriteResult] = [] + write = mk_fake_parts_writer(parts_written) + + a = MPUChunk(1, 10) + b = MPUChunk(10, 1) + + data = _mk_fake_data(_mb(3 + (6 + 7))) + da1, db1, db2 = _split( + data, + [_mb(3), _mb(6), _mb(7)], + ) + b.append(db1, "b1") + b.append(db2, "b2") + + # not enough credits to write + assert b.maybe_write(write, _mb(5)) == 0 + + a.append(da1, "a1") + + ab = MPUChunk.merge(a, b, write) + assert ab.started_write is False + assert len(parts_written) == 0 + assert ab.nextPartId == 1 + assert ab.write_credits == 11 + + assert ab.final_flush(write) > 0 + assert len(ab.parts) == 1 + assert _data(parts_written) == data + assert _parts(parts_written) == ab.parts + + +def test_mpu(): + mpu = MultiPartUpload("bucket", "file.dat") + assert mpu.bucket == "bucket" + assert mpu.key == "file.dat"