diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index a73eab1eb..f8d033727 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -15,13 +15,13 @@ from aiohttp import BytesIOPayload from aiohttp.abc import AbstractStreamWriter -from modal_proto import api_pb2 +from modal_proto import api_pb2, modal_api_grpc from modal_proto.modal_api_grpc import ModalClientModal from ..exception import ExecutionError from .async_utils import TaskContext, retry from .grpc_utils import retry_transient_errors -from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes +from .hash_utils import get_sha256_hex from .http_utils import ClientSessionRegistry from .logger import logger @@ -122,79 +122,78 @@ def remaining_bytes(self): @retry(n_attempts=5, base_delay=0.5, timeout=None) -async def _upload_to_s3_url( - upload_url, +async def _upload_with_request_details( + request_details: api_pb2.BlobUploadRequestDetails, payload: BytesIOSegmentPayload, - content_md5_b64: Optional[str] = None, - content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header -) -> str: - """Returns etag of s3 object which is a md5 hex checksum of the uploaded content""" +): with payload.reset_on_error(): # ensure retries read the same data + method = request_details.method + uri = request_details.uri + headers = {} - if content_md5_b64 and use_md5(upload_url): - headers["Content-MD5"] = content_md5_b64 - if content_type: - headers["Content-Type"] = content_type + for header in request_details.headers: + headers[header.name] = header.value - async with ClientSessionRegistry.get_session().put( - upload_url, + async with ClientSessionRegistry.get_session().request( + method, + uri, data=payload, headers=headers, - skip_auto_headers=["content-type"] if content_type is None else [], ) as resp: # S3 signal to slow down request rate. if resp.status == 503: logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.") await asyncio.sleep(1) - if resp.status != 200: + if resp.status not in [200, 204]: try: text = await resp.text() except Exception: text = "" - raise ExecutionError(f"Put to url {upload_url} failed with status {resp.status}: {text}") + raise ExecutionError(f"{method} to url {uri} failed with status {resp.status}: {text}") - # client side ETag checksum verification - # the s3 ETag of a single part upload is a quoted md5 hex of the uploaded content - etag = resp.headers["ETag"].strip() - if etag.startswith(("W/", "w/")): # see https://www.rfc-editor.org/rfc/rfc7232#section-2.3 - etag = etag[2:] - if etag[0] == '"' and etag[-1] == '"': - etag = etag[1:-1] - remote_md5 = etag - local_md5_hex = payload.md5_checksum().hexdigest() - if local_md5_hex != remote_md5: - raise ExecutionError(f"Local data and remote data checksum mismatch ({local_md5_hex} vs {remote_md5})") - - return remote_md5 +async def _stage_and_upload( + stub: modal_api_grpc.ModalClientModal, + session_token: bytes, + part: int, + payload: BytesIOSegmentPayload, +): + req = api_pb2.BlobStagePartRequest(session_token=session_token, part=part) + resp = await retry_transient_errors(stub.BlobStagePart, req) + request_details = resp.upload_request + return await _upload_with_request_details(request_details, payload) -async def perform_multipart_upload( +async def _perform_multipart_upload( data_file: Union[BinaryIO, io.BytesIO, io.FileIO], *, - content_length: int, + stub: modal_api_grpc.ModalClientModal, + session_token: bytes, + blob_size: int, max_part_size: int, - part_urls: list[str], - completion_url: str, upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE, progress_report_cb: Optional[Callable] = None, -) -> None: +): + def ceildiv(a, b): + return -(a // -b) + upload_coros = [] file_offset = 0 - num_bytes_left = content_length + num_parts = ceildiv(blob_size, max_part_size) + num_bytes_left = blob_size # Give each part its own IO reader object to avoid needing to # lock access to the reader's position pointer. data_file_readers: list[BinaryIO] if isinstance(data_file, io.BytesIO): view = data_file.getbuffer() # does not copy data - data_file_readers = [io.BytesIO(view) for _ in range(len(part_urls))] + data_file_readers = [io.BytesIO(view) for _ in range(num_parts)] else: filename = data_file.name - data_file_readers = [open(filename, "rb") for _ in range(len(part_urls))] + data_file_readers = [open(filename, "rb") for _ in range(num_parts)] - for part_number, (data_file_rdr, part_url) in enumerate(zip(data_file_readers, part_urls), start=1): + for part, data_file_rdr in enumerate(data_file_readers): part_length_bytes = min(num_bytes_left, max_part_size) part_payload = BytesIOSegmentPayload( data_file_rdr, @@ -203,40 +202,14 @@ async def perform_multipart_upload( chunk_size=upload_chunk_size, progress_report_cb=progress_report_cb, ) - upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None)) + upload_coros.append(_stage_and_upload(stub, session_token, part, part_payload)) num_bytes_left -= part_length_bytes file_offset += part_length_bytes - part_etags = await TaskContext.gather(*upload_coros) - - # The body of the complete_multipart_upload command needs some data in xml format: - completion_body = "\n" - for part_number, etag in enumerate(part_etags, 1): - completion_body += f"""\n{part_number}\n"{etag}"\n\n""" - completion_body += "" - - # etag of combined object should be md5 hex of concatendated md5 *bytes* from parts + `-{num_parts}` - bin_hash_parts = [bytes.fromhex(etag) for etag in part_etags] - - expected_multipart_etag = hashlib.md5(b"".join(bin_hash_parts)).hexdigest() + f"-{len(part_etags)}" - resp = await ClientSessionRegistry.get_session().post( - completion_url, data=completion_body.encode("ascii"), skip_auto_headers=["content-type"] - ) - if resp.status != 200: - try: - msg = await resp.text() - except Exception: - msg = "" - raise ExecutionError(f"Error when completing multipart upload: {resp.status}\n{msg}") - else: - response_body = await resp.text() - if expected_multipart_etag not in response_body: - raise ExecutionError( - f"Hash mismatch on multipart upload assembly: {expected_multipart_etag} not in {response_body}" - ) + await TaskContext.gather(*upload_coros) -def get_content_length(data: BinaryIO) -> int: +def _get_blob_size(data: BinaryIO) -> int: # *Remaining* length of file from current seek position pos = data.tell() data.seek(0, os.SEEK_END) @@ -246,69 +219,78 @@ def get_content_length(data: BinaryIO) -> int: async def _blob_upload( - upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None + sha256_hex: str, + data: Union[bytes, BinaryIO], + stub: modal_api_grpc.ModalClientModal, + progress_report_cb: Optional[Callable] = None ) -> str: if isinstance(data, bytes): data = io.BytesIO(data) - content_length = get_content_length(data) + blob_size = _get_blob_size(data) - req = api_pb2.BlobCreateRequest( - content_md5=upload_hashes.md5_base64, - content_sha256_base64=upload_hashes.sha256_base64, - content_length=content_length, + create_req = api_pb2.BlobCreateUploadRequest( + blob_hash=sha256_hex, + blob_size=blob_size, ) - resp = await retry_transient_errors(stub.BlobCreate, req) + create_resp = await retry_transient_errors(stub.BlobCreateUpload, create_req) - blob_id = resp.blob_id + session_token = create_resp.session_token - if resp.WhichOneof("upload_type_oneof") == "multipart": - await perform_multipart_upload( + which_oneof = create_resp.WhichOneof("upload_status") + if which_oneof == "already_exists": + return sha256_hex + elif which_oneof == "multi_part_upload": + await _perform_multipart_upload( data, - content_length=content_length, - max_part_size=resp.multipart.part_length, - part_urls=resp.multipart.upload_urls, - completion_url=resp.multipart.completion_url, + stub=stub, + session_token=session_token, + blob_size=blob_size, + max_part_size=create_resp.multi_part_upload.part_size, upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE, progress_report_cb=progress_report_cb, ) - else: + elif which_oneof == "single_part_upload": + request_details = create_resp.single_part_upload.upload_request payload = BytesIOSegmentPayload( - data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb + data, segment_start=0, segment_length=blob_size, progress_report_cb=progress_report_cb ) - await _upload_to_s3_url( - resp.upload_url, + await _upload_with_request_details( + request_details, payload, - # for single part uploads, we use server side md5 checksums - content_md5_b64=upload_hashes.md5_base64, ) + else: + raise NotImplementedError(f"unsupported upload mode from CreateBlobUploadResponse: {which_oneof}") + + commit_req = api_pb2.BlobCommitUploadRequest(session_token=session_token) + commit_resp = await retry_transient_errors(stub.BlobCommitUpload, commit_req) if progress_report_cb: progress_report_cb(complete=True) - return blob_id + return commit_resp.blob_hash -async def blob_upload(payload: bytes, stub: ModalClientModal) -> str: +async def blob_upload(payload: bytes, stub: modal_api_grpc.ModalClientModal) -> str: size_mib = len(payload) / 1024 / 1024 logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB") t0 = time.time() if isinstance(payload, str): logger.warning("Blob uploading string, not bytes - auto-encoding as utf8") payload = payload.encode("utf8") - upload_hashes = get_upload_hashes(payload) - blob_id = await _blob_upload(upload_hashes, payload, stub) + sha256_hex = get_sha256_hex(payload) + blob_id = await _blob_upload(sha256_hex, payload, stub) dur_s = max(time.time() - t0, 0.001) # avoid division by zero - throughput_mib_s = (size_mib) / dur_s + throughput_mib_s = size_mib / dur_s logger.debug(f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s)." f" {blob_id}") return blob_id async def blob_upload_file( - file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None + file_obj: BinaryIO, stub: modal_api_grpc.ModalClientModal, progress_report_cb: Optional[Callable] = None ) -> str: - upload_hashes = get_upload_hashes(file_obj) - return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb) + sha256_hex = get_sha256_hex(file_obj) + return await _blob_upload(sha256_hex, file_obj, stub, progress_report_cb) @retry(n_attempts=5, base_delay=0.1, timeout=None) diff --git a/modal/_utils/hash_utils.py b/modal/_utils/hash_utils.py index 4c9abb268..21800d630 100644 --- a/modal/_utils/hash_utils.py +++ b/modal/_utils/hash_utils.py @@ -1,6 +1,5 @@ # Copyright Modal Labs 2022 import base64 -import dataclasses import hashlib from typing import BinaryIO, Callable, Union @@ -41,19 +40,3 @@ def get_md5_base64(data: Union[bytes, BinaryIO]) -> str: hasher = hashlib.md5() _update([hasher.update], data) return base64.b64encode(hasher.digest()).decode("utf-8") - - -@dataclasses.dataclass -class UploadHashes: - md5_base64: str - sha256_base64: str - - -def get_upload_hashes(data: Union[bytes, BinaryIO]) -> UploadHashes: - md5 = hashlib.md5() - sha256 = hashlib.sha256() - _update([md5.update, sha256.update], data) - return UploadHashes( - md5_base64=base64.b64encode(md5.digest()).decode("ascii"), - sha256_base64=base64.b64encode(sha256.digest()).decode("ascii"), - ) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 97ee5ff5c..f7515c532 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -547,6 +547,19 @@ message BaseImage { reserved 4; } +// Request payload for `Blobs.CommitBlobUpload` +message BlobCommitUploadRequest { + // The session token to check for completion and commit. + bytes session_token = 1; +} + +// Response payload for `Blobs.CommitBlobUpload` +message BlobCommitUploadResponse { + // The hash of the final committed blob. This blob will now be visible across + // other blobnet-related APIs. + string blob_hash = 1; +} + message BlobCreateRequest { // TODO(erikbern): how are these garbage collected? // Shouldn't they belong to an app? @@ -563,6 +576,54 @@ message BlobCreateResponse { } } +// Request payload for `Blobs.CreateBlobUpload` +message BlobCreateUploadRequest { + // The blob hash that we expect the resulting blob to have, as a hex-encoded + // SHA-256 digest. This value will not be trusted when creating the resulting + // blob; the service will hash and verify the actually uploaded chunks + // instead. However, the value can be used to skip uploads of files that + // already exist, and verify data integrity once the upload has completed. + string blob_hash = 1; + + // The size of the blob to be uploaded. This value informs how many parts + // will be required for uploading the blob, so this size value must be + // accurate. + uint64 blob_size = 2; +} + +// Response payload for `Blobs.CreateBlobUpload` +message BlobCreateUploadResponse { + // Use this token for subsequent calls + bytes session_token = 1; + + oneof upload_status { + // The blob with the given `blob_hash` already exists; no upload is + // necessary. + bool already_exists = 2; + // The blob does not exist and is small enough to be uploaded as a single + // part. Note that `Blobs.CommitBlobUpload` must be called after uploading the + // part. + SinglePartUpload single_part_upload = 3; + // The blob does not exist and needs to be uploaded in multiple parts. Note + // that `Blobs.CommitBlobUpload` must be called after uploading the parts. + MultiPartUpload multi_part_upload = 4; + } + + // Details about a multi-part upload process involving calls to + // `Blobs.StageBlobPart`. + message MultiPartUpload { + // Upload parts split by this part size. All except the last part need to + // have exactly this size. + uint64 part_size = 1; + } + + // Details about a single-part upload process. + message SinglePartUpload { + // The part should be uploaded using the following request details. + BlobUploadRequestDetails upload_request = 1; + } +} + message BlobGetRequest { string blob_id = 1; } @@ -571,6 +632,39 @@ message BlobGetResponse { string download_url = 1; } +// Request payload for `Blobs.StageBlobPart` +message BlobStagePartRequest { + // Session token previously received from `Blobs.CreateBlobUpload`. + bytes session_token = 1; + // The part that we are uploading in this request. + // + // The first part has index 0. + uint64 part = 2; +} + +// Response payload for `Blobs.StageBlobPart` +message BlobStagePartResponse { + // The part should be uploaded using the following request details. The size + // must not exceed the `MultiPartUpload.part_size` that was previously + // received. Also, all but the last part must have exactly that size. + BlobUploadRequestDetails upload_request = 1; +} + +// Generic HTTP-style details for a request +message BlobUploadRequestDetails { + // The request should be made to the following URI + string uri = 1; + // The request should use the following HTTP method + string method = 2; + // The request should have the following headers + repeated Header headers = 3; + + message Header { + string name = 1; + string value = 2; + } +} + message BuildFunction { string definition = 1; bytes globals = 2; @@ -2655,8 +2749,17 @@ service ModalClient { rpc AppStop(AppStopRequest) returns (google.protobuf.Empty); // Blobs + + // Commit a staged blob upload after all parts have been uploaded using + // `StageBlobPart`. + rpc BlobCommitUpload(BlobCommitUploadRequest) returns (BlobCommitUploadResponse); rpc BlobCreate(BlobCreateRequest) returns (BlobCreateResponse); + // Initiate a new blob upload. + rpc BlobCreateUpload(BlobCreateUploadRequest) returns (BlobCreateUploadResponse); rpc BlobGet(BlobGetRequest) returns (BlobGetResponse); + // Stage a part of a blob for a `session_token` that was previously returned + // from `CreateBlobUpload`. + rpc BlobStagePart(BlobStagePartRequest) returns (BlobStagePartResponse); // Classes rpc ClassCreate(ClassCreateRequest) returns (ClassCreateResponse); diff --git a/test/conftest.py b/test/conftest.py index e468f244f..68ff9c8ea 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -194,6 +194,7 @@ def __init__(self, blob_host, blobs, credentials): self.fail_blob_create = [] self.blob_create_metadata = None self.blob_multipart_threshold = 10_000_000 + self.blobnet_chunk_size = 8 * 1024 * 1024 self.precreated_functions = set() @@ -630,6 +631,70 @@ async def BlobGet(self, stream): download_url = f"{self.blob_host}/download?blob_id={request.blob_id}" await stream.send_message(api_pb2.BlobGetResponse(download_url=download_url)) + async def BlobCreateUpload(self, stream): + request: api_pb2.BlobCreateUploadRequest = await stream.recv_message() + blob_hash = request.blob_hash + blob_size = request.blob_size + + session_token = f"{blob_hash}:{blob_size}".encode("utf-8") + + if blob_size > self.blobnet_chunk_size: + response = api_pb2.BlobCreateUploadResponse( + session_token=session_token, + single_part_upload = api_pb2.BlobCreateUploadResponse.SinglePartUpload( + upload_request=api_pb2.BlobUploadRequestDetails( + uri=f"{self.blob_host}/upload?blob_id={blob_hash}&part_number=0", + method="PUT", + headers=[api_pb2.BlobUploadRequestDetails.Header(name="Content-Length", value=str(blob_size))], + ) + ) + ) + else: + response = api_pb2.BlobCreateUploadResponse( + session_token=session_token, + multi_part_upload = api_pb2.BlobCreateUploadResponse.MultiPartUpload( + part_size=self.blobnet_chunk_size, + ) + ) + + await stream.send_message(response) + + async def BlobStagePart(self, stream): + request: api_pb2.BlobStagePartRequest = await stream.recv_message() + blob_hash, blob_size_str = request.session_token.decode("utf-8").split(":") + blob_size = int(blob_size_str) + part = request.part + + def ceildiv(a, b): + return -(a // -b) + + num_parts = ceildiv(blob_size, self.blobnet_chunk_size) + assert request.part < num_parts + + upload_request = api_pb2.BlobUploadRequestDetails( + uri=f"{self.blob_host}/upload?blob_id={blob_hash}&part_number={part}", + method="PUT", + headers=[api_pb2.BlobUploadRequestDetails.Header(name="Content-Length", value=str(blob_size))], + ) + + response = api_pb2.BlobStagePartResponse(upload_request=upload_request) + await stream.send_message(response) + + async def BlobCommitUpload(self, stream): + request: api_pb2.BlobCommitUploadRequest = await stream.recv_message() + blob_hash, blob_size_str = request.session_token.decode("utf-8").split(":") + blob_size = int(blob_size_str) + + url = f"{self.blob_host}/commit?blob_id={blob_hash}&expected_size={blob_size}" + async with aiohttp.request("POST", url) as r: + r.raise_for_status() + actual_hash = await r.text() + + assert blob_hash == actual_hash + + response = api_pb2.BlobCommitUploadResponse(blob_hash=blob_hash) + await stream.send_message(response) + ### Class async def ClassCreate(self, stream): @@ -1734,32 +1799,33 @@ def blob_server(): async def upload(request): blob_id = request.query["blob_id"] + part_number = int(request.query["part_number"]) content = await request.content.read() if content == b"FAILURE": return aiohttp.web.Response(status=500) - content_md5 = hashlib.md5(content).hexdigest() - etag = f'"{content_md5}"' - if "part_number" in request.query: - part_number = int(request.query["part_number"]) - blob_parts[blob_id][part_number] = content - else: - blobs[blob_id] = content - return aiohttp.web.Response(text="Hello, world", headers={"ETag": etag}) + blob_parts[blob_id][part_number] = content + return aiohttp.web.Response(status=200) - async def complete_multipart(request): + async def commit(request): blob_id = request.query["blob_id"] - blob_nums = range(min(blob_parts[blob_id].keys()), max(blob_parts[blob_id].keys()) + 1) + expected_size = int(request.query["expected_size"]) + part_keys = blob_parts[blob_id].keys() + + if len(part_keys) == 0: + return aiohttp.web.Response(status=412, text="No parts uploaded") + + blob_nums = range(max(part_keys) + 1) content = b"" - part_hashes = b"" for num in blob_nums: part_content = blob_parts[blob_id][num] content += part_content - part_hashes += hashlib.md5(part_content).digest() - content_md5 = hashlib.md5(part_hashes).hexdigest() - etag = f'"{content_md5}-{len(blob_parts[blob_id])}"' + if len(content) != expected_size: + return aiohttp.web.Response(status=412, text="Uploaded blob_size doesn't match expected blob_size") + + content_sha256 = hashlib.sha256(content).hexdigest() blobs[blob_id] = content - return aiohttp.web.Response(text=f"{etag}") + return aiohttp.web.Response(text=content_sha256) async def download(request): blob_id = request.query["blob_id"] @@ -1770,7 +1836,7 @@ async def download(request): app = aiohttp.web.Application() app.add_routes([aiohttp.web.put("/upload", upload)]) app.add_routes([aiohttp.web.get("/download", download)]) - app.add_routes([aiohttp.web.post("/complete_multipart", complete_multipart)]) + app.add_routes([aiohttp.web.post("/commit", commit)]) started = threading.Event() stop_server = threading.Event()