diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index f6f40f81a..46f401e55 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -88,9 +88,8 @@ jobs: run: | python -m venv venv source venv/bin/activate - pip install grpcio-tools==1.59.2 grpclib==0.4.7 + pip install grpcio-tools==1.59.2 python -m grpc_tools.protoc --python_out=. --grpclib_python_out=. --grpc_python_out=. -I . modal_proto/api.proto modal_proto/options.proto - python -m grpc_tools.protoc --plugin=protoc-gen-modal-grpclib-python=protoc_plugin/plugin.py --modal-grpclib-python_out=. -I . modal_proto/api.proto modal_proto/options.proto deactivate - name: Check entrypoint import diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index 1148cfb10..1a2d37025 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -14,7 +14,7 @@ from aiohttp.abc import AbstractStreamWriter from modal_proto import api_pb2 -from modal_proto.modal_api_grpc import ModalClientModal +from modal_proto.api_grpc import ModalClientStub from ..exception import ExecutionError from .async_utils import TaskContext, retry @@ -287,7 +287,7 @@ async def _blob_upload( return blob_id -async def blob_upload(payload: bytes, stub: ModalClientModal) -> str: +async def blob_upload(payload: bytes, stub: ModalClientStub) -> str: if isinstance(payload, str): logger.warning("Blob uploading string, not bytes - auto-encoding as utf8") payload = payload.encode("utf8") @@ -296,7 +296,7 @@ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str: async def blob_upload_file( - file_obj: BinaryIO, stub: ModalClientModal, progress_report_cb: Optional[Callable] = None + file_obj: BinaryIO, stub: ModalClientStub, progress_report_cb: Optional[Callable] = None ) -> str: upload_hashes = get_upload_hashes(file_obj) return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb) @@ -316,7 +316,7 @@ async def _download_from_url(download_url: str) -> bytes: return await s3_resp.read() -async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes: +async def blob_download(blob_id: str, stub: ModalClientStub) -> bytes: # convenience function reading all of the downloaded file into memory req = api_pb2.BlobGetRequest(blob_id=blob_id) resp = await retry_transient_errors(stub.BlobGet, req) @@ -324,7 +324,7 @@ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes: return await _download_from_url(resp.download_url) -async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]: +async def blob_iter(blob_id: str, stub: ModalClientStub) -> AsyncIterator[bytes]: req = api_pb2.BlobGetRequest(blob_id=blob_id) resp = await retry_transient_errors(stub.BlobGet, req) download_url = resp.download_url diff --git a/modal/_utils/grpc_utils.py b/modal/_utils/grpc_utils.py index 2b2c83b89..769b45173 100644 --- a/modal/_utils/grpc_utils.py +++ b/modal/_utils/grpc_utils.py @@ -9,21 +9,14 @@ from typing import ( Any, AsyncIterator, - Collection, Dict, - Generic, - Mapping, Optional, - Tuple, TypeVar, - Union, ) import grpclib.client import grpclib.config import grpclib.events -import grpclib.protocol -import grpclib.stream from google.protobuf.message import Message from grpclib import GRPCError, Status from grpclib.exceptions import StreamTerminatedError @@ -118,52 +111,11 @@ async def send_request(event: grpclib.events.SendRequest) -> None: return channel -_Value = Union[str, bytes] -_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] - - -class UnaryUnaryWrapper(Generic[RequestType, ResponseType]): - wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType] - - def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]): - self.wrapped_method = wrapped_method - - @property - def name(self) -> str: - return self.wrapped_method.name - - async def __call__( - self, - req: RequestType, - *, - timeout: Optional[float] = None, - metadata: Optional[_MetadataLike] = None, - ) -> ResponseType: - # TODO: implement Client tracking and retries - return await self.wrapped_method(req, timeout=timeout, metadata=metadata) - - -class UnaryStreamWrapper(Generic[RequestType, ResponseType]): - wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType] - - def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]): - self.wrapped_method = wrapped_method - - def open( - self, - *, - timeout: Optional[float] = None, - metadata: Optional[_MetadataLike] = None, - ) -> grpclib.client.Stream[RequestType, ResponseType]: - # TODO: implement Client tracking and unary_stream-wrapper - return self.wrapped_method.open(timeout=timeout, metadata=metadata) - - async def unary_stream( - method: UnaryStreamWrapper[RequestType, ResponseType], - request: RequestType, + method: grpclib.client.UnaryStreamMethod[_SendType, _RecvType], + request: _SendType, metadata: Optional[Any] = None, -) -> AsyncIterator[ResponseType]: +) -> AsyncIterator[_RecvType]: """Helper for making a unary-streaming gRPC request.""" async with method.open(metadata=metadata) as stream: await stream.send_message(request, end=True) @@ -172,7 +124,7 @@ async def unary_stream( async def retry_transient_errors( - fn: UnaryUnaryWrapper[RequestType, ResponseType], + fn: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], *args, base_delay: float = 0.1, max_delay: float = 1, diff --git a/modal/client.py b/modal/client.py index d5679d516..abc5a7f30 100644 --- a/modal/client.py +++ b/modal/client.py @@ -10,7 +10,7 @@ from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager -from modal_proto import api_grpc, api_pb2, modal_api_grpc +from modal_proto import api_grpc, api_pb2 from modal_version import __version__ from ._utils import async_utils @@ -99,11 +99,11 @@ def __init__( self.image_builder_version: Optional[str] = None self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None self._channel: Optional[grpclib.client.Channel] = None - self._stub: Optional[modal_api_grpc.ModalClientModal] = None + self._stub: Optional[api_grpc.ModalClientStub] = None self._snapshotted = False @property - def stub(self) -> modal_api_grpc.ModalClientModal: + def stub(self) -> api_grpc.ModalClientStub: """mdmd:hidden""" assert self._stub return self._stub @@ -125,8 +125,7 @@ async def _open(self): assert self._stub is None metadata = _get_metadata(self.client_type, self._credentials, self.version) self._channel = create_channel(self.server_url, metadata=metadata) - grpclib_stub = api_grpc.ModalClientStub(self._channel) - self._stub = modal_api_grpc.ModalClientModal(grpclib_stub) + self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore async def _close(self, prep_for_restore: bool = False): if self._pre_stop is not None: diff --git a/modal/functions.py b/modal/functions.py index 1d5fffffa..082653f0c 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -30,8 +30,7 @@ from synchronicity.combined_types import MethodWithAio from modal._output import FunctionCreationStatus -from modal_proto import api_pb2 -from modal_proto.modal_api_grpc import ModalClientModal +from modal_proto import api_grpc, api_pb2 from ._location import parse_cloud_provider from ._output import OutputManager @@ -100,9 +99,9 @@ class _Invocation: """Internal client representation of a single-input call to a Modal Function or Generator""" - stub: ModalClientModal + stub: api_grpc.ModalClientStub - def __init__(self, stub: ModalClientModal, function_call_id: str, client: _Client): + def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client: _Client): self.stub = stub self.client = client # Used by the deserializer. self.function_call_id = function_call_id # TODO: remove and use only input_id diff --git a/modal_proto/__init__.py b/modal_proto/__init__.py index 12325a125..60dc7f423 100644 --- a/modal_proto/__init__.py +++ b/modal_proto/__init__.py @@ -1 +1 @@ -# Copyright Modal Labs 2024 +# Copyright Modal Labs 2022 diff --git a/protoc_plugin/plugin.py b/protoc_plugin/plugin.py deleted file mode 100755 index e5cf7816a..000000000 --- a/protoc_plugin/plugin.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python -# Copyright Modal Labs 2024 -# built by modifying grpclib.plugin.main, see https://github.com/vmagamedov/grpclib -# original: Copyright (c) 2019 , Vladimir Magamedov -import os -import sys -from collections import deque -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Collection, Deque, Dict, Iterator, List, NamedTuple, Optional, Tuple - -from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse -from google.protobuf.descriptor_pb2 import DescriptorProto, FileDescriptorProto -from grpclib import const - -_CARDINALITY = { - (False, False): const.Cardinality.UNARY_UNARY, - (True, False): const.Cardinality.STREAM_UNARY, - (False, True): const.Cardinality.UNARY_STREAM, - (True, True): const.Cardinality.STREAM_STREAM, -} - - -class Method(NamedTuple): - name: str - cardinality: const.Cardinality - request_type: str - reply_type: str - - -class Service(NamedTuple): - name: str - methods: List[Method] - - -class Buffer: - def __init__(self) -> None: - self._lines: List[str] = [] - self._indent = 0 - - def add(self, string: str, *args: Any, **kwargs: Any) -> None: - line = " " * self._indent * 4 + string.format(*args, **kwargs) - self._lines.append(line.rstrip(" ")) - - @contextmanager - def indent(self) -> Iterator[None]: - self._indent += 1 - try: - yield - finally: - self._indent -= 1 - - def content(self) -> str: - return "\n".join(self._lines) + "\n" - - -def render( - proto_file: str, - imports: Collection[str], - services: Collection[Service], - grpclib_module: str, -) -> str: - buf = Buffer() - buf.add("# Generated by the Modal Protocol Buffers compiler. DO NOT EDIT!") - buf.add("# source: {}", proto_file) - buf.add("# plugin: {}", __name__) - if not services: - return buf.content() - - buf.add("") - for mod in imports: - buf.add("import {}", mod) - for service in services: - buf.add("") - buf.add("") - grpclib_stub_name = f"{service.name}Stub" - buf.add("class {}Modal:", service.name) - with buf.indent(): - buf.add("") - buf.add("def __init__(self, grpclib_stub: {}.{}) -> None:".format(grpclib_module, grpclib_stub_name)) - with buf.indent(): - if len(service.methods) == 0: - buf.add("pass") - for method in service.methods: - name, cardinality, request_type, reply_type = method - wrapper_cls: str - if cardinality is const.Cardinality.UNARY_UNARY: - wrapper_cls = "modal._utils.grpc_utils.UnaryUnaryWrapper" - elif cardinality is const.Cardinality.UNARY_STREAM: - wrapper_cls = "modal._utils.grpc_utils.UnaryStreamWrapper" - # elif cardinality is const.Cardinality.STREAM_UNARY: - # wrapper_cls = StreamUnaryWrapper - # elif cardinality is const.Cardinality.STREAM_STREAM: - # wrapper_cls = StreamStreamWrapper - else: - raise TypeError(cardinality) - - original_method = f"grpclib_stub.{name}" - buf.add(f"self.{name} = {wrapper_cls}({original_method})") - - return buf.content() - - -def _get_proto(request: CodeGeneratorRequest, name: str) -> FileDescriptorProto: - return next(f for f in request.proto_file if f.name == name) - - -def _strip_proto(proto_file_path: str) -> str: - for suffix in [".protodevel", ".proto"]: - if proto_file_path.endswith(suffix): - return proto_file_path[: -len(suffix)] - - return proto_file_path - - -def _base_module_name(proto_file_path: str) -> str: - basename = _strip_proto(proto_file_path) - return basename.replace("-", "_").replace("/", ".") - - -def _proto2pb2_module_name(proto_file_path: str) -> str: - return _base_module_name(proto_file_path) + "_pb2" - - -def _proto2grpc_module_name(proto_file_path: str) -> str: - return _base_module_name(proto_file_path) + "_grpc" - - -def _type_names( - proto_file: FileDescriptorProto, - message_type: DescriptorProto, - parents: Optional[Deque[str]] = None, -) -> Iterator[Tuple[str, str]]: - if parents is None: - parents = deque() - - proto_name_parts = [""] - if proto_file.package: - proto_name_parts.append(proto_file.package) - proto_name_parts.extend(parents) - proto_name_parts.append(message_type.name) - - py_name_parts = [_proto2pb2_module_name(proto_file.name)] - py_name_parts.extend(parents) - py_name_parts.append(message_type.name) - - yield ".".join(proto_name_parts), ".".join(py_name_parts) - - parents.append(message_type.name) - for nested in message_type.nested_type: - yield from _type_names(proto_file, nested, parents=parents) - parents.pop() - - -def main() -> None: - with os.fdopen(sys.stdin.fileno(), "rb") as inp: - request = CodeGeneratorRequest.FromString(inp.read()) - - types_map: Dict[str, str] = {} - for pf in request.proto_file: - for mt in pf.message_type: - types_map.update(_type_names(pf, mt)) - - response = CodeGeneratorResponse() - - # See https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/implementing_proto3_presence.md # noqa - if hasattr(CodeGeneratorResponse, "Feature"): - response.supported_features = CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL - - for file_to_generate in request.file_to_generate: - proto_file = _get_proto(request, file_to_generate) - module_name = _proto2grpc_module_name(file_to_generate) - grpclib_module_path = Path(module_name.replace(".", "/") + ".py") - - imports = ["modal._utils.grpc_utils", module_name] - - services = [] - for service in proto_file.service: - methods = [] - for method in service.method: - cardinality = _CARDINALITY[(method.client_streaming, method.server_streaming)] - methods.append( - Method( - name=method.name, - cardinality=cardinality, - request_type=types_map[method.input_type], - reply_type=types_map[method.output_type], - ) - ) - services.append(Service(name=service.name, methods=methods)) - - file = response.file.add() - - file.name = str(grpclib_module_path.with_name("modal_" + grpclib_module_path.name)) - file.content = render( - proto_file=proto_file.name, imports=imports, services=services, grpclib_module=module_name - ) - - with os.fdopen(sys.stdout.fileno(), "wb") as out: - out.write(response.SerializeToString()) - - -if __name__ == "__main__": - main() diff --git a/tasks.py b/tasks.py index 3f54c46b1..6c6722bf2 100644 --- a/tasks.py +++ b/tasks.py @@ -9,11 +9,9 @@ import re import subprocess import sys -from contextlib import contextmanager from datetime import date from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Generator, List, Optional +from typing import List, Optional import requests from invoke import task @@ -25,43 +23,14 @@ copyright_header_full = f"{copyright_header_start} {year}" -@contextmanager -def python_file_as_executable(path: Path) -> Generator[Path, None, None]: - if sys.platform == "win32": - # windows can't just run shebang:ed python files, so we create a .bat file that calls it - src = f"""@echo off -{sys.executable} {path} -""" - with NamedTemporaryFile(mode="w", suffix=".bat", encoding="ascii", delete=False) as f: - f.write(src) - - try: - yield Path(f.name) - finally: - Path(f.name).unlink() - else: - yield path - - @task def protoc(ctx): - protoc_cmd = f"{sys.executable} -m grpc_tools.protoc" - input_files = "modal_proto/api.proto modal_proto/options.proto" py_protoc = ( - protoc_cmd + " --python_out=. --grpclib_python_out=." + " --grpc_python_out=. --mypy_out=. --mypy_grpc_out=." + f"{sys.executable} -m grpc_tools.protoc" + + " --python_out=. --grpclib_python_out=. --grpc_python_out=. --mypy_out=. --mypy_grpc_out=." ) print(py_protoc) - # generate grpcio and grpclib proto files: - ctx.run(f"{py_protoc} -I . {input_files}") - - # generate modal-specific wrapper around grpclib api stub using custom plugin: - grpc_plugin_pyfile = Path(__file__).parent / "protoc_plugin" / "plugin.py" - - with python_file_as_executable(grpc_plugin_pyfile) as grpc_plugin_executable: - ctx.run( - f"{protoc_cmd} --plugin=protoc-gen-modal-grpclib-python={grpc_plugin_executable}" - + f" --modal-grpclib-python_out=. -I . {input_files}" - ) + ctx.run(f"{py_protoc} -I . " "modal_proto/api.proto " "modal_proto/options.proto ") @task @@ -168,7 +137,6 @@ def check_copyright(ctx, fix=False): and not fn.endswith(".notebook.py") # vendored code has a different copyright and "_vendor" not in root - and "protoc_plugin" not in root # third-party code (i.e., in a local venv) has a different copyright and "/site-packages/" not in root ) diff --git a/test/container_app_test.py b/test/container_app_test.py index 5dee0c106..054b94325 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -15,7 +15,7 @@ from modal.client import _Client from modal.exception import InvalidError from modal.running_app import RunningApp -from modal_proto import api_grpc, api_pb2, modal_api_grpc +from modal_proto import api_grpc, api_pb2 def my_f_1(x): @@ -85,7 +85,7 @@ async def test_container_snapshot_reference_capture(container_client, tmpdir, se from modal.runner import deploy_app channel = create_channel(servicer.client_addr) - client_stub = modal_api_grpc.ModalClientModal(api_grpc.ModalClientStub(channel)) + client_stub = api_grpc.ModalClientStub(channel) app.function()(square) app_name = "my-app" app_id = deploy_app(app, app_name, client=container_client).app_id diff --git a/test/grpc_utils_test.py b/test/grpc_utils_test.py index 128685abe..027d63b93 100644 --- a/test/grpc_utils_test.py +++ b/test/grpc_utils_test.py @@ -6,7 +6,6 @@ from modal._utils.grpc_utils import create_channel, retry_transient_errors from modal_proto import api_grpc, api_pb2 -from modal_proto.modal_api_grpc import ModalClientModal from .supports.skip import skip_windows_unix_socket @@ -41,7 +40,7 @@ async def test_unix_channel(servicer): @pytest.mark.asyncio async def test_retry_transient_errors(servicer): channel = create_channel(servicer.client_addr) - client_stub = ModalClientModal(api_grpc.ModalClientStub(channel)) + client_stub = api_grpc.ModalClientStub(channel) # Use the BlobCreate request for retries req = api_pb2.BlobCreateRequest()