diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 8e305c5fd..f3bc1184b 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -508,7 +508,7 @@ async def _create_input( args_serialized = serialize((args, kwargs)) if len(args_serialized) > MAX_OBJECT_SIZE_BYTES: - args_blob_id = await blob_upload(args_serialized, client.stub) + args_blob_id = await blob_upload(args_serialized, client.blobs_stub) return api_pb2.FunctionPutInputsItem( input=api_pb2.FunctionInput( diff --git a/modal/mount.py b/modal/mount.py index f7e602bd8..d23cb9f0b 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -483,7 +483,7 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)") async with blob_upload_concurrency: with file_spec.source() as fp: - blob_id = await blob_upload_file(fp, resolver.client.stub) + blob_id = await blob_upload_file(fp, resolver.client.blobs_stub) logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}") request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex) else: diff --git a/modal/network_file_system.py b/modal/network_file_system.py index 7d2f9a8bf..d7020d901 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -239,7 +239,7 @@ async def write_file(self, remote_path: str, fp: BinaryIO, progress_cb: Optional if data_size > LARGE_FILE_LIMIT: progress_task_id = progress_cb(name=remote_path, size=data_size) blob_id = await blob_upload_file( - fp, self._client.stub, progress_report_cb=functools.partial(progress_cb, progress_task_id) + fp, self._client.blobs_stub, progress_report_cb=functools.partial(progress_cb, progress_task_id) ) req = api_pb2.SharedVolumePutFileRequest( shared_volume_id=self.object_id, diff --git a/modal/volume.py b/modal/volume.py index 02cec5805..9ea2e4a09 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -632,7 +632,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)") with file_spec.source() as fp: blob_id = await blob_upload_file( - fp, self._client.stub, functools.partial(self._progress_cb, progress_task_id) + fp, self._client.blobs_stub, functools.partial(self._progress_cb, progress_task_id) ) logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}") request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex) diff --git a/test/blob_test.py b/test/blob_test.py index 06687e39d..7e8607030 100644 --- a/test/blob_test.py +++ b/test/blob_test.py @@ -21,7 +21,7 @@ @pytest.mark.asyncio async def test_blob_put_get(servicer, blob_server, client): # Upload - blob_id = await blob_upload.aio(b"Hello, world", client.stub) + blob_id = await blob_upload.aio(b"Hello, world", client.blobs_stub) # Download data = await blob_download.aio(blob_id, client.stub) @@ -31,7 +31,7 @@ async def test_blob_put_get(servicer, blob_server, client): @pytest.mark.asyncio async def test_blob_put_failure(servicer, blob_server, client): with pytest.raises(ExecutionError): - await blob_upload.aio(b"FAILURE", client.stub) + await blob_upload.aio(b"FAILURE", client.blobs_stub) @pytest.mark.asyncio @@ -43,7 +43,7 @@ async def test_blob_get_failure(servicer, blob_server, client): @pytest.mark.asyncio async def test_blob_large(servicer, blob_server, client): data = b"*" * 10_000_000 - blob_id = await blob_upload.aio(data, client.stub) + blob_id = await blob_upload.aio(data, client.blobs_stub) assert await blob_download.aio(blob_id, client.stub) == data @@ -57,17 +57,17 @@ async def test_blob_multipart(servicer, blob_server, client, monkeypatch, tmp_pa # - make last part significantly shorter than rest, creating uneven upload time. data_len = (256 * multipart_threshold) + (multipart_threshold // 2) data = random.randbytes(data_len) # random data will not hide byte re-ordering corruption - blob_id = await blob_upload.aio(data, client.stub) + blob_id = await blob_upload.aio(data, client.blobs_stub) assert await blob_download.aio(blob_id, client.stub) == data data_len = (256 * multipart_threshold) + (multipart_threshold // 2) data = random.randbytes(data_len) # random data will not hide byte re-ordering corruption data_filepath = tmp_path / "temp.bin" data_filepath.write_bytes(data) - blob_id = await blob_upload_file.aio(data_filepath.open("rb"), client.stub) + blob_id = await blob_upload_file.aio(data_filepath.open("rb"), client.blobs_stub) assert await blob_download.aio(blob_id, client.stub) == data def test_sync(blob_server, client): # just tests that tests running blocking calls that upload to blob storage don't deadlock - blob_upload(b"adsfadsf", client.stub) + blob_upload(b"adsfadsf", client.blobs_stub) diff --git a/test/container_test.py b/test/container_test.py index fefff9072..4113ede40 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -68,7 +68,7 @@ def _get_inputs( client: Optional[Client] = None, ) -> list[api_pb2.FunctionGetInputsResponse]: if upload_to_blob: - args_blob_id = blob_upload(serialize(args), client.stub) + args_blob_id = blob_upload(serialize(args), client.blobs_stub) input_pb = api_pb2.FunctionInput( args_blob_id=args_blob_id, data_format=api_pb2.DATA_FORMAT_PICKLE, method_name=method_name or "" )