Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rpc calls using an asyncio task context MOD-3632 #2178

Merged
merged 31 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff270a9
wip
freider Aug 30, 2024
09e5087
Move unary_stream to wrapper
freider Sep 2, 2024
04d9597
Fix heartbeat loop shutdown
freider Sep 2, 2024
d3c69f6
Add some basic layer of shutdown protection on streaming calls as well
freider Sep 2, 2024
f354814
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 5, 2024
27694fc
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 6, 2024
6a476e4
Fix deadlocks from not running rpcs within synchronizer loop
freider Sep 9, 2024
fd8a761
copyright
freider Sep 9, 2024
3481f1f
Fix typing, tests
freider Sep 9, 2024
1f95a67
Fix test
freider Sep 10, 2024
ff41708
debug wip
freider Sep 10, 2024
d66ba96
Revert "debug wip"
freider Sep 12, 2024
0985376
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 12, 2024
253f004
Cleanup
freider Sep 12, 2024
66804ba
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 13, 2024
24e915a
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 17, 2024
3c3ecd1
Don't raise ClientClosed unless it is
freider Sep 17, 2024
9834dd4
Ugly workaround for test that leak pending tasks
freider Sep 17, 2024
fb46cce
Revert "remove max_workers from DaemonizedThreadPool (#2238)" (#2242)
gongy Sep 17, 2024
c381469
Similar fix to UnaryUnary - only raise ClientClosed if actually closed
freider Sep 18, 2024
92eb0ed
Test
freider Sep 18, 2024
b451caa
Types
freider Sep 18, 2024
2b7515c
Merge branch 'main' into freider/client-close-stops-rpcs
freider Sep 18, 2024
017e74c
Wip
freider Sep 18, 2024
ca6f2af
Fix test flake
freider Sep 19, 2024
0f08b37
Merge remote-tracking branch 'origin/main' into freider/client-close-…
freider Sep 19, 2024
8d55b77
Flake?
freider Sep 19, 2024
26a82e9
Renames
freider Sep 19, 2024
86d75a8
Revert "Flake?"
freider Sep 19, 2024
1222711
types
freider Sep 19, 2024
a9981de
Skip pdm test on python 3.9
freider Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ._utils.package_utils import parse_major_minor_version
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .config import config, logger
from .exception import InputCancellation, InvalidError, SerializationError
from .exception import ClientClosed, InputCancellation, InvalidError, SerializationError
from .running_app import RunningApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -347,6 +347,9 @@ async def _run_heartbeat_loop(self):
# two subsequent cancellations on the same task at the moment
await asyncio.sleep(1.0)
continue
except ClientClosed:
logger.info("Stopping heartbeat loop due to client shutdown")
break
except Exception as exc:
# don't stop heartbeat loop if there are transient exceptions!
time_elapsed = time.monotonic() - t0
Expand Down Expand Up @@ -935,7 +938,6 @@ async def memory_snapshot(self) -> None:

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
Expand Down
4 changes: 2 additions & 2 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from modal_proto import api_pb2

from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from ._utils.shell_utils import stream_from_stdin
from .client import _Client
from .config import logger
Expand Down Expand Up @@ -580,7 +580,7 @@ async def _get_logs():
last_entry_id=last_log_batch_entry_id,
)
log_batch: api_pb2.TaskLogsBatch
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
2 changes: 2 additions & 0 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ async def loader():
self._local_uuid_to_future[obj.local_uuid] = cached_future
if deduplication_key is not None:
self._deduplication_cache[deduplication_key] = cached_future

# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> List["_Object"]:
Expand Down
1 change: 1 addition & 0 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ async def stop(self):

# Cancel any remaining unfinished tasks.
task.cancel()
await asyncio.sleep(0) # wake up coroutines waiting for cancellations

async def __aexit__(self, exc_type, value, tb):
await self.stop()
Expand Down
4 changes: 2 additions & 2 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..exception import ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
from ..mount import ROOT_DIR, _is_modal_path, _Mount
from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES, unary_stream
from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES


class FunctionInfoType(Enum):
Expand Down Expand Up @@ -373,7 +373,7 @@ async def _stream_function_call_data(
while True:
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
try:
async for chunk in unary_stream(stub_fn, req):
async for chunk in stub_fn.unary_stream(req):
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
Expand Down
64 changes: 10 additions & 54 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
import platform
import socket
import time
import typing
import urllib.parse
import uuid
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import grpclib.client
Expand All @@ -36,6 +32,8 @@
RequestType = TypeVar("RequestType", bound=Message)
ResponseType = TypeVar("ResponseType", bound=Message)

if typing.TYPE_CHECKING:
import modal.client

# Monkey patches grpclib to have a Modal User Agent header.
grpclib.client.USER_AGENT = "modal-client/{version} ({sys}; {py}/{py_ver})'".format(
Expand Down Expand Up @@ -63,9 +61,6 @@ def connected(self):
return True


_SendType = TypeVar("_SendType")
_RecvType = TypeVar("_RecvType")

RETRYABLE_GRPC_STATUS_CODES = [
Status.DEADLINE_EXCEEDED,
Status.UNAVAILABLE,
Expand Down Expand Up @@ -118,61 +113,22 @@ async def send_request(event: grpclib.events.SendRequest) -> None:
return channel


_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]


class UnaryUnaryWrapper(Generic[RequestType, ResponseType]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving these to the modal.client module since they are now tied to a client instance

wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

@property
def name(self) -> str:
return self.wrapped_method.name

async def __call__(
self,
req: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
# TODO: implement Client tracking and retries
return await self.wrapped_method(req, timeout=timeout, metadata=metadata)


class UnaryStreamWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.UnaryStreamMethod[RequestType, ResponseType]):
self.wrapped_method = wrapped_method

def open(
self,
*,
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> grpclib.client.Stream[RequestType, ResponseType]:
# TODO: implement Client tracking and unary_stream-wrapper
return self.wrapped_method.open(timeout=timeout, metadata=metadata)
if typing.TYPE_CHECKING:
import modal.client


async def unary_stream(
method: UnaryStreamWrapper[RequestType, ResponseType],
method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
request: RequestType,
metadata: Optional[Any] = None,
) -> AsyncIterator[ResponseType]:
"""Helper for making a unary-streaming gRPC request."""
async with method.open(metadata=metadata) as stream:
await stream.send_message(request, end=True)
async for item in stream:
yield item
# TODO: remove this, since we have a method now
async for item in method.unary_stream(request, metadata):
yield item


async def retry_transient_errors(
fn: UnaryUnaryWrapper[RequestType, ResponseType],
fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
*args,
base_delay: float = 0.1,
max_delay: float = 1,
Expand Down
4 changes: 2 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._output import OutputManager
from ._utils.async_utils import synchronize_api
from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.grpc_utils import retry_transient_errors
from ._utils.mount_utils import validate_volumes
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount
Expand Down Expand Up @@ -1111,7 +1111,7 @@ async def _logs(self, client: Optional[_Client] = None) -> AsyncGenerator[str, N
timeout=55,
last_entry_id=last_log_batch_entry_id,
)
async for log_batch in unary_stream(client.stub.AppGetLogs, request):
async for log_batch in client.stub.AppGetLogs.unary_stream(request):
if log_batch.entry_id:
# log_batch entry_id is empty for fd="server" messages from AppGetLogs
last_log_batch_entry_id = log_batch.entry_id
Expand Down
Loading
Loading