Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FS-149] New blobs upload API #2608

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
174 changes: 78 additions & 96 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from aiohttp import BytesIOPayload
from aiohttp.abc import AbstractStreamWriter

from modal_proto import api_pb2
from modal_proto import api_pb2, modal_api_grpc
from modal_proto.modal_api_grpc import ModalClientModal

from ..exception import ExecutionError
from .async_utils import TaskContext, retry
from .grpc_utils import retry_transient_errors
from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes
from .hash_utils import get_sha256_hex
from .http_utils import ClientSessionRegistry
from .logger import logger

Expand Down Expand Up @@ -122,79 +122,78 @@ def remaining_bytes(self):


@retry(n_attempts=5, base_delay=0.5, timeout=None)
async def _upload_to_s3_url(
upload_url,
async def _upload_with_request_details(
request_details: api_pb2.BlobUploadRequestDetails,
payload: BytesIOSegmentPayload,
content_md5_b64: Optional[str] = None,
content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
) -> str:
"""Returns etag of s3 object which is a md5 hex checksum of the uploaded content"""
):
with payload.reset_on_error(): # ensure retries read the same data
method = request_details.method
uri = request_details.uri

headers = {}
if content_md5_b64 and use_md5(upload_url):
headers["Content-MD5"] = content_md5_b64
if content_type:
headers["Content-Type"] = content_type
for header in request_details.headers:
headers[header.name] = header.value

async with ClientSessionRegistry.get_session().put(
upload_url,
async with ClientSessionRegistry.get_session().request(
method,
uri,
data=payload,
headers=headers,
skip_auto_headers=["content-type"] if content_type is None else [],
) as resp:
# S3 signal to slow down request rate.
if resp.status == 503:
logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
await asyncio.sleep(1)

if resp.status != 200:
if resp.status not in [200, 204]:
try:
text = await resp.text()
except Exception:
text = "<no body>"
raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}")
raise ExecutionError(f"{method} to url {uri} failed with status {resp.status}: {text}")

# client side ETag checksum verification
# the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content
etag = resp.headers["ETag"].strip()
if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3
etag = etag[2:]
if etag[0] == '"' and etag[-1] == '"':
etag = etag[1:-1]
remote_md5 = etag

local_md5_hex = payload.md5_checksum().hexdigest()
if local_md5_hex != remote_md5:
raise ExecutionError(f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})")

return remote_md5
async def _stage_and_upload(
stub: modal_api_grpc.ModalClientModal,
session_token: bytes,
part: int,
payload: BytesIOSegmentPayload,
):
req = api_pb2.BlobStagePartRequest(session_token=session_token, part=part)
resp = await retry_transient_errors(stub.BlobStagePart, req)
request_details = resp.upload_request
return await _upload_with_request_details(request_details, payload)


async def perform_multipart_upload(
async def _perform_multipart_upload(
data_file: Union[BinaryIO, io.BytesIO, io.FileIO],
*,
content_length: int,
stub: modal_api_grpc.ModalClientModal,
session_token: bytes,
blob_size: int,
max_part_size: int,
part_urls: list[str],
completion_url: str,
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
) -> None:
):
def ceildiv(a, b):
return -(a // -b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯


upload_coros = []
file_offset = 0
num_bytes_left = content_length
num_parts = ceildiv(blob_size, max_part_size)
num_bytes_left = blob_size

# Give each part its own IO reader object to avoid needing to
# lock access to the reader's position pointer.
data_file_readers: list[BinaryIO]
if isinstance(data_file, io.BytesIO):
view = data_file.getbuffer() # does not copy data
data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))]
data_file_readers = [io.BytesIO(view) for _ in range(num_parts)]
else:
filename = data_file.name
data_file_readers = [open(filename, "rb") for _ in range(len(part_urls))]
data_file_readers = [open(filename, "rb") for _ in range(num_parts)]

for part_number, (data_file_rdr, part_url) in enumerate(zip(data_file_readers, part_urls), start=1):
for part, data_file_rdr in enumerate(data_file_readers):
part_length_bytes = min(num_bytes_left, max_part_size)
part_payload = BytesIOSegmentPayload(
data_file_rdr,
Expand All @@ -203,40 +202,14 @@ async def perform_multipart_upload(
chunk_size=upload_chunk_size,
progress_report_cb=progress_report_cb,
)
upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None))
upload_coros.append(_stage_and_upload(stub, session_token, part, part_payload))
num_bytes_left -= part_length_bytes
file_offset += part_length_bytes

part_etags = await TaskContext.gather(*upload_coros)

# The body of the complete_multipart_upload command needs some data in xml format:
completion_body = "<CompleteMultipartUpload>\n"
for part_number, etag in enumerate(part_etags, 1):
completion_body += f"""<Part>\n<PartNumber>{part_number}</PartNumber>\n<ETag>"{etag}"</ETag>\n</Part>\n"""
completion_body += "</CompleteMultipartUpload>"

# etag of combined object should be md5 hex of concatendated md5 *bytes* from parts + `-{num_parts}`
bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags]

expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}"
resp = await ClientSessionRegistry.get_session().post(
completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"]
)
if resp.status != 200:
try:
msg = await resp.text()
except Exception:
msg = "<no body>"
raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}")
else:
response_body = await resp.text()
if expected_multipart_etag not in response_body:
raise ExecutionError(
f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}"
)
await TaskContext.gather(*upload_coros)


def get_content_length(data: BinaryIO) -> int:
def _get_blob_size(data: BinaryIO) -> int:
# *Remaining* length of file from current seek position
pos = data.tell()
data.seek(0, os.SEEK_END)
Expand All @@ -246,69 +219,78 @@ def get_content_length(data: BinaryIO) -> int:


async def _blob_upload(
upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None
sha256_hex: str,
data: Union[bytes, BinaryIO],
stub: modal_api_grpc.ModalClientModal,
progress_report_cb: Optional[Callable] = None
) -> str:
if isinstance(data, bytes):
data = io.BytesIO(data)

content_length = get_content_length(data)
blob_size = _get_blob_size(data)

req = api_pb2.BlobCreateRequest(
content_md5=upload_hashes.md5_base64,
content_sha256_base64=upload_hashes.sha256_base64,
content_length=content_length,
create_req = api_pb2.BlobCreateUploadRequest(
blob_hash=sha256_hex,
blob_size=blob_size,
)
resp = await retry_transient_errors(stub.BlobCreate, req)
create_resp = await retry_transient_errors(stub.BlobCreateUpload, create_req)

blob_id = resp.blob_id
session_token = create_resp.session_token

if resp.WhichOneof("upload_type_oneof") == "multipart":
await perform_multipart_upload(
which_oneof = create_resp.WhichOneof("upload_status")
if which_oneof == "already_exists":
return sha256_hex
elif which_oneof == "multi_part_upload":
await _perform_multipart_upload(
data,
content_length=content_length,
max_part_size=resp.multipart.part_length,
part_urls=resp.multipart.upload_urls,
completion_url=resp.multipart.completion_url,
stub=stub,
session_token=session_token,
blob_size=blob_size,
max_part_size=create_resp.multi_part_upload.part_size,
upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb=progress_report_cb,
)
else:
elif which_oneof == "single_part_upload":
request_details = create_resp.single_part_upload.upload_request
payload = BytesIOSegmentPayload(
data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
data, segment_start=0, segment_length=blob_size, progress_report_cb=progress_report_cb
)
await _upload_to_s3_url(
resp.upload_url,
await _upload_with_request_details(
request_details,
payload,
# for single part uploads, we use server side md5 checksums
content_md5_b64=upload_hashes.md5_base64,
)
else:
raise NotImplementedError(f"unsupported upload mode from CreateBlobUploadResponse: {which_oneof}")

commit_req = api_pb2.BlobCommitUploadRequest(session_token=session_token)
commit_resp = await retry_transient_errors(stub.BlobCommitUpload, commit_req)

if progress_report_cb:
progress_report_cb(complete=True)

return blob_id
return commit_resp.blob_hash


async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
async def blob_upload(payload: bytes, stub: modal_api_grpc.ModalClientModal) -> str:
size_mib = len(payload) / 1024 / 1024
logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
t0 = time.time()
if isinstance(payload, str):
logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
payload = payload.encode("utf8")
upload_hashes = get_upload_hashes(payload)
blob_id = await _blob_upload(upload_hashes, payload, stub)
sha256_hex = get_sha256_hex(payload)
blob_id = await _blob_upload(sha256_hex, payload, stub)
dur_s = max(time.time() - t0, 0.001) # avoid division by zero
throughput_mib_s = (size_mib) / dur_s
throughput_mib_s = size_mib / dur_s
logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}")
return blob_id


async def blob_upload_file(
file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None
file_obj: BinaryIO, stub: modal_api_grpc.ModalClientModal, progress_report_cb: Optional[Callable] = None
) -> str:
upload_hashes = get_upload_hashes(file_obj)
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
sha256_hex = get_sha256_hex(file_obj)
return await _blob_upload(sha256_hex, file_obj, stub, progress_report_cb)


@retry(n_attempts=5, base_delay=0.1, timeout=None)
Expand Down
17 changes: 0 additions & 17 deletions modal/_utils/hash_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Modal Labs 2022
import base64
import dataclasses
import hashlib
from typing import BinaryIO, Callable, Union

Expand Down Expand Up @@ -41,19 +40,3 @@ def get_md5_base64(data: Union[bytes, BinaryIO]) -> str:
hasher = hashlib.md5()
_update([hasher.update], data)
return base64.b64encode(hasher.digest()).decode("utf-8")


@dataclasses.dataclass
class UploadHashes:
md5_base64: str
sha256_base64: str


def get_upload_hashes(data: Union[bytes, BinaryIO]) -> UploadHashes:
md5 = hashlib.md5()
sha256 = hashlib.sha256()
_update([md5.update, sha256.update], data)
return UploadHashes(
md5_base64=base64.b64encode(md5.digest()).decode("ascii"),
sha256_base64=base64.b64encode(sha256.digest()).decode("ascii"),
)
Loading
Loading