From c0fcaf922086b0bb2945ec9ceb893c623a3d7a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Flemstr=C3=B6m?= Date: Wed, 4 Dec 2024 16:25:53 +0100 Subject: [PATCH] protoc_plugin: Pass along which attr on Client should be used for grpclib_stub --- modal/client.py | 43 ++++++++++++++++++++++++++++++++--------- protoc_plugin/plugin.py | 4 ++-- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/modal/client.py b/modal/client.py index 14ceadee91..0457470870 100644 --- a/modal/client.py +++ b/modal/client.py @@ -129,9 +129,17 @@ async def _open(self): self._cancellation_context_event_loop = asyncio.get_running_loop() await self._cancellation_context.__aenter__() self._grpclib_stub = api_grpc.ModalClientStub(self._channel) - self._stub = modal_api_grpc.ModalClientModal(self._grpclib_stub, client=self) + self._stub = modal_api_grpc.ModalClientModal( + self._grpclib_stub, + client=self, + grpclib_stub_attr="_grpclib_stub" + ) self._grpclib_blobs_stub = blobs_grpc.BlobsStub(self._channel) - self._blobs_stub = modal_blobs_grpc.BlobsModal(self._grpclib_blobs_stub, client=self) + self._blobs_stub = modal_blobs_grpc.BlobsModal( + self._grpclib_blobs_stub, + client=self, + grpclib_stub_attr="_grpclib_blobs_stub" + ) self._owner_pid = os.getpid() async def _close(self, prep_for_restore: bool = False): @@ -333,22 +341,23 @@ async def _reset_on_pid_change(self): await self._open() # intentionally not doing self.hello since we should already be authenticated etc. - async def _get_grpclib_method(self, method_name: str) -> Any: + async def _get_grpclib_method(self, grpclib_stub, method_name: str) -> Any: # safely get grcplib method that is bound to a valid channel # This prevents usage of stale methods across forks of processes await self._reset_on_pid_change() - return getattr(self._grpclib_stub, method_name) + return getattr(grpclib_stub, method_name) @synchronizer.nowrap async def _call_unary( self, + grpclib_stub, method_name: str, request: Any, *, timeout: Optional[float] = None, metadata: Optional[_MetadataLike] = None, ) -> Any: - grpclib_method = await self._get_grpclib_method(method_name) + grpclib_method = await self._get_grpclib_method(grpclib_stub, method_name) coro = grpclib_method(request, timeout=timeout, metadata=metadata) return await self._call_safely(coro, grpclib_method.name) @@ -387,11 +396,17 @@ class UnaryUnaryWrapper(Generic[RequestType, ResponseType]): wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType] client: _Client - def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], client: _Client): + def __init__( + self, + wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], + client: _Client, + grpclib_stub_attr: str + ): # we pass in the wrapped_method here to get the correct static types # but don't use the reference directly, see `def wrapped_method` below self._wrapped_full_name = wrapped_method.name self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1] + self._grpclib_stub_attr = grpclib_stub_attr self.client = client @property @@ -408,15 +423,23 @@ async def __call__( if self.client._snapshotted: logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}") self.client = await _Client.from_env() - return await self.client._call_unary(self._wrapped_method_name, req, timeout=timeout, metadata=metadata) + grpclib_stub = getattr(self.client, self._grpclib_stub_attr) + method = self._wrapped_method_name + return await self.client._call_unary(grpclib_stub, method, req, timeout=timeout, metadata=metadata) class UnaryStreamWrapper(Generic[RequestType, ResponseType]): wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType] - def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType], client: _Client): + def __init__( + self, + wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType], + client: _Client, + grpclib_stub_attr: str + ): self._wrapped_full_name = wrapped_method.name self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1] + self._grpclib_stub_attr = grpclib_stub_attr self.client = client @property @@ -431,5 +454,7 @@ async def unary_stream( if self.client._snapshotted: logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}") self.client = await _Client.from_env() - async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata): + grpclib_stub = getattr(self.client, self._grpclib_stub_attr) + method = self._wrapped_method_name + async for response in self.client._call_stream(grpclib_stub, method, request, metadata=metadata): yield response diff --git a/protoc_plugin/plugin.py b/protoc_plugin/plugin.py index 9977d20870..5a904c18d1 100755 --- a/protoc_plugin/plugin.py +++ b/protoc_plugin/plugin.py @@ -86,7 +86,7 @@ def render( buf.add("") buf.add( f"def __init__(self, grpclib_stub: {grpclib_module}.{grpclib_stub_name}, " - + """client: "modal.client._Client") -> None:""" + + """client: "modal.client._Client", grpclib_stub_attr: str) -> None:""" ) with buf.indent(): if len(service.methods) == 0: @@ -106,7 +106,7 @@ def render( raise TypeError(cardinality) original_method = f"grpclib_stub.{name}" - buf.add(f"self.{name} = {wrapper_cls}({original_method}, client)") + buf.add(f"self.{name} = {wrapper_cls}({original_method}, client, grpclib_stub_attr)") return buf.content()