Skip to content

Commit

Permalink
Custom protoc plugin to generate a grpclib wrapper (#2181)
Browse files Browse the repository at this point in the history
* Adds custom proto generator plugin

Generates an additional source file for wrapping the grpclib-generated api stub
in a way that can facilitate generic Modal functionality for all calls
  • Loading branch information
freider authored Sep 4, 2024
1 parent 5d4c1ed commit 10282b3
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ jobs:
run: |
python -m venv venv
source venv/bin/activate
pip install grpcio-tools==1.59.2
pip install grpcio-tools==1.59.2 grpclib==0.4.7
python -m grpc_tools.protoc --python_out=. --grpclib_python_out=. --grpc_python_out=. -I . modal_proto/api.proto modal_proto/options.proto
python -m grpc_tools.protoc --plugin=protoc-gen-modal-grpclib-python=protoc_plugin/plugin.py --modal-grpclib-python_out=. -I . modal_proto/api.proto modal_proto/options.proto
deactivate
- name: Check entrypoint import
Expand Down
10 changes: 5 additions & 5 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aiohttp.abc import AbstractStreamWriter

from modal_proto import api_pb2
from modal_proto.api_grpc import ModalClientStub
from modal_proto.modal_api_grpc import ModalClientModal

from ..exception import ExecutionError
from .async_utils import TaskContext, retry
Expand Down Expand Up @@ -287,7 +287,7 @@ async def _blob_upload(
return blob_id


async def blob_upload(payload: bytes, stub: ModalClientStub) -> str:
async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
if isinstance(payload, str):
logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
payload = payload.encode("utf8")
Expand All @@ -296,7 +296,7 @@ async def blob_upload(payload: bytes, stub: ModalClientStub) -> str:


async def blob_upload_file(
file_obj: BinaryIO, stub: ModalClientStub, progress_report_cb: Optional[Callable] = None
file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None
) -> str:
upload_hashes = get_upload_hashes(file_obj)
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
Expand All @@ -316,15 +316,15 @@ async def _download_from_url(download_url: str) -> bytes:
return await s3_resp.read()


async def blob_download(blob_id: str, stub: ModalClientStub) -> bytes:
async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
# convenience function reading all of the downloaded file into memory
req = api_pb2.BlobGetRequest(blob_id=blob_id)
resp = await retry_transient_errors(stub.BlobGet, req)

return await _download_from_url(resp.download_url)


async def blob_iter(blob_id: str, stub: ModalClientStub) -> AsyncIterator[bytes]:
async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
req = api_pb2.BlobGetRequest(blob_id=blob_id)
resp = await retry_transient_errors(stub.BlobGet, req)
download_url = resp.download_url
Expand Down
56 changes: 52 additions & 4 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import grpclib.client
import grpclib.config
import grpclib.events
import grpclib.protocol
import grpclib.stream
from google.protobuf.message import Message
from grpclib import GRPCError, Status
from grpclib.exceptions import StreamTerminatedError
Expand Down Expand Up @@ -111,11 +118,52 @@ async def send_request(event: grpclib.events.SendRequest) -> None:
return channel


_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]


class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

@property
def name(self) -> str:
return self.wrapped_method.name

async def __call__(
self,
req: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)


async def unary_stream(
method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType],
request: _SendType,
method: UnaryStreamWrapper[RequestType, ResponseType],
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[_RecvType]:
) -> AsyncIterator[ResponseType]:
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
Expand All @@ -124,7 +172,7 @@ async def unary_stream(


async def retry_transient_errors(
fn: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType],
fn: UnaryUnaryWrapper[RequestType, ResponseType],
*args,
base_delay: float = 0.1,
max_delay: float = 1,
Expand Down
9 changes: 5 additions & 4 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from grpclib import GRPCError, Status
from synchronicity.async_wrap import asynccontextmanager

from modal_proto import api_grpc, api_pb2
from modal_proto import api_grpc, api_pb2, modal_api_grpc
from modal_version import __version__

from ._utils import async_utils
Expand Down Expand Up @@ -99,11 +99,11 @@ def __init__(
self.image_builder_version: Optional[str] = None
self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[api_grpc.ModalClientStub] = None
self._stub: Optional[modal_api_grpc.ModalClientModal] = None
self._snapshotted = False

@property
def stub(self) -> api_grpc.ModalClientStub:
def stub(self) -> modal_api_grpc.ModalClientModal:
"""mdmd:hidden"""
assert self._stub
return self._stub
Expand All @@ -125,7 +125,8 @@ async def _open(self):
assert self._stub is None
metadata = _get_metadata(self.client_type, self._credentials, self.version)
self._channel = create_channel(self.server_url, metadata=metadata)
self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore
grpclib_stub = api_grpc.ModalClientStub(self._channel)
self._stub = modal_api_grpc.ModalClientModal(grpclib_stub)

async def _close(self, prep_for_restore: bool = False):
if self._pre_stop is not None:
Expand Down
7 changes: 4 additions & 3 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from synchronicity.combined_types import MethodWithAio

from modal._output import FunctionCreationStatus
from modal_proto import api_grpc, api_pb2
from modal_proto import api_pb2
from modal_proto.modal_api_grpc import ModalClientModal

from ._location import parse_cloud_provider
from ._output import OutputManager
Expand Down Expand Up @@ -99,9 +100,9 @@
class _Invocation:
"""Internal client representation of a single-input call to a Modal Function or Generator"""

stub: api_grpc.ModalClientStub
stub: ModalClientModal

def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client):
def __init__(self, stub: ModalClientModal, function_call_id: str, client: _Client):
self.stub = stub
self.client = client # Used by the deserializer.
self.function_call_id = function_call_id # TODO: remove and use only input_id
Expand Down
2 changes: 1 addition & 1 deletion modal_proto/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# Copyright Modal Labs 2022
# Copyright Modal Labs 2024
Loading

0 comments on commit 10282b3

Please sign in to comment.