Skip to content

Commit

Permalink
Make rpc calls using a task/cancellation context MOD-3632 (#2178)
Browse files Browse the repository at this point in the history
Uses client-tied wrappers around all methods exposed in the client.stub, which puts all rpcs in a cancellation context that is exited when the client is closed, translating the CancelledError within each rpc call to a ClientClosed exception.
  • Loading branch information
freider authored Sep 19, 2024
1 parent eb9d6fd commit 0a17fa3
Show file tree
Hide file tree
Showing 23 changed files with 343 additions and 132 deletions.
6 changes: 4 additions & 2 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ._utils.package_utils import parse_major_minor_version
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .config import config, logger
from .exception import InputCancellation, InvalidError, SerializationError
from .exception import ClientClosed, InputCancellation, InvalidError, SerializationError
from .running_app import RunningApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -347,6 +347,9 @@ async def _run_heartbeat_loop(self):
# two subsequent cancellations on the same task at the moment
await asyncio.sleep(1.0)
continue
except ClientClosed:
logger.info("Stopping heartbeat loop due to client shutdown")
break
except Exception as exc:
# don't stop heartbeat loop if there are transient exceptions!
time_elapsed = time.monotonic() - t0
Expand Down Expand Up @@ -935,7 +938,6 @@ async def memory_snapshot(self) -> None:

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
Expand Down
4 changes: 2 additions & 2 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from modal_proto import api_pb2

from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from ._utils.shell_utils import stream_from_stdin
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -580,7 +580,7 @@ async def _get_logs():
last_entry_id=last_log_batch_entry_id,
)
log_batch: api_pb2.TaskLogsBatch
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
2 changes: 2 additions & 0 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ async def loader():
self._local_uuid_to_future[obj.local_uuid] = cached_future
if deduplication_key is not None:
self._deduplication_cache[deduplication_key] = cached_future

# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> List["_Object"]:
Expand Down
1 change: 1 addition & 0 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ async def stop(self):

# Cancel any remaining unfinished tasks.
task.cancel()
await asyncio.sleep(0) # wake up coroutines waiting for cancellations

async def __aexit__(self, exc_type, value, tb):
await self.stop()
Expand Down
4 changes: 2 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..exception import ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, unary_stream
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES


class FunctionInfoType(Enum):
Expand Down Expand Up @@ -373,7 +373,7 @@ async def _stream_function_call_data(
while True:
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
try:
async for chunk in unary_stream(stub_fn, req):
async for chunk in stub_fn.unary_stream(req):
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
Expand Down
64 changes: 10 additions & 54 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
import platform
import socket
import time
import typing
import urllib.parse
import uuid
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import grpclib.client
Expand All @@ -36,6 +32,8 @@
RequestType = TypeVar("RequestType", bound=Message)
ResponseType = TypeVar("ResponseType", bound=Message)

if typing.TYPE_CHECKING:
import modal.client

# Monkey patches grpclib to have a Modal User Agent header.
grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
Expand Down Expand Up @@ -63,9 +61,6 @@ def connected(self):
return True


_SendType = TypeVar("_SendType")
_RecvType = TypeVar("_RecvType")

RETRYABLE_GRPC_STATUS_CODES = [
Status.DEADLINE_EXCEEDED,
Status.UNAVAILABLE,
Expand Down Expand Up @@ -118,61 +113,22 @@ async def send_request(event: grpclib.events.SendRequest) -> None:
return channel


_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]


class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

@property
def name(self) -> str:
return self.wrapped_method.name

async def __call__(
self,
req: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)
if typing.TYPE_CHECKING:
import modal.client


async def unary_stream(
method: UnaryStreamWrapper[RequestType, ResponseType],
method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[ResponseType]:
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
async for item in stream:
yield item
# TODO: remove this, since we have a method now
async for item in method.unary_stream(request, metadata):
yield item


async def retry_transient_errors(
fn: UnaryUnaryWrapper[RequestType, ResponseType],
fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
*args,
base_delay: float = 0.1,
max_delay: float = 1,
Expand Down
4 changes: 2 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._output import OutputManager
from ._utils.async_utils import synchronize_api
from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.grpc_utils import retry_transient_errors
from ._utils.mount_utils import validate_volumes
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount
Expand Down Expand Up @@ -1111,7 +1111,7 @@ async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, N
timeout=55,
last_entry_id=last_log_batch_entry_id,
)
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
Loading

0 comments on commit 0a17fa3

Please sign in to comment.