Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOD-3535] snapshots capture references to modal objects which become #2164

2 changes: 1 addition & 1 deletion modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def preload(self, obj, existing_object_id: Optional[str]):
await obj._preload(obj, self, existing_object_id)

async def load(self, obj: "_Object", existing_object_id: Optional[str] = None):
if obj._is_hydrated and obj._is_another_app:
if not obj._is_rehydrating and obj._is_hydrated and obj._is_another_app:
# No need to reload this, it won't typically change
if obj.local_uuid not in self._local_uuid_to_future:
# a bit dumb - but we still need to store a reference to the object here
Expand Down
2 changes: 2 additions & 0 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
self._pre_stop: Optional[Callable[[], Awaitable[None]]] = None
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[api_grpc.ModalClientStub] = None
self._snapshotted = False

@property
def stub(self) -> api_grpc.ModalClientStub:
Expand Down Expand Up @@ -136,6 +137,7 @@ async def _close(self, forget_credentials: bool = False):

if forget_credentials:
self._credentials = None
self._snapshotted = True
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved

# Remove cached client.
self.set_env_client(None)
Expand Down
11 changes: 11 additions & 0 deletions modal/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class _Object:
_object_id: str
_client: _Client
_is_hydrated: bool
_is_rehydrating: bool
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
_is_rehydrated: bool
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def __init_subclass__(cls, type_prefix: Optional[str] = None):
Expand Down Expand Up @@ -77,6 +79,8 @@ def _init(
self._object_id = None
self._client = None
self._is_hydrated = False
self._is_rehydrating = False
self._is_rehydrated = False

self._initialize_from_empty()

Expand Down Expand Up @@ -214,6 +218,13 @@ def deps(self) -> Callable[..., List["_Object"]]:
async def resolve(self):
"""mdmd:hidden"""
if self._is_hydrated:
# memory snapshots capture references which must be rehydrated
# on restore to handle staleness.
if self._client._snapshotted and not self._is_rehydrated:
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
self._is_rehydrating = True
resolver = Resolver(await _Client.from_env())
await resolver.load(self)
self._is_rehydrating = False
return
elif not self._hydrate_lazily:
self._validate_is_hydrated()
Expand Down
44 changes: 42 additions & 2 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from modal import App, interact
from modal._container_io_manager import ContainerIOManager, _ContainerIOManager
from modal._utils.grpc_utils import create_channel, retry_transient_errors
from modal.client import _Client
from modal.exception import InvalidError
from modal.running_app import RunningApp
from modal_proto import api_pb2
from modal_proto import api_grpc, api_pb2


def my_f_1(x):
Expand Down Expand Up @@ -69,10 +70,49 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()
# In-memory Client instance should have update credentials, not old credentials
# In-memory Client instance should have updated credentials, not old credentials
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")


def square(x):
pass


@pytest.mark.asyncio
async def test_container_snapshot_reference_capture(container_client, tmpdir, servicer):
app = App()
from modal import Function
from modal.runner import deploy_app

channel = create_channel(servicer.client_addr)
client_stub = api_grpc.ModalClientStub(channel)
app.function()(square)
app_name = "my-app"
app_id = deploy_app(app, app_name, client=container_client).app_id

f = Function.lookup(app_name, "square", client=container_client)
assert f.object_id == "fu-1"
await f.remote.aio()
assert f.object_id == "fu-1"

io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()

# Stop the App, invalidating the fu- ID stored in `f`.
assert await retry_transient_errors(client_stub.AppStop, api_pb2.AppStopRequest(app_id=app_id))
# After snapshot-restore the previously looked-up Function should get refreshed and have the
# new fu- ID. ie. the ID should not be stale and invalid.
new_app_id = deploy_app(app, app_name, client=container_client).app_id
assert new_app_id != app_id
await f.remote.aio()
assert f.object_id == "fu-2"
channel.close()


@pytest.mark.asyncio
async def test_container_snapshot_restore_heartbeats(tmpdir, servicer):
client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret"))
Expand Down
Loading