Skip to content

Commit

Permalink
Infer sandbox app if spawned from inside container (#2193)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksh-at authored Sep 6, 2024
1 parent ad58ac1 commit 588c85e
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ async def get_app_objects(self) -> RunningApp:
environment_name=self._environment_name,
tag_to_object_id=tag_to_object_id,
object_handle_metadata=object_handle_metadata,
client=self._client,
)

async def get_serialized_function(self) -> Tuple[Optional[Any], Optional[Callable[..., Any]]]:
Expand Down
3 changes: 3 additions & 0 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def foo():
"""

_all_apps: ClassVar[Dict[Optional[str], List["_App"]]] = {}
_container_app: ClassVar[Optional[RunningApp]] = None

_name: Optional[str]
_description: Optional[str]
Expand Down Expand Up @@ -429,6 +430,8 @@ def _init_container(self, client: _Client, running_app: RunningApp):
self._running_app = running_app
self._client = client

_App._container_app = running_app

# Hydrate objects on app
for tag, object_id in running_app.tag_to_object_id.items():
if tag in self._indexed_objects:
Expand Down
3 changes: 2 additions & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def _init_local_app_existing(client: _Client, existing_app_id: str) -> Run
obj_resp = await retry_transient_errors(client.stub.AppGetObjects, obj_req)
app_page_url = f"https://modal.com/apps/{existing_app_id}" # TODO (elias): this should come from the backend
object_ids = {item.tag: item.object.object_id for item in obj_resp.items}
return RunningApp(existing_app_id, app_page_url=app_page_url, tag_to_object_id=object_ids)
return RunningApp(existing_app_id, app_page_url=app_page_url, tag_to_object_id=object_ids, client=client)


async def _init_local_app_new(
Expand All @@ -76,6 +76,7 @@ async def _init_local_app_new(
logger.debug(f"Created new app with id {app_resp.app_id}")
return RunningApp(
app_resp.app_id,
client=client,
app_page_url=app_resp.app_page_url,
app_logs_url=app_resp.app_logs_url,
environment_name=environment_name,
Expand Down
3 changes: 3 additions & 0 deletions modal/running_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from google.protobuf.message import Message

from .client import _Client


@dataclass
class RunningApp:
Expand All @@ -14,3 +16,4 @@ class RunningApp:
tag_to_object_id: Dict[str, str] = field(default_factory=dict)
object_handle_metadata: Dict[str, Optional[Message]] = field(default_factory=dict)
interactive: bool = False
client: Optional[_Client] = None
21 changes: 15 additions & 6 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ async def create(
client: Optional[_Client] = None,
_experimental_gpus: Sequence[GPU_T] = [],
) -> "_Sandbox":
from .app import _App

if environment_name is None:
environment_name = config.get("environment")

Expand All @@ -208,12 +210,19 @@ async def create(
_experimental_scheduler_placement=_experimental_scheduler_placement,
_experimental_gpus=_experimental_gpus,
)
if client is None:
if app and app._client:
client = app._client
else:
client = await _Client.from_env()
app_id: Optional[str] = app.app_id if app else None

app_id: Optional[str] = None
app_client: Optional[_Client] = None

if app is not None:
app_id = app.app_id
app_client = app._client
elif _App._container_app is not None:
app_id = _App._container_app.app_id
app_client = _App._container_app.client

client = client or app_client or await _Client.from_env()

resolver = Resolver(client, environment_name=environment_name, app_id=app_id)
await resolver.load(obj)
return obj
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, blob_host, blobs):
self.volume_reloads: Dict[str, int] = defaultdict(lambda: 0)

self.sandbox_defs = []
self.sandbox_app_id = None
self.sandbox: asyncio.subprocess.Process = None
self.sandbox_result: Optional[api_pb2.GenericResult] = None

Expand Down Expand Up @@ -1181,6 +1182,7 @@ async def SandboxCreate(self, stream):
stdin=asyncio.subprocess.PIPE,
)

self.sandbox_app_id = request.app_id
self.sandbox_defs.append(request.definition)

await stream.send_message(api_pb2.SandboxCreateResponse(sandbox_id="sb-123"))
Expand Down
6 changes: 6 additions & 0 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,3 +2065,9 @@ def test_max_concurrency(servicer):

outputs = [deserialize(item.result.data, ret.client) for item in ret.items]
assert n_inputs in outputs


@skip_github_non_linux
def test_sandbox_infers_app(servicer, event_loop):
_run_container(servicer, "test.supports.sandbox", "spawn_sandbox")
assert servicer.sandbox_app_id == "ap-1"
9 changes: 9 additions & 0 deletions test/supports/sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright Modal Labs 2024
import modal

app = modal.App()


@app.function()
def spawn_sandbox(x):
modal.Sandbox.create("bash", "-c", "echo bar")

0 comments on commit 588c85e

Please sign in to comment.