From 10282b3e6f7f545cd677df09a20af71bd2f7b8bf Mon Sep 17 00:00:00 2001 From: Elias Freider Date: Wed, 4 Sep 2024 15:36:59 +0200 Subject: [PATCH] Custom protoc plugin to generate a grpclib wrapper (#2181) * Adds custom proto generator plugin Generates an additional source file for wrapping the grpclib-generated api stub in a way that can facilitate generic Modal functionality for all calls --- .github/workflows/ci-cd.yml | 3 +- modal/_utils/blob_utils.py | 10 +- modal/_utils/grpc_utils.py | 56 +++++++++- modal/client.py | 9 +- modal/functions.py | 7 +- modal_proto/__init__.py | 2 +- protoc_plugin/plugin.py | 204 ++++++++++++++++++++++++++++++++++++ tasks.py | 40 ++++++- test/container_app_test.py | 4 +- test/grpc_utils_test.py | 3 +- 10 files changed, 313 insertions(+), 25 deletions(-) create mode 100755 protoc_plugin/plugin.py diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 46f401e55..f6f40f81a 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -88,8 +88,9 @@ jobs: run: | python -m venv venv source venv/bin/activate - pip install grpcio-tools==1.59.2 + pip install grpcio-tools==1.59.2 grpclib==0.4.7 python -m grpc_tools.protoc --python_out=. --grpclib_python_out=. --grpc_python_out=. -I . modal_proto/api.proto modal_proto/options.proto + python -m grpc_tools.protoc --plugin=protoc-gen-modal-grpclib-python=protoc_plugin/plugin.py --modal-grpclib-python_out=. -I . modal_proto/api.proto modal_proto/options.proto deactivate - name: Check entrypoint import diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index 1a2d37025..1148cfb10 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -14,7 +14,7 @@ from aiohttp.abc import AbstractStreamWriter from modal_proto import api_pb2 -from modal_proto.api_grpc import ModalClientStub +from modal_proto.modal_api_grpc import ModalClientModal from ..exception import ExecutionError from .async_utils import TaskContext, retry @@ -287,7 +287,7 @@ async def _blob_upload( return blob_id -async def blob_upload(payload: bytes, stub: ModalClientStub) -> str: +async def blob_upload(payload: bytes, stub: ModalClientModal) -> str: if isinstance(payload, str): logger.warning("Blob uploading string, not bytes - auto-encoding as utf8") payload = payload.encode("utf8") @@ -296,7 +296,7 @@ async def blob_upload(payload: bytes, stub: ModalClientStub) -> str: async def blob_upload_file( - file_obj: BinaryIO, stub: ModalClientStub, progress_report_cb: Optional[Callable] = None + file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None ) -> str: upload_hashes = get_upload_hashes(file_obj) return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb) @@ -316,7 +316,7 @@ async def _download_from_url(download_url: str) -> bytes: return await s3_resp.read() -async def blob_download(blob_id: str, stub: ModalClientStub) -> bytes: +async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes: # convenience function reading all of the downloaded file into memory req = api_pb2.BlobGetRequest(blob_id=blob_id) resp = await retry_transient_errors(stub.BlobGet, req) @@ -324,7 +324,7 @@ async def blob_download(blob_id: str, stub: ModalClientStub) -> bytes: return await _download_from_url(resp.download_url) -async def blob_iter(blob_id: str, stub: ModalClientStub) -> AsyncIterator[bytes]: +async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]: req = api_pb2.BlobGetRequest(blob_id=blob_id) resp = await retry_transient_errors(stub.BlobGet, req) download_url = resp.download_url diff --git a/modal/_utils/grpc_utils.py b/modal/_utils/grpc_utils.py index 769b45173..2b2c83b89 100644 --- a/modal/_utils/grpc_utils.py +++ b/modal/_utils/grpc_utils.py @@ -9,14 +9,21 @@ from typing import ( Any, AsyncIterator, + Collection, Dict, + Generic, + Mapping, Optional, + Tuple, TypeVar, + Union, ) import grpclib.client import grpclib.config import grpclib.events +import grpclib.protocol +import grpclib.stream from google.protobuf.message import Message from grpclib import GRPCError, Status from grpclib.exceptions import StreamTerminatedError @@ -111,11 +118,52 @@ async def send_request(event: grpclib.events.SendRequest) -> None: return channel +_Value = Union[str, bytes] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] + + +class UnaryUnaryWrapper(Generic[RequestType, ResponseType]): + wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType] + + def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]): + self.wrapped_method = wrapped_method + + @property + def name(self) -> str: + return self.wrapped_method.name + + async def __call__( + self, + req: RequestType, + *, + timeout: Optional[float] = None, + metadata: Optional[_MetadataLike] = None, + ) -> ResponseType: + # TODO: implement Client tracking and retries + return await self.wrapped_method(req, timeout=timeout, metadata=metadata) + + +class UnaryStreamWrapper(Generic[RequestType, ResponseType]): + wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType] + + def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]): + self.wrapped_method = wrapped_method + + def open( + self, + *, + timeout: Optional[float] = None, + metadata: Optional[_MetadataLike] = None, + ) -> grpclib.client.Stream[RequestType, ResponseType]: + # TODO: implement Client tracking and unary_stream-wrapper + return self.wrapped_method.open(timeout=timeout, metadata=metadata) + + async def unary_stream( - method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType], - request: _SendType, + method: UnaryStreamWrapper[RequestType, ResponseType], + request: RequestType, metadata: Optional[Any] = None, -) -> AsyncIterator[_RecvType]: +) -> AsyncIterator[ResponseType]: """Helper for making a unary-streaming gRPC request.""" async with method.open(metadata=metadata) as stream: await stream.send_message(request, end=True) @@ -124,7 +172,7 @@ async def unary_stream( async def retry_transient_errors( - fn: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], + fn: UnaryUnaryWrapper[RequestType, ResponseType], *args, base_delay: float = 0.1, max_delay: float = 1, diff --git a/modal/client.py b/modal/client.py index abc5a7f30..d5679d516 100644 --- a/modal/client.py +++ b/modal/client.py @@ -10,7 +10,7 @@ from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager -from modal_proto import api_grpc, api_pb2 +from modal_proto import api_grpc, api_pb2, modal_api_grpc from modal_version import __version__ from ._utils import async_utils @@ -99,11 +99,11 @@ def __init__( self.image_builder_version: Optional[str] = None self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None self._channel: Optional[grpclib.client.Channel] = None - self._stub: Optional[api_grpc.ModalClientStub] = None + self._stub: Optional[modal_api_grpc.ModalClientModal] = None self._snapshotted = False @property - def stub(self) -> api_grpc.ModalClientStub: + def stub(self) -> modal_api_grpc.ModalClientModal: """mdmd:hidden""" assert self._stub return self._stub @@ -125,7 +125,8 @@ async def _open(self): assert self._stub is None metadata = _get_metadata(self.client_type, self._credentials, self.version) self._channel = create_channel(self.server_url, metadata=metadata) - self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore + grpclib_stub = api_grpc.ModalClientStub(self._channel) + self._stub = modal_api_grpc.ModalClientModal(grpclib_stub) async def _close(self, prep_for_restore: bool = False): if self._pre_stop is not None: diff --git a/modal/functions.py b/modal/functions.py index 082653f0c..1d5fffffa 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -30,7 +30,8 @@ from synchronicity.combined_types import MethodWithAio from modal._output import FunctionCreationStatus -from modal_proto import api_grpc, api_pb2 +from modal_proto import api_pb2 +from modal_proto.modal_api_grpc import ModalClientModal from ._location import parse_cloud_provider from ._output import OutputManager @@ -99,9 +100,9 @@ class _Invocation: """Internal client representation of a single-input call to a Modal Function or Generator""" - stub: api_grpc.ModalClientStub + stub: ModalClientModal - def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client): + def __init__(self, stub: ModalClientModal, function_call_id: str, client: _Client): self.stub = stub self.client = client # Used by the deserializer. self.function_call_id = function_call_id # TODO: remove and use only input_id diff --git a/modal_proto/__init__.py b/modal_proto/__init__.py index 60dc7f423..12325a125 100644 --- a/modal_proto/__init__.py +++ b/modal_proto/__init__.py @@ -1 +1 @@ -# Copyright Modal Labs 2022 +# Copyright Modal Labs 2024 diff --git a/protoc_plugin/plugin.py b/protoc_plugin/plugin.py new file mode 100755 index 000000000..e5cf7816a --- /dev/null +++ b/protoc_plugin/plugin.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +# Copyright Modal Labs 2024 +# built by modifying grpclib.plugin.main, see https://github.com/vmagamedov/grpclib +# original: Copyright (c) 2019 , Vladimir Magamedov +import os +import sys +from collections import deque +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Collection, Deque, Dict, Iterator, List, NamedTuple, Optional, Tuple + +from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse +from google.protobuf.descriptor_pb2 import DescriptorProto, FileDescriptorProto +from grpclib import const + +_CARDINALITY = { + (False, False): const.Cardinality.UNARY_UNARY, + (True, False): const.Cardinality.STREAM_UNARY, + (False, True): const.Cardinality.UNARY_STREAM, + (True, True): const.Cardinality.STREAM_STREAM, +} + + +class Method(NamedTuple): + name: str + cardinality: const.Cardinality + request_type: str + reply_type: str + + +class Service(NamedTuple): + name: str + methods: List[Method] + + +class Buffer: + def __init__(self) -> None: + self._lines: List[str] = [] + self._indent = 0 + + def add(self, string: str, *args: Any, **kwargs: Any) -> None: + line = " " * self._indent * 4 + string.format(*args, **kwargs) + self._lines.append(line.rstrip(" ")) + + @contextmanager + def indent(self) -> Iterator[None]: + self._indent += 1 + try: + yield + finally: + self._indent -= 1 + + def content(self) -> str: + return "\n".join(self._lines) + "\n" + + +def render( + proto_file: str, + imports: Collection[str], + services: Collection[Service], + grpclib_module: str, +) -> str: + buf = Buffer() + buf.add("# Generated by the Modal Protocol Buffers compiler. DO NOT EDIT!") + buf.add("# source: {}", proto_file) + buf.add("# plugin: {}", __name__) + if not services: + return buf.content() + + buf.add("") + for mod in imports: + buf.add("import {}", mod) + for service in services: + buf.add("") + buf.add("") + grpclib_stub_name = f"{service.name}Stub" + buf.add("class {}Modal:", service.name) + with buf.indent(): + buf.add("") + buf.add("def __init__(self, grpclib_stub: {}.{}) -> None:".format(grpclib_module, grpclib_stub_name)) + with buf.indent(): + if len(service.methods) == 0: + buf.add("pass") + for method in service.methods: + name, cardinality, request_type, reply_type = method + wrapper_cls: str + if cardinality is const.Cardinality.UNARY_UNARY: + wrapper_cls = "modal._utils.grpc_utils.UnaryUnaryWrapper" + elif cardinality is const.Cardinality.UNARY_STREAM: + wrapper_cls = "modal._utils.grpc_utils.UnaryStreamWrapper" + # elif cardinality is const.Cardinality.STREAM_UNARY: + # wrapper_cls = StreamUnaryWrapper + # elif cardinality is const.Cardinality.STREAM_STREAM: + # wrapper_cls = StreamStreamWrapper + else: + raise TypeError(cardinality) + + original_method = f"grpclib_stub.{name}" + buf.add(f"self.{name} = {wrapper_cls}({original_method})") + + return buf.content() + + +def _get_proto(request: CodeGeneratorRequest, name: str) -> FileDescriptorProto: + return next(f for f in request.proto_file if f.name == name) + + +def _strip_proto(proto_file_path: str) -> str: + for suffix in [".protodevel", ".proto"]: + if proto_file_path.endswith(suffix): + return proto_file_path[: -len(suffix)] + + return proto_file_path + + +def _base_module_name(proto_file_path: str) -> str: + basename = _strip_proto(proto_file_path) + return basename.replace("-", "_").replace("/", ".") + + +def _proto2pb2_module_name(proto_file_path: str) -> str: + return _base_module_name(proto_file_path) + "_pb2" + + +def _proto2grpc_module_name(proto_file_path: str) -> str: + return _base_module_name(proto_file_path) + "_grpc" + + +def _type_names( + proto_file: FileDescriptorProto, + message_type: DescriptorProto, + parents: Optional[Deque[str]] = None, +) -> Iterator[Tuple[str, str]]: + if parents is None: + parents = deque() + + proto_name_parts = [""] + if proto_file.package: + proto_name_parts.append(proto_file.package) + proto_name_parts.extend(parents) + proto_name_parts.append(message_type.name) + + py_name_parts = [_proto2pb2_module_name(proto_file.name)] + py_name_parts.extend(parents) + py_name_parts.append(message_type.name) + + yield ".".join(proto_name_parts), ".".join(py_name_parts) + + parents.append(message_type.name) + for nested in message_type.nested_type: + yield from _type_names(proto_file, nested, parents=parents) + parents.pop() + + +def main() -> None: + with os.fdopen(sys.stdin.fileno(), "rb") as inp: + request = CodeGeneratorRequest.FromString(inp.read()) + + types_map: Dict[str, str] = {} + for pf in request.proto_file: + for mt in pf.message_type: + types_map.update(_type_names(pf, mt)) + + response = CodeGeneratorResponse() + + # See https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/implementing_proto3_presence.md # noqa + if hasattr(CodeGeneratorResponse, "Feature"): + response.supported_features = CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL + + for file_to_generate in request.file_to_generate: + proto_file = _get_proto(request, file_to_generate) + module_name = _proto2grpc_module_name(file_to_generate) + grpclib_module_path = Path(module_name.replace(".", "/") + ".py") + + imports = ["modal._utils.grpc_utils", module_name] + + services = [] + for service in proto_file.service: + methods = [] + for method in service.method: + cardinality = _CARDINALITY[(method.client_streaming, method.server_streaming)] + methods.append( + Method( + name=method.name, + cardinality=cardinality, + request_type=types_map[method.input_type], + reply_type=types_map[method.output_type], + ) + ) + services.append(Service(name=service.name, methods=methods)) + + file = response.file.add() + + file.name = str(grpclib_module_path.with_name("modal_" + grpclib_module_path.name)) + file.content = render( + proto_file=proto_file.name, imports=imports, services=services, grpclib_module=module_name + ) + + with os.fdopen(sys.stdout.fileno(), "wb") as out: + out.write(response.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/tasks.py b/tasks.py index 6c6722bf2..3f54c46b1 100644 --- a/tasks.py +++ b/tasks.py @@ -9,9 +9,11 @@ import re import subprocess import sys +from contextlib import contextmanager from datetime import date from pathlib import Path -from typing import List, Optional +from tempfile import NamedTemporaryFile +from typing import Generator, List, Optional import requests from invoke import task @@ -23,14 +25,43 @@ copyright_header_full = f"{copyright_header_start} {year}" +@contextmanager +def python_file_as_executable(path: Path) -> Generator[Path, None, None]: + if sys.platform == "win32": + # windows can't just run shebang:ed python files, so we create a .bat file that calls it + src = f"""@echo off +{sys.executable} {path} +""" + with NamedTemporaryFile(mode="w", suffix=".bat", encoding="ascii", delete=False) as f: + f.write(src) + + try: + yield Path(f.name) + finally: + Path(f.name).unlink() + else: + yield path + + @task def protoc(ctx): + protoc_cmd = f"{sys.executable} -m grpc_tools.protoc" + input_files = "modal_proto/api.proto modal_proto/options.proto" py_protoc = ( - f"{sys.executable} -m grpc_tools.protoc" - + " --python_out=. --grpclib_python_out=. --grpc_python_out=. --mypy_out=. --mypy_grpc_out=." + protoc_cmd + " --python_out=. --grpclib_python_out=." + " --grpc_python_out=. --mypy_out=. --mypy_grpc_out=." ) print(py_protoc) - ctx.run(f"{py_protoc} -I . " "modal_proto/api.proto " "modal_proto/options.proto ") + # generate grpcio and grpclib proto files: + ctx.run(f"{py_protoc} -I . {input_files}") + + # generate modal-specific wrapper around grpclib api stub using custom plugin: + grpc_plugin_pyfile = Path(__file__).parent / "protoc_plugin" / "plugin.py" + + with python_file_as_executable(grpc_plugin_pyfile) as grpc_plugin_executable: + ctx.run( + f"{protoc_cmd} --plugin=protoc-gen-modal-grpclib-python={grpc_plugin_executable}" + + f" --modal-grpclib-python_out=. -I . {input_files}" + ) @task @@ -137,6 +168,7 @@ def check_copyright(ctx, fix=False): and not fn.endswith(".notebook.py") # vendored code has a different copyright and "_vendor" not in root + and "protoc_plugin" not in root # third-party code (i.e., in a local venv) has a different copyright and "/site-packages/" not in root ) diff --git a/test/container_app_test.py b/test/container_app_test.py index 054b94325..5dee0c106 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -15,7 +15,7 @@ from modal.client import _Client from modal.exception import InvalidError from modal.running_app import RunningApp -from modal_proto import api_grpc, api_pb2 +from modal_proto import api_grpc, api_pb2, modal_api_grpc def my_f_1(x): @@ -85,7 +85,7 @@ async def test_container_snapshot_reference_capture(container_client, tmpdir, se from modal.runner import deploy_app channel = create_channel(servicer.client_addr) - client_stub = api_grpc.ModalClientStub(channel) + client_stub = modal_api_grpc.ModalClientModal(api_grpc.ModalClientStub(channel)) app.function()(square) app_name = "my-app" app_id = deploy_app(app, app_name, client=container_client).app_id diff --git a/test/grpc_utils_test.py b/test/grpc_utils_test.py index 027d63b93..128685abe 100644 --- a/test/grpc_utils_test.py +++ b/test/grpc_utils_test.py @@ -6,6 +6,7 @@ from modal._utils.grpc_utils import create_channel, retry_transient_errors from modal_proto import api_grpc, api_pb2 +from modal_proto.modal_api_grpc import ModalClientModal from .supports.skip import skip_windows_unix_socket @@ -40,7 +41,7 @@ async def test_unix_channel(servicer): @pytest.mark.asyncio async def test_retry_transient_errors(servicer): channel = create_channel(servicer.client_addr) - client_stub = api_grpc.ModalClientStub(channel) + client_stub = ModalClientModal(api_grpc.ModalClientStub(channel)) # Use the BlobCreate request for retries req = api_pb2.BlobCreateRequest()