diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 4cd83d052..034f109bb 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -442,6 +442,7 @@ async def get_app_objects(self) -> RunningApp: environment_name=self._environment_name, tag_to_object_id=tag_to_object_id, object_handle_metadata=object_handle_metadata, + client=self._client, ) async def get_serialized_function(self) -> Tuple[Optional[Any], Optional[Callable[..., Any]]]: diff --git a/modal/app.py b/modal/app.py index 057911675..7f3d754fa 100644 --- a/modal/app.py +++ b/modal/app.py @@ -124,6 +124,7 @@ def foo(): """ _all_apps: ClassVar[Dict[Optional[str], List["_App"]]] = {} + _container_app: ClassVar[Optional[RunningApp]] = None _name: Optional[str] _description: Optional[str] @@ -429,6 +430,8 @@ def _init_container(self, client: _Client, running_app: RunningApp): self._running_app = running_app self._client = client + _App._container_app = running_app + # Hydrate objects on app for tag, object_id in running_app.tag_to_object_id.items(): if tag in self._indexed_objects: diff --git a/modal/runner.py b/modal/runner.py index 23fcf4c89..69ad01d15 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -57,7 +57,7 @@ async def _init_local_app_existing(client: _Client, existing_app_id: str) -> Run obj_resp = await retry_transient_errors(client.stub.AppGetObjects, obj_req) app_page_url = f"https://modal.com/apps/{existing_app_id}" # TODO (elias): this should come from the backend object_ids = {item.tag: item.object.object_id for item in obj_resp.items} - return RunningApp(existing_app_id, app_page_url=app_page_url, tag_to_object_id=object_ids) + return RunningApp(existing_app_id, app_page_url=app_page_url, tag_to_object_id=object_ids, client=client) async def _init_local_app_new( @@ -76,6 +76,7 @@ async def _init_local_app_new( logger.debug(f"Created new app with id {app_resp.app_id}") return RunningApp( app_resp.app_id, + client=client, app_page_url=app_resp.app_page_url, app_logs_url=app_resp.app_logs_url, environment_name=environment_name, diff --git a/modal/running_app.py b/modal/running_app.py index 22f76d9ed..7bbb46e34 100644 --- a/modal/running_app.py +++ b/modal/running_app.py @@ -4,6 +4,8 @@ from google.protobuf.message import Message +from .client import _Client + @dataclass class RunningApp: @@ -14,3 +16,4 @@ class RunningApp: tag_to_object_id: Dict[str, str] = field(default_factory=dict) object_handle_metadata: Dict[str, Optional[Message]] = field(default_factory=dict) interactive: bool = False + client: Optional[_Client] = None diff --git a/modal/sandbox.py b/modal/sandbox.py index 87591b81a..7590463bd 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -183,6 +183,8 @@ async def create( client: Optional[_Client] = None, _experimental_gpus: Sequence[GPU_T] = [], ) -> "_Sandbox": + from .app import _App + if environment_name is None: environment_name = config.get("environment") @@ -208,12 +210,19 @@ async def create( _experimental_scheduler_placement=_experimental_scheduler_placement, _experimental_gpus=_experimental_gpus, ) - if client is None: - if app and app._client: - client = app._client - else: - client = await _Client.from_env() - app_id: Optional[str] = app.app_id if app else None + + app_id: Optional[str] = None + app_client: Optional[_Client] = None + + if app is not None: + app_id = app.app_id + app_client = app._client + elif _App._container_app is not None: + app_id = _App._container_app.app_id + app_client = _App._container_app.client + + client = client or app_client or await _Client.from_env() + resolver = Resolver(client, environment_name=environment_name, app_id=app_id) await resolver.load(obj) return obj diff --git a/test/conftest.py b/test/conftest.py index cd7e8f4fc..3761fccfe 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -192,6 +192,7 @@ def __init__(self, blob_host, blobs): self.volume_reloads: Dict[str, int] = defaultdict(lambda: 0) self.sandbox_defs = [] + self.sandbox_app_id = None self.sandbox: asyncio.subprocess.Process = None self.sandbox_result: Optional[api_pb2.GenericResult] = None @@ -1181,6 +1182,7 @@ async def SandboxCreate(self, stream): stdin=asyncio.subprocess.PIPE, ) + self.sandbox_app_id = request.app_id self.sandbox_defs.append(request.definition) await stream.send_message(api_pb2.SandboxCreateResponse(sandbox_id="sb-123")) diff --git a/test/container_test.py b/test/container_test.py index b0571c2a1..ae905f99e 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -2065,3 +2065,9 @@ def test_max_concurrency(servicer): outputs = [deserialize(item.result.data, ret.client) for item in ret.items] assert n_inputs in outputs + + +@skip_github_non_linux +def test_sandbox_infers_app(servicer, event_loop): + _run_container(servicer, "test.supports.sandbox", "spawn_sandbox") + assert servicer.sandbox_app_id == "ap-1" diff --git a/test/supports/sandbox.py b/test/supports/sandbox.py new file mode 100644 index 000000000..0318d24d0 --- /dev/null +++ b/test/supports/sandbox.py @@ -0,0 +1,9 @@ +# Copyright Modal Labs 2024 +import modal + +app = modal.App() + + +@app.function() +def spawn_sandbox(x): + modal.Sandbox.create("bash", "-c", "echo bar")