Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WRK-200] memory snapshot causes clientclosed error for webapp #2367

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ async def _open(self):
self._owner_pid = os.getpid()

async def _close(self, prep_for_restore: bool = False):
logger.debug(f"Client ({id(self)}): closing")
self._closed = True
await self._cancellation_context.__aexit__(None, None, None) # wait for all rpcs to be finished/cancelled
if self._channel is not None:
Expand All @@ -156,7 +157,7 @@ async def _close(self, prep_for_restore: bool = False):

async def _init(self):
"""Connect to server and retrieve version information; raise appropriate error for various failures."""
logger.debug("Client: Starting")
logger.debug(f"Client ({id(self)}): Starting")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the object ID logged helps track the usage of fresh vs. stale clients when debugging snapshot issues.

Instead of having to catch stale client objects and refresh them it'd be better if the client object itself could catch that it was stale and refresh itself. If this could work then we could remove all current (and future) if self.client._snapshotted type checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, we should make the client itself detect this, which should be doable through the _call_unary and _call_stream methods on it which all RPC methods should be going via now

_check_config()
try:
req = empty_pb2.Empty()
Expand Down Expand Up @@ -302,7 +303,7 @@ async def _call_safely(self, coro, readable_method: str):

if self.is_closed():
coro.close() # prevent "was never awaited"
raise ClientClosed()
raise ClientClosed(id(self))

current_event_loop = asyncio.get_running_loop()
if current_event_loop == self._cancellation_context_event_loop:
Expand All @@ -312,7 +313,7 @@ async def _call_safely(self, coro, readable_method: str):
return await self._cancellation_context.create_task(coro)
except asyncio.CancelledError:
if self.is_closed():
raise ClientClosed() from None
raise ClientClosed(id(self)) from None
raise # if the task is cancelled as part of synchronizer shutdown or similar, don't raise ClientClosed
else:
# this should be rare - mostly used in tests where rpc requests sometimes are triggered
Expand Down Expand Up @@ -406,6 +407,9 @@ async def __call__(
timeout: Optional[float] = None,
metadata: Optional[_MetadataLike] = None,
) -> ResponseType:
if self.client._snapshotted:
logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
self.client = await _Client.from_env()
return await self.client._call_unary(self._wrapped_method_name, req, timeout=timeout, metadata=metadata)


Expand All @@ -426,5 +430,8 @@ async def unary_stream(
request,
metadata: Optional[Any] = None,
):
if self.client._snapshotted:
logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
self.client = await _Client.from_env
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_env is a method that needs to be called (()) so this probably crashes 😬 - can we add a test that covers this case?

async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata):
yield response
7 changes: 5 additions & 2 deletions modal/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from .client import _Client
from .config import config
from .config import config, logger
from .exception import ExecutionError, InvalidError

O = TypeVar("O", bound="_Object")
Expand Down Expand Up @@ -223,10 +223,13 @@ async def resolve(self):
# memory snapshots capture references which must be rehydrated
# on restore to handle staleness.
if self._client._snapshotted and not self._is_rehydrated:
logger.debug(f"rehydrating {self} after snapshot")
self._is_hydrated = False # un-hydrate and re-resolve
resolver = Resolver(await _Client.from_env())
c = await _Client.from_env()
resolver = Resolver(c)
await resolver.load(self)
self._is_rehydrated = True
logger.debug(f"rehydrated {self} with client {id(c)}")
return
elif not self._hydrate_lazily:
self._validate_is_hydrated()
Expand Down
Loading