Skip to content

Commit

Permalink
Add tunnel support to sandboxes
Browse files Browse the repository at this point in the history
This commit adds an `open_ports` argument to sandboxes, allowing users
to specify a list of ports to expose. It also adds .tunnels() to
sandboxes, allowing users to get the list active tunnels once they've
spun up.
  • Loading branch information
pawalt committed Sep 5, 2024
1 parent a539e82 commit 1ddf57b
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class _Sandbox(_Object, type_prefix="sb"):
_stderr: _StreamReader
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[List[api_pb2.TunnelData]] = None

@staticmethod
def _new(
Expand All @@ -66,6 +67,8 @@ def _new(
block_network: bool = False,
volumes: Dict[Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]] = {},
pty_info: Optional[api_pb2.PTYInfo] = None,
encrypted_ports: Sequence[int] = [],
unencrypted_ports: Sequence[int] = [],
_experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
_experimental_gpus: Sequence[GPU_T] = [],
) -> "_Sandbox":
Expand Down Expand Up @@ -109,6 +112,9 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
for path, volume in validated_volumes
]

open_ports = [api_pb2.PortSpec(port=port, unencrypted=False) for port in encrypted_ports]
open_ports.extend([api_pb2.PortSpec(port=port, unencrypted=True) for port in unencrypted_ports])

ephemeral_disk = None # Ephemeral disk requests not supported on Sandboxes.
definition = api_pb2.Sandbox(
entrypoint_args=entrypoint_args,
Expand All @@ -129,6 +135,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
pty_info=pty_info,
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
worker_id=config.get("worker_id"),
open_ports=api_pb2.PortSpecs(ports=open_ports),
)

# Note - `resolver.app_id` will be `None` for app-less sandboxes
Expand Down Expand Up @@ -165,6 +172,10 @@ async def create(
Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]
] = {}, # Mount points for Modal Volumes and CloudBucketMounts
pty_info: Optional[api_pb2.PTYInfo] = None,
# List of ports to tunnel into the sandbox. Encrypted ports are tunneled with TLS.
encrypted_ports: Sequence[int] = [],
# List of ports to tunnel into the sandbox without encryption.
unencrypted_ports: Sequence[int] = [],
_experimental_scheduler_placement: Optional[
SchedulerPlacement
] = None, # Experimental controls over fine-grained scheduling (alpha).
Expand All @@ -191,6 +202,8 @@ async def create(
block_network=block_network,
volumes=volumes,
pty_info=pty_info,
encrypted_ports=encrypted_ports,
unencrypted_ports=unencrypted_ports,
_experimental_scheduler_placement=_experimental_scheduler_placement,
_experimental_gpus=_experimental_gpus,
)
Expand Down Expand Up @@ -244,6 +257,23 @@ async def wait(self, raise_on_termination: bool = True):
raise SandboxTerminatedError()
break

async def tunnels(self, timeout: int = 50) -> List[api_pb2.TunnelData]:
"""Get tunnel metadata for the sandbox."""

if self._tunnels:
return self._tunnels

req = api_pb2.SandboxGetTunnelsRequest(sandbox_id=self.object_id, timeout=timeout)
resp = await retry_transient_errors(self._client.stub.SandboxGetTunnels, req)

# If we couldn't get the tunnels in time, report the timeout.
if resp.result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
raise SandboxTimeoutError()

# Otherwise, we got the tunnels and can report the result.
self._tunnels = resp.tunnels
return resp.tunnels

async def terminate(self):
"""Terminate Sandbox execution.
Expand Down

0 comments on commit 1ddf57b

Please sign in to comment.