Skip to content

Commit

Permalink
amend me: s3 chunked writer
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Sep 26, 2023
1 parent a54feea commit 784eea6
Show file tree
Hide file tree
Showing 2 changed files with 503 additions and 0 deletions.
242 changes: 242 additions & 0 deletions odc/geo/cog/_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""
S3 utils for COG to S3.
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from cachetools import cached

SomeData = Union[bytes, bytearray]
_5MB = 5 * (1 << 20)
PartsWriter = Callable[[int, bytearray], Dict[str, Any]]
if TYPE_CHECKING:
from botocore.credentials import ReadOnlyCredentials


class MPUChunk:
"""
chunk cache and writer
"""

# pylint: disable=too-many-arguments

__slots__ = (
"nextPartId",
"write_credits",
"data",
"left_data",
"parts",
"observed",
"is_final",
)

def __init__(
self,
partId: int,
write_credits: int,
data: Optional[bytearray] = None,
left_data: Optional[bytearray] = None,
parts: Optional[List[Dict[str, Any]]] = None,
observed: Optional[List[Tuple[int, Any]]] = None,
is_final: bool = False,
) -> None:
assert partId >= 1
assert write_credits >= 0
assert partId + write_credits <= 10_000

self.nextPartId = partId
self.write_credits = write_credits
self.data = bytearray() if data is None else data
self.left_data = left_data
self.parts: List[Dict[str, Any]] = [] if parts is None else parts
self.observed: List[Tuple[int, Any]] = [] if observed is None else observed
self.is_final = is_final

def __repr__(self) -> str:
s = f"MPUChunk: {self.nextPartId}#{self.write_credits} cache: {len(self.data)}"
if self.observed:
s = f"{s} observed[{len(self.observed)}]"
if self.parts:
s = f"{s} parts: [{len(self.parts)}]"
if self.is_final:
s = f"{s} final"
return s

def append(self, data: SomeData, chunk_id: Any = None):
sz = len(data)
self.observed.append((sz, chunk_id))
self.data += data

@property
def started_write(self) -> bool:
return len(self.parts) > 0

@staticmethod
def merge(lhs: "MPUChunk", rhs: "MPUChunk", write: PartsWriter) -> "MPUChunk":
if not rhs.started_write:
# no writes on the right
# Just append
assert rhs.left_data is None
assert len(rhs.parts) == 0

return MPUChunk(
lhs.nextPartId,
lhs.write_credits + rhs.write_credits,
lhs.data + rhs.data,
lhs.left_data,
lhs.parts,
lhs.observed + rhs.observed,
rhs.is_final,
)

# Flush `lhs.data + rhs.left_data` if we can
# or else move it into .left_data
lhs.final_flush(write, rhs.left_data)

return MPUChunk(
rhs.nextPartId,
rhs.write_credits,
rhs.data,
lhs.left_data,
lhs.parts + rhs.parts,
lhs.observed + rhs.observed,
rhs.is_final,
)

def final_flush(
self, write: PartsWriter, extra_data: Optional[bytearray] = None
) -> int:
data = self.data
if extra_data is not None:
data += extra_data

def _flush_data():
part = write(self.nextPartId, data)
self.parts.append(part)
self.data = bytearray()
self.nextPartId += 1
self.write_credits -= 1
return len(data)

# Have enough write credits
# AND (have enough bytes OR it's the last chunk)
can_flush = self.write_credits > 0 and (self.is_final or len(data) >= _5MB)

if self.started_write:
assert can_flush is True
return _flush_data()

if can_flush:
return _flush_data()

assert self.left_data is None
self.left_data, self.data = data, bytearray()
return 0

def maybe_write(self, write: PartsWriter, min_sz: int) -> int:
# if not last section keep 5MB and 1 partId around after flush
bytes_to_keep, parts_to_keep = (0, 0) if self.is_final else (_5MB, 1)

if self.write_credits - 1 < parts_to_keep:
return 0

bytes_to_write = len(self.data) - bytes_to_keep
if bytes_to_write < min_sz:
return 0

part = write(self.nextPartId, self.data[:bytes_to_write])

self.parts.append(part)
self.data = self.data[bytes_to_write:]
self.nextPartId += 1
self.write_credits -= 1

return bytes_to_write


class MultiPartUpload:
"""
Dask to S3 dumper.
"""

def __init__(
self,
bucket: str,
key: str,
*,
uploadId: str = "",
profile: Optional[str] = None,
endpoint_url: Optional[str] = None,
creds: Optional["ReadOnlyCredentials"] = None,
):
self.bucket = bucket
self.key = key
self.uploadId = uploadId
self.profile = profile
self.endpoint_url = endpoint_url
self.creds = creds

# @cached({}, key=lambda _self: (_self.profile, _self.endpoint_url, _self.creds))
@cached({})
def s3_client(self):
# pylint: disable=import-outside-toplevel,import-error
from botocore.session import Session

sess = Session(profile=self.profile)
creds = self.creds
if creds is None:
return sess.create_client("s3", endpoint_url=self.endpoint_url)
return sess.create_client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=creds.access_key,
aws_secret_access_key=creds.secret_key,
aws_session_token=creds.token,
)

def __call__(self, partId: int, data: bytearray) -> Dict[str, Any]:
s3 = self.s3_client()
assert self.uploadId != ""
rr = s3.upload_part(
PartNumber=partId,
Body=data,
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
)
etag = rr["ETag"]
return {"PartNumber": partId, "ETag": etag}

def initiate(self) -> str:
assert self.uploadId == ""
s3 = self.s3_client()

rr = s3.create_multipart_upload(Bucket=self.bucket, Key=self.key)
uploadId = rr["UploadId"]
self.uploadId = uploadId
return uploadId

def cancel(self, other: str = ""):
uploadId = other if other else self.uploadId
if not uploadId:
return

s3 = self.s3_client()
s3.abort_multipart_upload(Bucket=self.bucket, Key=self.key, UploadId=uploadId)
if uploadId == self.uploadId:
self.uploadId = ""

def complete(self, root: MPUChunk) -> str:
s3 = self.s3_client()
rr = s3.complete_multipart_upload(
Bucket=self.bucket,
Key=self.key,
UploadId=self.uploadId,
MultipartUpload={"Parts": root.parts},
)

return rr["ETag"]

def list_active(self):
s3 = self.s3_client()
rr = s3.list_multipart_uploads(Bucket=self.bucket, Prefix=self.key)
return [x["UploadId"] for x in rr.get("Uploads", [])]
Loading

0 comments on commit 784eea6

Please sign in to comment.