Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom protoc plugin to generate a grpclib wrapper #2181

Merged
merged 17 commits into from
Sep 4, 2024
10 changes: 5 additions & 5 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aiohttp.abc import AbstractStreamWriter

from modal_proto import api_pb2
from modal_proto.api_grpc import ModalClientStub
from modal_proto.modal_api_grpc import ModalClientModal

from ..exception import ExecutionError
from .async_utils import TaskContext, retry
Expand Down Expand Up @@ -287,7 +287,7 @@ async def _blob_upload(
return blob_id


async def blob_upload(payload: bytes, stub: ModalClientStub) -> str:
async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
if isinstance(payload, str):
logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
payload = payload.encode("utf8")
Expand All @@ -296,7 +296,7 @@ async def blob_upload(payload: bytes, stub: ModalClientStub) -> str:


async def blob_upload_file(
file_obj: BinaryIO, stub: ModalClientStub, progress_report_cb: Optional[Callable] = None
file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None
) -> str:
upload_hashes = get_upload_hashes(file_obj)
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)
Expand All @@ -316,15 +316,15 @@ async def _download_from_url(download_url: str) -> bytes:
return await s3_resp.read()


async def blob_download(blob_id: str, stub: ModalClientStub) -> bytes:
async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
# convenience function reading all of the downloaded file into memory
req = api_pb2.BlobGetRequest(blob_id=blob_id)
resp = await retry_transient_errors(stub.BlobGet, req)

return await _download_from_url(resp.download_url)


async def blob_iter(blob_id: str, stub: ModalClientStub) -> AsyncIterator[bytes]:
async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
req = api_pb2.BlobGetRequest(blob_id=blob_id)
resp = await retry_transient_errors(stub.BlobGet, req)
download_url = resp.download_url
Expand Down
56 changes: 52 additions & 4 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import grpclib.client
import grpclib.config
import grpclib.events
import grpclib.protocol
import grpclib.stream
from google.protobuf.message import Message
from grpclib import GRPCError, Status
from grpclib.exceptions import StreamTerminatedError
Expand Down Expand Up @@ -111,11 +118,52 @@ async def send_request(event: grpclib.events.SendRequest) -> None:
return channel


_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]


class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

@property
def name(self) -> str:
return self.wrapped_method.name

async def __call__(
self,
req: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)


async def unary_stream(
method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType],
request: _SendType,
method: UnaryStreamWrapper[RequestType, ResponseType],
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[_RecvType]:
) -> AsyncIterator[ResponseType]:
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
Expand All @@ -124,7 +172,7 @@ async def unary_stream(


async def retry_transient_errors(
fn: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType],
fn: UnaryUnaryWrapper[RequestType, ResponseType],
*args,
base_delay: float = 0.1,
max_delay: float = 1,
Expand Down
9 changes: 5 additions & 4 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from grpclib import GRPCError, Status
from synchronicity.async_wrap import asynccontextmanager

from modal_proto import api_grpc, api_pb2
from modal_proto import api_grpc, api_pb2, modal_api_grpc
from modal_version import __version__

from ._utils import async_utils
Expand Down Expand Up @@ -99,11 +99,11 @@ def __init__(
self.image_builder_version: Optional[str] = None
self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[api_grpc.ModalClientStub] = None
self._stub: Optional[modal_api_grpc.ModalClientModal] = None
self._snapshotted = False

@property
def stub(self) -> api_grpc.ModalClientStub:
def stub(self) -> modal_api_grpc.ModalClientModal:
"""mdmd:hidden"""
assert self._stub
return self._stub
Expand All @@ -125,7 +125,8 @@ async def _open(self):
assert self._stub is None
metadata = _get_metadata(self.client_type, self._credentials, self.version)
self._channel = create_channel(self.server_url, metadata=metadata)
self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore
grpclib_stub = api_grpc.ModalClientStub(self._channel)
self._stub = modal_api_grpc.ModalClientModal(grpclib_stub)

async def _close(self, prep_for_restore: bool = False):
if self._pre_stop is not None:
Expand Down
7 changes: 4 additions & 3 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from synchronicity.combined_types import MethodWithAio

from modal._output import FunctionCreationStatus
from modal_proto import api_grpc, api_pb2
from modal_proto import api_pb2
from modal_proto.modal_api_grpc import ModalClientModal

from ._location import parse_cloud_provider
from ._output import OutputManager
Expand Down Expand Up @@ -99,9 +100,9 @@
class _Invocation:
"""Internal client representation of a single-input call to a Modal Function or Generator"""

stub: api_grpc.ModalClientStub
stub: ModalClientModal

def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client):
def __init__(self, stub: ModalClientModal, function_call_id: str, client: _Client):
self.stub = stub
self.client = client # Used by the deserializer.
self.function_call_id = function_call_id # TODO: remove and use only input_id
Expand Down
1 change: 0 additions & 1 deletion modal_proto/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
# Copyright Modal Labs 2022
203 changes: 203 additions & 0 deletions protoc_plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#!/usr/bin/env python
# built by modifying grpclib.plugin.main, see https://github.com/vmagamedov/grpclib
# original: Copyright (c) 2019 , Vladimir Magamedov
import os
import sys
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Collection, Deque, Dict, Iterator, List, NamedTuple, Optional, Tuple

from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse
from google.protobuf.descriptor_pb2 import DescriptorProto, FileDescriptorProto
from grpclib import const

_CARDINALITY = {
(False, False): const.Cardinality.UNARY_UNARY,
(True, False): const.Cardinality.STREAM_UNARY,
(False, True): const.Cardinality.UNARY_STREAM,
(True, True): const.Cardinality.STREAM_STREAM,
}


class Method(NamedTuple):
name: str
cardinality: const.Cardinality
request_type: str
reply_type: str


class Service(NamedTuple):
name: str
methods: List[Method]


class Buffer:
def __init__(self) -> None:
self._lines: List[str] = []
self._indent = 0

def add(self, string: str, *args: Any, **kwargs: Any) -> None:
line = " " * self._indent * 4 + string.format(*args, **kwargs)
self._lines.append(line.rstrip(" "))

@contextmanager
def indent(self) -> Iterator[None]:
self._indent += 1
try:
yield
finally:
self._indent -= 1

def content(self) -> str:
return "\n".join(self._lines) + "\n"


def render(
proto_file: str,
imports: Collection[str],
services: Collection[Service],
grpclib_module: str,
) -> str:
buf = Buffer()
buf.add("# Generated by the Modal Protocol Buffers compiler. DO NOT EDIT!")
buf.add("# source: {}", proto_file)
buf.add("# plugin: {}", __name__)
if not services:
return buf.content()

buf.add("")
for mod in imports:
buf.add("import {}", mod)
for service in services:
buf.add("")
buf.add("")
grpclib_stub_name = f"{service.name}Stub"
buf.add("class {}Modal:", service.name)
with buf.indent():
buf.add("")
buf.add("def __init__(self, grpclib_stub: {}.{}) -> None:".format(grpclib_module, grpclib_stub_name))
with buf.indent():
if len(service.methods) == 0:
buf.add("pass")
for method in service.methods:
name, cardinality, request_type, reply_type = method
wrapper_cls: type
if cardinality is const.Cardinality.UNARY_UNARY:
wrapper_cls = "modal._utils.grpc_utils.UnaryUnaryWrapper"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need modal._utils.grpc_utils imported under a TYPE_CHECKING guard to use a forward reference? Never been totally clear on that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it's not a forward reference - it's a "real" reference in the generated code and it's added as an import on line 175: https://github.com/modal-labs/modal-client/pull/2181/files#diff-d721170dbb2b36a3f20394f9563415ebe89a0916e4957cf99e7b615cfd8c772fR175

In case of forward/str references I think the TYPE_CHECKING-guarded imports are still required for type checkers to know what it's dealing with

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops didn't read carefully enough!

elif cardinality is const.Cardinality.UNARY_STREAM:
wrapper_cls = "modal._utils.grpc_utils.UnaryStreamWrapper"
# elif cardinality is const.Cardinality.STREAM_UNARY:
# wrapper_cls = StreamUnaryWrapper
# elif cardinality is const.Cardinality.STREAM_STREAM:
# wrapper_cls = StreamStreamWrapper
else:
raise TypeError(cardinality)

original_method = f"grpclib_stub.{name}"
buf.add(f"self.{name} = {wrapper_cls}({original_method})")

return buf.content()


def _get_proto(request: CodeGeneratorRequest, name: str) -> FileDescriptorProto:
return next(f for f in request.proto_file if f.name == name)


def _strip_proto(proto_file_path: str) -> str:
for suffix in [".protodevel", ".proto"]:
if proto_file_path.endswith(suffix):
return proto_file_path[: -len(suffix)]

return proto_file_path


def _base_module_name(proto_file_path: str) -> str:
basename = _strip_proto(proto_file_path)
return basename.replace("-", "_").replace("/", ".")


def _proto2pb2_module_name(proto_file_path: str) -> str:
return _base_module_name(proto_file_path) + "_pb2"


def _proto2grpc_module_name(proto_file_path: str) -> str:
return _base_module_name(proto_file_path) + "_grpc"


def _type_names(
proto_file: FileDescriptorProto,
message_type: DescriptorProto,
parents: Optional[Deque[str]] = None,
) -> Iterator[Tuple[str, str]]:
if parents is None:
parents = deque()

proto_name_parts = [""]
if proto_file.package:
proto_name_parts.append(proto_file.package)
proto_name_parts.extend(parents)
proto_name_parts.append(message_type.name)

py_name_parts = [_proto2pb2_module_name(proto_file.name)]
py_name_parts.extend(parents)
py_name_parts.append(message_type.name)

yield ".".join(proto_name_parts), ".".join(py_name_parts)

parents.append(message_type.name)
for nested in message_type.nested_type:
yield from _type_names(proto_file, nested, parents=parents)
parents.pop()


def main() -> None:
with os.fdopen(sys.stdin.fileno(), "rb") as inp:
request = CodeGeneratorRequest.FromString(inp.read())

types_map: Dict[str, str] = {}
for pf in request.proto_file:
for mt in pf.message_type:
types_map.update(_type_names(pf, mt))

response = CodeGeneratorResponse()

# See https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/implementing_proto3_presence.md # noqa
if hasattr(CodeGeneratorResponse, "Feature"):
response.supported_features = CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL

for file_to_generate in request.file_to_generate:
proto_file = _get_proto(request, file_to_generate)
module_name = _proto2grpc_module_name(file_to_generate)
grpclib_module_path = Path(module_name.replace(".", "/") + ".py")

imports = ["modal._utils.grpc_utils", module_name]

services = []
for service in proto_file.service:
methods = []
for method in service.method:
cardinality = _CARDINALITY[(method.client_streaming, method.server_streaming)]
methods.append(
Method(
name=method.name,
cardinality=cardinality,
request_type=types_map[method.input_type],
reply_type=types_map[method.output_type],
)
)
services.append(Service(name=service.name, methods=methods))

file = response.file.add()

file.name = str(grpclib_module_path.with_name("modal_" + grpclib_module_path.name))
file.content = render(
proto_file=proto_file.name, imports=imports, services=services, grpclib_module=module_name
)

with os.fdopen(sys.stdout.fileno(), "wb") as out:
out.write(response.SerializeToString())


if __name__ == "__main__":
main()
Loading
Loading