Skip to content

Commit

Permalink
Revert "Make rpc calls using a task/cancellation context MOD-3632 (#2178
Browse files Browse the repository at this point in the history
)"

This reverts commit 0a17fa3.
  • Loading branch information
irfansharif committed Sep 23, 2024
1 parent 4a6f36c commit 78f667e
Show file tree
Hide file tree
Showing 23 changed files with 132 additions and 343 deletions.
6 changes: 2 additions & 4 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ._utils.package_utils import parse_major_minor_version
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .config import config, logger
from .exception import ClientClosed, InputCancellation, InvalidError, SerializationError
from .exception import InputCancellation, InvalidError, SerializationError
from .running_app import RunningApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -347,9 +347,6 @@ async def _run_heartbeat_loop(self):
# two subsequent cancellations on the same task at the moment
await asyncio.sleep(1.0)
continue
except ClientClosed:
logger.info("Stopping heartbeat loop due to client shutdown")
break
except Exception as exc:
# don't stop heartbeat loop if there are transient exceptions!
time_elapsed = time.monotonic() - t0
Expand Down Expand Up @@ -943,6 +940,7 @@ async def memory_snapshot(self) -> None:

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
Expand Down
4 changes: 2 additions & 2 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from modal_proto import api_pb2

from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.shell_utils import stream_from_stdin
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -580,7 +580,7 @@ async def _get_logs():
last_entry_id=last_log_batch_entry_id,
)
log_batch: api_pb2.TaskLogsBatch
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
2 changes: 0 additions & 2 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ async def loader():
self._local_uuid_to_future[obj.local_uuid] = cached_future
if deduplication_key is not None:
self._deduplication_cache[deduplication_key] = cached_future

# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> List["_Object"]:
Expand Down
1 change: 0 additions & 1 deletion modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ async def stop(self):

# Cancel any remaining unfinished tasks.
task.cancel()
await asyncio.sleep(0) # wake up coroutines waiting for cancellations

async def __aexit__(self, exc_type, value, tb):
await self.stop()
Expand Down
4 changes: 2 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..exception import ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, unary_stream


class FunctionInfoType(Enum):
Expand Down Expand Up @@ -373,7 +373,7 @@ async def _stream_function_call_data(
while True:
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
try:
async for chunk in stub_fn.unary_stream(req):
async for chunk in unary_stream(stub_fn, req):
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
Expand Down
64 changes: 54 additions & 10 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import platform
import socket
import time
import typing
import urllib.parse
import uuid
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import grpclib.client
Expand All @@ -32,8 +36,6 @@
RequestType = TypeVar("RequestType", bound=Message)
ResponseType = TypeVar("ResponseType", bound=Message)

if typing.TYPE_CHECKING:
import modal.client

# Monkey patches grpclib to have a Modal User Agent header.
grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
Expand Down Expand Up @@ -61,6 +63,9 @@ def connected(self):
return True


_SendType = TypeVar("_SendType")
_RecvType = TypeVar("_RecvType")

RETRYABLE_GRPC_STATUS_CODES = [
Status.DEADLINE_EXCEEDED,
Status.UNAVAILABLE,
Expand Down Expand Up @@ -113,22 +118,61 @@ async def send_request(event: grpclib.events.SendRequest) -> None:
return channel


if typing.TYPE_CHECKING:
import modal.client
_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]


class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

@property
def name(self) -> str:
return self.wrapped_method.name

async def __call__(
self,
req: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)


async def unary_stream(
method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
method: UnaryStreamWrapper[RequestType, ResponseType],
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[ResponseType]:
# TODO: remove this, since we have a method now
async for item in method.unary_stream(request, metadata):
yield item
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
async for item in stream:
yield item


async def retry_transient_errors(
fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
fn: UnaryUnaryWrapper[RequestType, ResponseType],
*args,
base_delay: float = 0.1,
max_delay: float = 1,
Expand Down
4 changes: 2 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._output import OutputManager
from ._utils.async_utils import synchronize_api
from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn
from ._utils.grpc_utils import retry_transient_errors
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.mount_utils import validate_volumes
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount
Expand Down Expand Up @@ -1114,7 +1114,7 @@ async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, N
timeout=55,
last_entry_id=last_log_batch_entry_id,
)
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
Loading

0 comments on commit 78f667e

Please sign in to comment.