Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rpc calls using an asyncio task context MOD-3632 #2178

Merged
merged 31 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff270a9
wip
freider Aug 30, 2024
09e5087
Move unary_stream to wrapper
freider Sep 2, 2024
04d9597
Fix heartbeat loop shutdown
freider Sep 2, 2024
d3c69f6
Add some basic layer of shutdown protection on streaming calls as well
freider Sep 2, 2024
f354814
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 5, 2024
27694fc
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 6, 2024
6a476e4
Fix deadlocks from not running rpcs within synchronizer loop
freider Sep 9, 2024
fd8a761
copyright
freider Sep 9, 2024
3481f1f
Fix typing, tests
freider Sep 9, 2024
1f95a67
Fix test
freider Sep 10, 2024
ff41708
debug wip
freider Sep 10, 2024
d66ba96
Revert "debug wip"
freider Sep 12, 2024
0985376
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 12, 2024
253f004
Cleanup
freider Sep 12, 2024
66804ba
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 13, 2024
24e915a
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 17, 2024
3c3ecd1
Don't raise ClientClosed unless it is
freider Sep 17, 2024
9834dd4
Ugly workaround for test that leak pending tasks
freider Sep 17, 2024
fb46cce
Revert "remove max_workers from DaemonizedThreadPool (#2238)" (#2242)
gongy Sep 17, 2024
c381469
Similar fix to UnaryUnary - only raise ClientClosed if actually closed
freider Sep 18, 2024
92eb0ed
Test
freider Sep 18, 2024
b451caa
Types
freider Sep 18, 2024
2b7515c
Merge branch 'main' into freider/client-close-stops-rpcs
freider Sep 18, 2024
017e74c
Wip
freider Sep 18, 2024
ca6f2af
Fix test flake
freider Sep 19, 2024
0f08b37
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 19, 2024
8d55b77
Flake?
freider Sep 19, 2024
26a82e9
Renames
freider Sep 19, 2024
86d75a8
Revert "Flake?"
freider Sep 19, 2024
1222711
types
freider Sep 19, 2024
a9981de
Skip pdm test on python 3.9
freider Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ._utils.grpc_utils import get_proto_oneof, retry_transient_errors
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .config import config, logger
from .exception import InputCancellation, InvalidError, SerializationError
from .exception import ClientClosed, InputCancellation, InvalidError, SerializationError
from .running_app import RunningApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -333,6 +333,9 @@ async def _run_heartbeat_loop(self):
# two subsequent cancellations on the same task at the moment
await asyncio.sleep(1.0)
continue
except ClientClosed:
logger.info("Stopping heartbeat loop due to client shutdown")
break
except Exception as exc:
# don't stop heartbeat loop if there are transient exceptions!
time_elapsed = time.monotonic() - t0
Expand Down Expand Up @@ -912,7 +915,6 @@ async def memory_snapshot(self) -> None:

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
Expand Down
4 changes: 2 additions & 2 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from modal_proto import api_pb2

from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from ._utils.shell_utils import stream_from_stdin
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -580,7 +580,7 @@ async def _get_logs():
last_entry_id=last_log_batch_entry_id,
)
log_batch: api_pb2.TaskLogsBatch
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
2 changes: 2 additions & 0 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ async def loader():
self._local_uuid_to_future[obj.local_uuid] = cached_future
if deduplication_key is not None:
self._deduplication_cache[deduplication_key] = cached_future

# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> List["_Object"]:
Expand Down
4 changes: 2 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..exception import ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, unary_stream
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES


class FunctionInfoType(Enum):
Expand Down Expand Up @@ -364,7 +364,7 @@ async def _stream_function_call_data(
while True:
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
try:
async for chunk in unary_stream(stub_fn, req):
async for chunk in stub_fn.unary_stream(req):
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
Expand Down
62 changes: 49 additions & 13 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import platform
import socket
import time
import typing
import urllib.parse
import uuid
from typing import (
Expand All @@ -29,13 +30,17 @@
from grpclib.exceptions import StreamTerminatedError
from grpclib.protocol import H2Protocol

from modal.exception import ClientClosed
from modal_version import __version__

from .async_utils import synchronizer
from .logger import logger

RequestType = TypeVar("RequestType", bound=Message)
ResponseType = TypeVar("ResponseType", bound=Message)

if typing.TYPE_CHECKING:
import modal.client

# Monkey patches grpclib to have a Modal User Agent header.
grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
Expand Down Expand Up @@ -63,9 +68,6 @@ def connected(self):
return True


_SendType = TypeVar("_SendType")
_RecvType = TypeVar("_RecvType")

RETRYABLE_GRPC_STATUS_CODES = [
Status.DEADLINE_EXCEEDED,
Status.UNAVAILABLE,
Expand Down Expand Up @@ -124,9 +126,13 @@ async def send_request(event: grpclib.events.SendRequest) -> None:

class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving these to the modal.client module since they are now tied to a client instance

wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]
client: "modal.client._Client"

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
def __init__(
self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType], client: "modal.client._Client"
):
self.wrapped_method = wrapped_method
self.client = client

@property
def name(self) -> str:
Expand All @@ -139,36 +145,66 @@ async def __call__(
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)
# it's important that this is run from the same event loop as the rpc context is "bound" to,
# i.e. the synchronicity event loop
assert synchronizer._is_inside_loop()
# TODO: incorporate retry_transient_errors here
if self.client.is_closed():
raise ClientClosed()
try:
return await self.client._rpc_context.create_task(
self.wrapped_method(req, timeout=timeout, metadata=metadata)
)
except asyncio.CancelledError:
raise ClientClosed()


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
def __init__(
self,
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType],
client: "modal.client._Client",
):
self.wrapped_method = wrapped_method
self.client = client

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)

async def unary_stream(
self,
request,
metadata: Optional[Any] = None,
):
"""Helper for making a unary-streaming gRPC request."""
# TODO: would be nice to put the Client.close tracking in `.open()` instead
# TODO: unit test that close triggers ClientClosed for streams
if self.client.is_closed():
raise ClientClosed()
try:
async with self.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
async for item in stream:
yield item
except asyncio.CancelledError:
raise ClientClosed()


async def unary_stream(
method: UnaryStreamWrapper[RequestType, ResponseType],
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[ResponseType]:
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
async for item in stream:
yield item
# TODO: remove this, since we have a method now
async for item in method.unary_stream(request, metadata):
yield item


async def retry_transient_errors(
Expand Down
3 changes: 1 addition & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ._output import OutputManager
from ._utils.async_utils import synchronize_api
from ._utils.function_utils import FunctionInfo, is_global_object, is_top_level_function
from ._utils.grpc_utils import unary_stream
from ._utils.mount_utils import validate_volumes
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount
Expand Down Expand Up @@ -1028,7 +1027,7 @@ async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, N
timeout=55,
last_entry_id=last_log_batch_entry_id,
)
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
29 changes: 12 additions & 17 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import platform
import warnings
from typing import AsyncIterator, Awaitable, Callable, ClassVar, Dict, Optional, Tuple
from typing import AsyncIterator, ClassVar, Dict, Optional, Tuple

import grpclib.client
from aiohttp import ClientConnectorError, ClientResponseError
Expand All @@ -14,7 +14,7 @@
from modal_version import __version__

from ._utils import async_utils
from ._utils.async_utils import synchronize_api
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import create_channel, retry_transient_errors
from ._utils.http_utils import ClientSessionRegistry
from .config import _check_config, config, logger
Expand Down Expand Up @@ -82,6 +82,8 @@ async def _grpc_exc_string(exc: GRPCError, method_name: str, server_url: str, ti
class _Client:
_client_from_env: ClassVar[Optional["_Client"]] = None
_client_from_env_lock: ClassVar[Optional[asyncio.Lock]] = None
_rpc_context: TaskContext
_stub: Optional[api_grpc.ModalClientStub]

def __init__(
self,
Expand All @@ -99,11 +101,13 @@ def __init__(
self.version = version
self._authenticated = False
self.image_builder_version: Optional[str] = None
self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[modal_api_grpc.ModalClientModal] = None
self._snapshotted = False

def is_closed(self) -> bool:
return self._channel is None

@property
def stub(self) -> modal_api_grpc.ModalClientModal:
"""mdmd:hidden"""
Expand All @@ -127,16 +131,17 @@ async def _open(self):
assert self._stub is None
metadata = _get_metadata(self.client_type, self._credentials, self.version)
self._channel = create_channel(self.server_url, metadata=metadata)
self._rpc_context = TaskContext(grace=0.5) # allow running rpcs to finish in 0.5s when closing client
await self._rpc_context.__aenter__()
grpclib_stub = api_grpc.ModalClientStub(self._channel)
self._stub = modal_api_grpc.ModalClientModal(grpclib_stub)
self._stub = modal_api_grpc.ModalClientModal(grpclib_stub, client=self)

async def _close(self, prep_for_restore: bool = False):
if self._pre_stop is not None:
logger.debug("Client: running pre-stop coroutine before shutting down")
await self._pre_stop() # type: ignore
await self._rpc_context.__aexit__(None, None, None) # wait for all rpcs to be finished/cancelled

if self._channel is not None:
self._channel.close()
self._channel = None

if prep_for_restore:
self._credentials = None
Expand All @@ -145,16 +150,6 @@ async def _close(self, prep_for_restore: bool = False):
# Remove cached client.
self.set_env_client(None)

def set_pre_stop(self, pre_stop: Callable[[], Awaitable[None]]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to not be used anymore 🕺

"""mdmd:hidden"""
# hack: stub.serve() gets into a losing race with the `on_shutdown` client
# teardown when an interrupt signal is received (eg. KeyboardInterrupt).
# By registering a pre-stop fn stub.serve() can have its teardown
# performed before the client is disconnected.
#
# ref: github.com/modal-labs/modal-client/pull/108
self._pre_stop = pre_stop

async def _init(self):
"""Connect to server and retrieve version information; raise appropriate error for various failures."""
logger.debug("Client: Starting")
Expand Down
8 changes: 4 additions & 4 deletions modal/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._resolver import Resolver
from ._serialization import deserialize, serialize
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -302,7 +302,7 @@ async def keys(self) -> AsyncIterator[Any]:
and results are unordered.
"""
req = api_pb2.DictContentsRequest(dict_id=self.object_id, keys=True)
async for resp in unary_stream(self._client.stub.DictContents, req):
async for resp in self._client.stub.DictContents.unary_stream(req):
yield deserialize(resp.key, self._client)

@live_method_gen
Expand All @@ -313,7 +313,7 @@ async def values(self) -> AsyncIterator[Any]:
and results are unordered.
"""
req = api_pb2.DictContentsRequest(dict_id=self.object_id, values=True)
async for resp in unary_stream(self._client.stub.DictContents, req):
async for resp in self._client.stub.DictContents.unary_stream(req):
yield deserialize(resp.value, self._client)

@live_method_gen
Expand All @@ -324,7 +324,7 @@ async def items(self) -> AsyncIterator[Tuple[Any, Any]]:
and results are unordered.
"""
req = api_pb2.DictContentsRequest(dict_id=self.object_id, keys=True, values=True)
async for resp in unary_stream(self._client.stub.DictContents, req):
async for resp in self._client.stub.DictContents.unary_stream(req):
yield (deserialize(resp.key, self._client), deserialize(resp.value, self._client))


Expand Down
4 changes: 4 additions & 0 deletions modal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,7 @@ class InputCancellation(BaseException):

class ModuleNotMountable(Exception):
pass


class ClientClosed(Error):
pass
6 changes: 3 additions & 3 deletions modal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ._utils.async_utils import synchronize_api
from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES
from ._utils.function_utils import FunctionInfo
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from .cloud_bucket_mount import _CloudBucketMount
from .config import config, logger, user_config_path
from .exception import InvalidError, NotFoundError, RemoteError, VersionError, deprecation_error, deprecation_warning
Expand Down Expand Up @@ -414,7 +414,7 @@ async def join():
nonlocal last_entry_id, result

request = api_pb2.ImageJoinStreamingRequest(image_id=image_id, timeout=55, last_entry_id=last_entry_id)
async for response in unary_stream(resolver.client.stub.ImageJoinStreaming, request):
async for response in resolver.client.stub.ImageJoinStreaming.unary_stream(request):
if response.entry_id:
last_entry_id = response.entry_id
if response.result.status:
Expand Down Expand Up @@ -1702,7 +1702,7 @@ async def _logs(self) -> AsyncGenerator[str, None]:
request = api_pb2.ImageJoinStreamingRequest(
image_id=self._object_id, timeout=55, last_entry_id=last_entry_id, include_logs_for_finished=True
)
async for response in unary_stream(self._client.stub.ImageJoinStreaming, request):
async for response in self._client.stub.ImageJoinStreaming.unary_stream(request):
if response.result.status:
return
if response.entry_id:
Expand Down
6 changes: 3 additions & 3 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from modal_proto import api_pb2

from ._utils.async_utils import synchronize_api
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from .client import _Client

if TYPE_CHECKING:
Expand All @@ -23,7 +23,7 @@ async def _sandbox_logs_iterator(
timeout=55,
last_entry_id=last_entry_id,
)
async for log_batch in unary_stream(client.stub.SandboxGetLogs, req):
async for log_batch in client.stub.SandboxGetLogs.unary_stream(req):
last_entry_id = log_batch.entry_id

for message in log_batch.items:
Expand All @@ -42,7 +42,7 @@ async def _container_process_logs_iterator(
last_batch_index=last_entry_id or 0,
file_descriptor=file_descriptor,
)
async for batch in unary_stream(client.stub.ContainerExecGetOutput, req):
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield (None, batch.batch_index)
break
Expand Down
Loading
Loading