From 1ddf57b6d80c0d46ea52e993f4b1825c0884a1b2 Mon Sep 17 00:00:00 2001 From: Peyton Walters Date: Thu, 5 Sep 2024 15:00:40 +0000 Subject: [PATCH] Add tunnel support to sandboxes This commit adds an `open_ports` argument to sandboxes, allowing users to specify a list of ports to expose. It also adds .tunnels() to sandboxes, allowing users to get the list active tunnels once they've spun up. --- modal/sandbox.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/modal/sandbox.py b/modal/sandbox.py index a2a3ae178..359761cb7 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -48,6 +48,7 @@ class _Sandbox(_Object, type_prefix="sb"): _stderr: _StreamReader _stdin: _StreamWriter _task_id: Optional[str] = None + _tunnels: Optional[List[api_pb2.TunnelData]] = None @staticmethod def _new( @@ -66,6 +67,8 @@ def _new( block_network: bool = False, volumes: Dict[Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]] = {}, pty_info: Optional[api_pb2.PTYInfo] = None, + encrypted_ports: Sequence[int] = [], + unencrypted_ports: Sequence[int] = [], _experimental_scheduler_placement: Optional[SchedulerPlacement] = None, _experimental_gpus: Sequence[GPU_T] = [], ) -> "_Sandbox": @@ -109,6 +112,9 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona for path, volume in validated_volumes ] + open_ports = [api_pb2.PortSpec(port=port, unencrypted=False) for port in encrypted_ports] + open_ports.extend([api_pb2.PortSpec(port=port, unencrypted=True) for port in unencrypted_ports]) + ephemeral_disk = None # Ephemeral disk requests not supported on Sandboxes. definition = api_pb2.Sandbox( entrypoint_args=entrypoint_args, @@ -129,6 +135,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona pty_info=pty_info, scheduler_placement=scheduler_placement.proto if scheduler_placement else None, worker_id=config.get("worker_id"), + open_ports=api_pb2.PortSpecs(ports=open_ports), ) # Note - `resolver.app_id` will be `None` for app-less sandboxes @@ -165,6 +172,10 @@ async def create( Union[str, os.PathLike], Union[_Volume, _CloudBucketMount] ] = {}, # Mount points for Modal Volumes and CloudBucketMounts pty_info: Optional[api_pb2.PTYInfo] = None, + # List of ports to tunnel into the sandbox. Encrypted ports are tunneled with TLS. + encrypted_ports: Sequence[int] = [], + # List of ports to tunnel into the sandbox without encryption. + unencrypted_ports: Sequence[int] = [], _experimental_scheduler_placement: Optional[ SchedulerPlacement ] = None, # Experimental controls over fine-grained scheduling (alpha). @@ -191,6 +202,8 @@ async def create( block_network=block_network, volumes=volumes, pty_info=pty_info, + encrypted_ports=encrypted_ports, + unencrypted_ports=unencrypted_ports, _experimental_scheduler_placement=_experimental_scheduler_placement, _experimental_gpus=_experimental_gpus, ) @@ -244,6 +257,23 @@ async def wait(self, raise_on_termination: bool = True): raise SandboxTerminatedError() break + async def tunnels(self, timeout: int = 50) -> List[api_pb2.TunnelData]: + """Get tunnel metadata for the sandbox.""" + + if self._tunnels: + return self._tunnels + + req = api_pb2.SandboxGetTunnelsRequest(sandbox_id=self.object_id, timeout=timeout) + resp = await retry_transient_errors(self._client.stub.SandboxGetTunnels, req) + + # If we couldn't get the tunnels in time, report the timeout. + if resp.result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT: + raise SandboxTimeoutError() + + # Otherwise, we got the tunnels and can report the result. + self._tunnels = resp.tunnels + return resp.tunnels + async def terminate(self): """Terminate Sandbox execution.