diff --git a/odc/geo/cog/_s3.py b/odc/geo/cog/_s3.py new file mode 100644 index 00000000..b6fd7334 --- /dev/null +++ b/odc/geo/cog/_s3.py @@ -0,0 +1,242 @@ +""" +S3 utils for COG to S3. +""" +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from cachetools import cached + +SomeData = Union[bytes, bytearray] +_5MB = 5 * (1 << 20) +PartsWriter = Callable[[int, bytearray], Dict[str, Any]] +if TYPE_CHECKING: + from botocore.credentials import ReadOnlyCredentials + + +class MPUChunk: + """ + chunk cache and writer + """ + + # pylint: disable=too-many-arguments + + __slots__ = ( + "nextPartId", + "write_credits", + "data", + "left_data", + "parts", + "observed", + "is_final", + ) + + def __init__( + self, + partId: int, + write_credits: int, + data: Optional[bytearray] = None, + left_data: Optional[bytearray] = None, + parts: Optional[List[Dict[str, Any]]] = None, + observed: Optional[List[Tuple[int, Any]]] = None, + is_final: bool = False, + ) -> None: + assert partId >= 1 + assert write_credits >= 0 + assert partId + write_credits <= 10_000 + + self.nextPartId = partId + self.write_credits = write_credits + self.data = bytearray() if data is None else data + self.left_data = left_data + self.parts: List[Dict[str, Any]] = [] if parts is None else parts + self.observed: List[Tuple[int, Any]] = [] if observed is None else observed + self.is_final = is_final + + def __repr__(self) -> str: + s = f"MPUChunk: {self.nextPartId}#{self.write_credits} cache: {len(self.data)}" + if self.observed: + s = f"{s} observed[{len(self.observed)}]" + if self.parts: + s = f"{s} parts: [{len(self.parts)}]" + if self.is_final: + s = f"{s} final" + return s + + def append(self, data: SomeData, chunk_id: Any = None): + sz = len(data) + self.observed.append((sz, chunk_id)) + self.data += data + + @property + def started_write(self) -> bool: + return len(self.parts) > 0 + + @staticmethod + def merge(lhs: "MPUChunk", rhs: "MPUChunk", write: PartsWriter) -> "MPUChunk": + if not rhs.started_write: + # no writes on the right + # Just append + assert rhs.left_data is None + assert len(rhs.parts) == 0 + + return MPUChunk( + lhs.nextPartId, + lhs.write_credits + rhs.write_credits, + lhs.data + rhs.data, + lhs.left_data, + lhs.parts, + lhs.observed + rhs.observed, + rhs.is_final, + ) + + # Flush `lhs.data + rhs.left_data` if we can + # or else move it into .left_data + lhs.final_flush(write, rhs.left_data) + + return MPUChunk( + rhs.nextPartId, + rhs.write_credits, + rhs.data, + lhs.left_data, + lhs.parts + rhs.parts, + lhs.observed + rhs.observed, + rhs.is_final, + ) + + def final_flush( + self, write: PartsWriter, extra_data: Optional[bytearray] = None + ) -> int: + data = self.data + if extra_data is not None: + data += extra_data + + def _flush_data(): + part = write(self.nextPartId, data) + self.parts.append(part) + self.data = bytearray() + self.nextPartId += 1 + self.write_credits -= 1 + return len(data) + + # Have enough write credits + # AND (have enough bytes OR it's the last chunk) + can_flush = self.write_credits > 0 and (self.is_final or len(data) >= _5MB) + + if self.started_write: + assert can_flush is True + return _flush_data() + + if can_flush: + return _flush_data() + + assert self.left_data is None + self.left_data, self.data = data, bytearray() + return 0 + + def maybe_write(self, write: PartsWriter, min_sz: int) -> int: + # if not last section keep 5MB and 1 partId around after flush + bytes_to_keep, parts_to_keep = (0, 0) if self.is_final else (_5MB, 1) + + if self.write_credits - 1 < parts_to_keep: + return 0 + + bytes_to_write = len(self.data) - bytes_to_keep + if bytes_to_write < min_sz: + return 0 + + part = write(self.nextPartId, self.data[:bytes_to_write]) + + self.parts.append(part) + self.data = self.data[bytes_to_write:] + self.nextPartId += 1 + self.write_credits -= 1 + + return bytes_to_write + + +class MultiPartUpload: + """ + Dask to S3 dumper. + """ + + def __init__( + self, + bucket: str, + key: str, + *, + uploadId: str = "", + profile: Optional[str] = None, + endpoint_url: Optional[str] = None, + creds: Optional["ReadOnlyCredentials"] = None, + ): + self.bucket = bucket + self.key = key + self.uploadId = uploadId + self.profile = profile + self.endpoint_url = endpoint_url + self.creds = creds + + # @cached({}, key=lambda _self: (_self.profile, _self.endpoint_url, _self.creds)) + @cached({}) + def s3_client(self): + # pylint: disable=import-outside-toplevel,import-error + from botocore.session import Session + + sess = Session(profile=self.profile) + creds = self.creds + if creds is None: + return sess.create_client("s3", endpoint_url=self.endpoint_url) + return sess.create_client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=creds.access_key, + aws_secret_access_key=creds.secret_key, + aws_session_token=creds.token, + ) + + def __call__(self, partId: int, data: bytearray) -> Dict[str, Any]: + s3 = self.s3_client() + assert self.uploadId != "" + rr = s3.upload_part( + PartNumber=partId, + Body=data, + Bucket=self.bucket, + Key=self.key, + UploadId=self.uploadId, + ) + etag = rr["ETag"] + return {"PartNumber": partId, "ETag": etag} + + def initiate(self) -> str: + assert self.uploadId == "" + s3 = self.s3_client() + + rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key) + uploadId = rr["UploadId"] + self.uploadId = uploadId + return uploadId + + def cancel(self, other: str = ""): + uploadId = other if other else self.uploadId + if not uploadId: + return + + s3 = self.s3_client() + s3.abort_multipart_upload(Bucket=self.bucket, Key=self.key, UploadId=uploadId) + if uploadId == self.uploadId: + self.uploadId = "" + + def complete(self, root: MPUChunk) -> str: + s3 = self.s3_client() + rr = s3.complete_multipart_upload( + Bucket=self.bucket, + Key=self.key, + UploadId=self.uploadId, + MultipartUpload={"Parts": root.parts}, + ) + + return rr["ETag"] + + def list_active(self): + s3 = self.s3_client() + rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key) + return [x["UploadId"] for x in rr.get("Uploads", [])] diff --git a/tests/test_s3.py b/tests/test_s3.py new file mode 100644 index 00000000..13a7e334 --- /dev/null +++ b/tests/test_s3.py @@ -0,0 +1,261 @@ +from hashlib import md5 +from typing import Any, Dict, List, Tuple + +import numpy as np + +from odc.geo.cog._s3 import MPUChunk, MultiPartUpload, PartsWriter + +FakeWriteResult = Tuple[int, bytearray, Dict[str, Any]] +# pylint: disable=unbalanced-tuple-unpacking + + +def _mb(x: float) -> int: + return int(x * (1 << 20)) + + +def mk_fake_parts_writer(dst: List[FakeWriteResult]) -> PartsWriter: + def writer(part: int, data: bytearray) -> Dict[str, Any]: + assert 1 <= part <= 10_000 + etag = md5(data).hexdigest() + ww = {"PartNumber": part, "ETag": f'"{etag}"'} + dst.append((part, data, ww)) + return ww + + return writer + + +def _mk_fake_data(sz: int) -> bytes: + return np.random.bytes(sz) + + +def _split(data: bytes, offsets=List[int]) -> Tuple[bytes, ...]: + parts = [] + + for sz in offsets: + parts.append(data[:sz]) + data = data[sz:] + if data: + parts.append(data) + + return tuple(parts) + + +def _data(parts_written: List[FakeWriteResult]) -> bytes: + return b"".join(data for _, data, _ in sorted(parts_written, key=lambda x: x[0])) + + +def _parts(parts_written: List[FakeWriteResult]) -> List[Dict[str, Any]]: + return [part for _, _, part in sorted(parts_written, key=lambda x: x[0])] + + +def test_s3_mpu_merge_small() -> None: + # Test situation where parts get joined and get written eventually in one chunk + parts_written: List[FakeWriteResult] = [] + writer = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(100) + da, db, dc = _split(data, [10, 20]) + + a = MPUChunk(1, 10) + b = MPUChunk(12, 2) + c = MPUChunk(14, 3, is_final=True) + assert a.is_final is False + assert b.is_final is False + assert c.is_final is True + assert a.started_write is False + + assert a.maybe_write(writer, _mb(5)) == 0 + assert a.started_write is False + + b.append(db, "b1") + c.append(dc, "c1") + a.append(da, "a1") + + # b + c + bc = MPUChunk.merge(b, c, writer) + assert bc.is_final is True + assert bc.nextPartId == b.nextPartId + assert bc.write_credits == (b.write_credits + c.write_credits) + assert bc.started_write is False + assert bc.left_data is None + assert bc.data == (db + dc) + + # a + (b + c) + abc = MPUChunk.merge(a, bc, writer) + + assert abc.is_final is True + assert abc.started_write is False + assert abc.data == data + assert abc.left_data is None + assert abc.nextPartId == a.nextPartId + assert abc.write_credits == (a.write_credits + b.write_credits + c.write_credits) + assert len(abc.observed) == 3 + assert abc.observed == [(len(da), "a1"), (len(db), "b1"), (len(dc), "c1")] + + assert len(parts_written) == 0 + + assert abc.final_flush(writer) == len(data) + assert len(abc.parts) == 1 + assert len(parts_written) == 1 + pid, _data, part = parts_written[0] + assert _data == data + assert pid == a.nextPartId + assert part["PartNumber"] == pid + assert abc.parts[0] == part + + +def test_mpu_multi_writes() -> None: + parts_written: List[FakeWriteResult] = [] + writer = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(_mb(20)) + da1, db1, db2, dc1 = _split( + data, + [ + _mb(5.2), + _mb(7), + _mb(6), + ], + ) + a = MPUChunk(1, 10) + b = MPUChunk(10, 2) + c = MPUChunk(12, 3, is_final=True) + assert (a.is_final, b.is_final, c.is_final) == (False, False, True) + + b.append(db1, "b1") + # have enough data to write, but not enough to have left-over + assert b.maybe_write(writer, _mb(6)) == 0 + b.append(db2, "b2") + assert b.observed == [(len(db1), "b1"), (len(db2), "b2")] + # have enough data to write and have left-over still + # (7 + 6) - 6 > 5 + assert b.maybe_write(writer, _mb(6)) > 0 + assert b.started_write is True + assert len(b.data) == _mb(5) + assert b.write_credits == 1 + assert b.nextPartId == 11 + + a.append(da1, "a1") + assert a.maybe_write(writer, _mb(6)) == 0 + + c.append(dc1, "c1") + assert c.maybe_write(writer, _mb(6)) == 0 + + # Should flush a + assert len(parts_written) == 1 + ab = MPUChunk.merge(a, b, writer) + assert len(parts_written) == 2 + assert len(ab.parts) == 2 + assert ab.started_write is True + assert ab.is_final is False + assert ab.data == (db1 + db2)[-_mb(5) :] + assert ab.left_data is None + + abc = MPUChunk.merge(ab, c, writer) + assert len(abc.parts) == 2 + assert abc.is_final is True + assert abc.observed == [ + (len(da1), "a1"), + (len(db1), "b1"), + (len(db2), "b2"), + (len(dc1), "c1"), + ] + + assert abc.final_flush(writer, None) > 0 + assert len(abc.parts) == 3 + assert _parts(parts_written) == abc.parts + assert _data(parts_written) == data + + +def test_mpu_left_data() -> None: + parts_written: List[FakeWriteResult] = [] + writer = mk_fake_parts_writer(parts_written) + + data = _mk_fake_data(_mb(3 + 2 + (6 + 5.2))) + da1, db1, dc1, dc2 = _split( + data, + [ + _mb(3), + _mb(2), + _mb(6), + ], + ) + a = MPUChunk(1, 100) + b = MPUChunk(100, 100) + c = MPUChunk(200, 100) + + c.append(dc1, "c1") + c.append(dc2, "c2") + assert c.maybe_write(writer, _mb(6)) > 0 + assert c.started_write is True + assert len(c.parts) == 1 + assert c.left_data is None + + b.append(db1, "b1") + a.append(da1, "a1") + + bc = MPUChunk.merge(b, c, writer) + assert bc.started_write is True + assert bc.left_data is not None + assert bc.nextPartId == 201 + assert bc.write_credits == 99 + + # Expect (a.data + bc.left_data) to be written to PartId=1 + abc = MPUChunk.merge(a, bc, writer) + assert len(abc.parts) == 2 + assert abc.parts[0]["PartNumber"] == 1 + assert abc.nextPartId == 201 + assert abc.write_credits == 99 + assert abc.is_final is False + + assert abc.observed == [ + (len(da1), "a1"), + (len(db1), "b1"), + (len(dc1), "c1"), + (len(dc2), "c2"), + ] + + assert abc.final_flush(writer) > 0 + assert abc.nextPartId == 202 + assert abc.write_credits == 98 + assert len(abc.parts) == 3 + assert _data(parts_written) == data + assert _parts(parts_written) == abc.parts + + +def test_mpu_misc() -> None: + parts_written: List[FakeWriteResult] = [] + writer = mk_fake_parts_writer(parts_written) + + a = MPUChunk(1, 10) + b = MPUChunk(10, 1) + + data = _mk_fake_data(_mb(3 + (6 + 7))) + da1, db1, db2 = _split( + data, + [_mb(3), _mb(6), _mb(7)], + ) + b.append(db1, "b1") + b.append(db2, "b2") + + # not enough credits to write + assert b.maybe_write(writer, _mb(5)) == 0 + + a.append(da1, "a1") + + ab = MPUChunk.merge(a, b, writer) + assert ab.started_write is False + assert len(parts_written) == 0 + assert ab.nextPartId == 1 + assert ab.write_credits == 11 + + assert ab.final_flush(writer) > 0 + assert len(ab.parts) == 1 + assert _data(parts_written) == data + assert _parts(parts_written) == ab.parts + + +def test_mpu(): + mpu = MultiPartUpload("bucket", "file.dat") + assert mpu.bucket == "bucket" + assert mpu.key == "file.dat"