Skip to content

Commit

Permalink
protoc_plugin: Pass along which attr on Client should be used for grp…
Browse files Browse the repository at this point in the history
…clib_stub
  • Loading branch information
dflemstr committed Dec 4, 2024
1 parent 6c6b211 commit 55309a1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
46 changes: 36 additions & 10 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,17 @@ async def _open(self):
self._cancellation_context_event_loop = asyncio.get_running_loop()
await self._cancellation_context.__aenter__()
self._grpclib_stub = api_grpc.ModalClientStub(self._channel)
self._stub = modal_api_grpc.ModalClientModal(self._grpclib_stub, client=self)
self._stub = modal_api_grpc.ModalClientModal(
self._grpclib_stub,
client=self,
grpclib_stub_attr="_grpclib_stub"
)
self._grpclib_blobs_stub = blobs_grpc.BlobsStub(self._channel)
self._blobs_stub = modal_blobs_grpc.BlobsModal(self._grpclib_blobs_stub, client=self)
self._blobs_stub = modal_blobs_grpc.BlobsModal(
self._grpclib_blobs_stub,
client=self,
grpclib_stub_attr="_grpclib_blobs_stub"
)
self._owner_pid = os.getpid()

async def _close(self, prep_for_restore: bool = False):
Expand Down Expand Up @@ -333,34 +341,36 @@ async def _reset_on_pid_change(self):
await self._open()
# intentionally not doing self.hello since we should already be authenticated etc.

async def _get_grpclib_method(self, method_name: str) -> Any:
async def _get_grpclib_method(self, grpclib_stub, method_name: str) -> Any:
# safely get grcplib method that is bound to a valid channel
# This prevents usage of stale methods across forks of processes
await self._reset_on_pid_change()
return getattr(self._grpclib_stub, method_name)
return getattr(grpclib_stub, method_name)

@synchronizer.nowrap
async def _call_unary(
self,
grpclib_stub,
method_name: str,
request: Any,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> Any:
grpclib_method = await self._get_grpclib_method(method_name)
grpclib_method = await self._get_grpclib_method(grpclib_stub, method_name)
coro = grpclib_method(request, timeout=timeout, metadata=metadata)
return await self._call_safely(coro, grpclib_method.name)

@synchronizer.nowrap
async def _call_stream(
self,
grpclib_stub,
method_name: str,
request: Any,
*,
metadata: Optional[_MetadataLike],
) -> AsyncGenerator[Any, None]:
grpclib_method = await self._get_grpclib_method(method_name)
grpclib_method = await self._get_grpclib_method(grpclib_stub, method_name)
stream_context = grpclib_method.open(metadata=metadata)
stream = await self._call_safely(stream_context.__aenter__(), f"{grpclib_method.name}.open")
try:
Expand All @@ -387,11 +397,17 @@ class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]
client: _Client

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], client: _Client):
def __init__(
self,
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType],
client: _Client,
grpclib_stub_attr: str
):
# we pass in the wrapped_method here to get the correct static types
# but don't use the reference directly, see `def wrapped_method` below
self._wrapped_full_name = wrapped_method.name
self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
self._grpclib_stub_attr = grpclib_stub_attr
self.client = client

@property
Expand All @@ -408,15 +424,23 @@ async def __call__(
if self.client._snapshotted:
logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
self.client = await _Client.from_env()
return await self.client._call_unary(self._wrapped_method_name, req, timeout=timeout, metadata=metadata)
grpclib_stub = getattr(self.client, self._grpclib_stub_attr)
method = self._wrapped_method_name
return await self.client._call_unary(grpclib_stub, method, req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType], client: _Client):
def __init__(
self,
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType],
client: _Client,
grpclib_stub_attr: str
):
self._wrapped_full_name = wrapped_method.name
self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
self._grpclib_stub_attr = grpclib_stub_attr
self.client = client

@property
Expand All @@ -431,5 +455,7 @@ async def unary_stream(
if self.client._snapshotted:
logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
self.client = await _Client.from_env()
async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata):
grpclib_stub = getattr(self.client, self._grpclib_stub_attr)
method = self._wrapped_method_name
async for response in self.client._call_stream(grpclib_stub, method, request, metadata=metadata):
yield response
4 changes: 2 additions & 2 deletions protoc_plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def render(
buf.add("")
buf.add(
f"def __init__(self, grpclib_stub: {grpclib_module}.{grpclib_stub_name}, "
+ """client: "modal.client._Client") -> None:"""
+ """client: "modal.client._Client", grpclib_stub_attr: str) -> None:"""
)
with buf.indent():
if len(service.methods) == 0:
Expand All @@ -106,7 +106,7 @@ def render(
raise TypeError(cardinality)

original_method = f"grpclib_stub.{name}"
buf.add(f"self.{name} = {wrapper_cls}({original_method}, client)")
buf.add(f"self.{name} = {wrapper_cls}({original_method}, client, grpclib_stub_attr)")

return buf.content()

Expand Down

0 comments on commit 55309a1

Please sign in to comment.