Skip to content

Commit

Permalink
rest: Pass in the right stub to the upload APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
dflemstr committed Dec 4, 2024
1 parent e80922d commit 6c6b211
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions modal/_runtime/container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def serialize_data_format(self, obj: Any, data_format: int) -> bytes:

async def format_blob_data(self, data: bytes) -> dict[str, Any]:
return (
{"data_blob_id": await blob_upload(data, self._client.stub)}
{"data_blob_id": await blob_upload(data, self._client.blobs_stub)}
if len(data) > MAX_OBJECT_SIZE_BYTES
else {"data": data}
)
Expand All @@ -523,7 +523,7 @@ async def put_data_out(
for i, message_bytes in enumerate(messages_bytes):
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
chunk.data_blob_id = await blob_upload(message_bytes, self._client.blobs_stub)
else:
chunk.data = message_bytes
data_chunks.append(chunk)
Expand Down
2 changes: 1 addition & 1 deletion modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ async def _create_input(
args_serialized = serialize((args, kwargs))

if len(args_serialized) > MAX_OBJECT_SIZE_BYTES:
args_blob_id = await blob_upload(args_serialized, client.stub)
args_blob_id = await blob_upload(args_serialized, client.blobs_stub)

return api_pb2.FunctionPutInputsItem(
input=api_pb2.FunctionInput(
Expand Down
2 changes: 1 addition & 1 deletion modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)")
async with blob_upload_concurrency:
with file_spec.source() as fp:
blob_id = await blob_upload_file(fp, resolver.client.stub)
blob_id = await blob_upload_file(fp, resolver.client.blobs_stub)
logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}")
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex)
else:
Expand Down
2 changes: 1 addition & 1 deletion modal/network_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async def write_file(self, remote_path: str, fp: BinaryIO, progress_cb: Optional
if data_size > LARGE_FILE_LIMIT:
progress_task_id = progress_cb(name=remote_path, size=data_size)
blob_id = await blob_upload_file(
fp, self._client.stub, progress_report_cb=functools.partial(progress_cb, progress_task_id)
fp, self._client.blobs_stub, progress_report_cb=functools.partial(progress_cb, progress_task_id)
)
req = api_pb2.SharedVolumePutFileRequest(
shared_volume_id=self.object_id,
Expand Down
2 changes: 1 addition & 1 deletion modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)")
with file_spec.source() as fp:
blob_id = await blob_upload_file(
fp, self._client.stub, functools.partial(self._progress_cb, progress_task_id)
fp, self._client.blobs_stub, functools.partial(self._progress_cb, progress_task_id)
)
logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}")
request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex)
Expand Down
12 changes: 6 additions & 6 deletions test/blob_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@pytest.mark.asyncio
async def test_blob_put_get(servicer, blob_server, client):
# Upload
blob_id = await blob_upload.aio(b"Hello, world", client.stub)
blob_id = await blob_upload.aio(b"Hello, world", client.blobs_stub)

# Download
data = await blob_download.aio(blob_id, client.stub)
Expand All @@ -31,7 +31,7 @@ async def test_blob_put_get(servicer, blob_server, client):
@pytest.mark.asyncio
async def test_blob_put_failure(servicer, blob_server, client):
with pytest.raises(ExecutionError):
await blob_upload.aio(b"FAILURE", client.stub)
await blob_upload.aio(b"FAILURE", client.blobs_stub)


@pytest.mark.asyncio
Expand All @@ -43,7 +43,7 @@ async def test_blob_get_failure(servicer, blob_server, client):
@pytest.mark.asyncio
async def test_blob_large(servicer, blob_server, client):
data = b"*" * 10_000_000
blob_id = await blob_upload.aio(data, client.stub)
blob_id = await blob_upload.aio(data, client.blobs_stub)
assert await blob_download.aio(blob_id, client.stub) == data


Expand All @@ -57,17 +57,17 @@ async def test_blob_multipart(servicer, blob_server, client, monkeypatch, tmp_pa
# - make last part significantly shorter than rest, creating uneven upload time.
data_len = (256 * multipart_threshold) + (multipart_threshold // 2)
data = random.randbytes(data_len) # random data will not hide byte re-ordering corruption
blob_id = await blob_upload.aio(data, client.stub)
blob_id = await blob_upload.aio(data, client.blobs_stub)
assert await blob_download.aio(blob_id, client.stub) == data

data_len = (256 * multipart_threshold) + (multipart_threshold // 2)
data = random.randbytes(data_len) # random data will not hide byte re-ordering corruption
data_filepath = tmp_path / "temp.bin"
data_filepath.write_bytes(data)
blob_id = await blob_upload_file.aio(data_filepath.open("rb"), client.stub)
blob_id = await blob_upload_file.aio(data_filepath.open("rb"), client.blobs_stub)
assert await blob_download.aio(blob_id, client.stub) == data


def test_sync(blob_server, client):
# just tests that tests running blocking calls that upload to blob storage don't deadlock
blob_upload(b"adsfadsf", client.stub)
blob_upload(b"adsfadsf", client.blobs_stub)
2 changes: 1 addition & 1 deletion test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_inputs(
client: Optional[Client] = None,
) -> list[api_pb2.FunctionGetInputsResponse]:
if upload_to_blob:
args_blob_id = blob_upload(serialize(args), client.stub)
args_blob_id = blob_upload(serialize(args), client.blobs_stub)
input_pb = api_pb2.FunctionInput(
args_blob_id=args_blob_id, data_format=api_pb2.DATA_FORMAT_PICKLE, method_name=method_name or ""
)
Expand Down

0 comments on commit 6c6b211

Please sign in to comment.