Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOD-3535] snapshots capture references to modal objects which become #2164

2 changes: 1 addition & 1 deletion modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ async def memory_snapshot(self) -> None:
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id or "")
)

await self._client._close(forget_credentials=True)
await self._client._close(prep_for_restore=True)

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()
Expand Down
6 changes: 4 additions & 2 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[api_grpc.ModalClientStub] = None
self._snapshotted = False

@property
def stub(self) -> api_grpc.ModalClientStub:
Expand All @@ -126,16 +127,17 @@ async def _open(self):
self._channel = create_channel(self.server_url, metadata=metadata)
self._stub = api_grpc.ModalClientStub(self._channel) # type: ignore

async def _close(self, forget_credentials: bool = False):
async def _close(self, prep_for_restore: bool = False):
if self._pre_stop is not None:
logger.debug("Client: running pre-stop coroutine before shutting down")
await self._pre_stop() # type: ignore

if self._channel is not None:
self._channel.close()

if forget_credentials:
if prep_for_restore:
self._credentials = None
self._snapshotted = True

# Remove cached client.
self.set_env_client(None)
Expand Down
8 changes: 8 additions & 0 deletions modal/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class _Object:
_object_id: str
_client: _Client
_is_hydrated: bool
_is_rehydrating: bool
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def __init_subclass__(cls, type_prefix: Optional[str] = None):
Expand Down Expand Up @@ -77,6 +78,7 @@ def _init(
self._object_id = None
self._client = None
self._is_hydrated = False
self._is_rehydrated = False

self._initialize_from_empty()

Expand Down Expand Up @@ -214,6 +216,12 @@ def deps(self) -> Callable[..., List["_Object"]]:
async def resolve(self):
"""mdmd:hidden"""
if self._is_hydrated:
# memory snapshots capture references which must be rehydrated
# on restore to handle staleness.
if self._client._snapshotted and not self._is_rehydrated:
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
self._is_hydrated = False # un-hydrate and re-resolve
resolver = Resolver(await _Client.from_env())
await resolver.load(self)
return
elif not self._hydrate_lazily:
self._validate_is_hydrated()
Expand Down
44 changes: 42 additions & 2 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from modal import App, interact
from modal._container_io_manager import ContainerIOManager, _ContainerIOManager
from modal._utils.grpc_utils import create_channel, retry_transient_errors
from modal.client import _Client
from modal.exception import InvalidError
from modal.running_app import RunningApp
from modal_proto import api_pb2
from modal_proto import api_grpc, api_pb2


def my_f_1(x):
Expand Down Expand Up @@ -69,10 +70,49 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()
# In-memory Client instance should have update credentials, not old credentials
# In-memory Client instance should have updated credentials, not old credentials
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")


def square(x):
pass


@pytest.mark.asyncio
async def test_container_snapshot_reference_capture(container_client, tmpdir, servicer):
app = App()
from modal import Function
from modal.runner import deploy_app

channel = create_channel(servicer.client_addr)
client_stub = api_grpc.ModalClientStub(channel)
app.function()(square)
app_name = "my-app"
app_id = deploy_app(app, app_name, client=container_client).app_id

f = Function.lookup(app_name, "square", client=container_client)
assert f.object_id == "fu-1"
await f.remote.aio()
assert f.object_id == "fu-1"

io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()

# Stop the App, invalidating the fu- ID stored in `f`.
assert await retry_transient_errors(client_stub.AppStop, api_pb2.AppStopRequest(app_id=app_id))
# After snapshot-restore the previously looked-up Function should get refreshed and have the
# new fu- ID. ie. the ID should not be stale and invalid.
new_app_id = deploy_app(app, app_name, client=container_client).app_id
assert new_app_id != app_id
await f.remote.aio()
assert f.object_id == "fu-2"
channel.close()


@pytest.mark.asyncio
async def test_container_snapshot_restore_heartbeats(tmpdir, servicer):
client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret"))
Expand Down
Loading