Skip to content

Commit

Permalink
Merge branch 'main' into ksalahi/i6pn-sandbox2
Browse files Browse the repository at this point in the history
  • Loading branch information
TheQuantumFractal authored Sep 5, 2024
2 parents ef9a090 + 860348a commit 78884df
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 48 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ We appreciate your patience while we speedily work towards a stable release of t

<!-- NEW CONTENT GENERATED BELOW. PLEASE PRESERVE THIS COMMENT. -->

### 0.64.87 (2024-09-05)

Sandboxes now support port tunneling. Ports can be exposed via the `open_ports` argument, and a list of active tunnels can be retrieved via the `.tunnels()` method.



### 0.64.67 (2024-08-30)

- Fix a regression in `modal launch` behavior not showing progress output when starting the container.
Expand Down
4 changes: 2 additions & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
from ._utils.async_utils import TaskContext, synchronizer
from ._utils.function_utils import (
LocalFunctionError,
callable_has_non_self_params,
is_async as get_is_async,
is_global_object,
method_has_params,
)
from .app import App, _App
from .client import Client, _Client
Expand Down Expand Up @@ -684,7 +684,7 @@ def call_lifecycle_functions(
for func in funcs:
# We are deprecating parameterized exit methods but want to gracefully handle old code.
# We can remove this once the deprecation in the actual @exit decorator is enforced.
args = (None, None, None) if method_has_params(func) else ()
args = (None, None, None) if callable_has_non_self_params(func) else ()
# in case func is non-async, it's executed here and sigint will by default
# interrupt it using a KeyboardInterrupt exception
res = func(*args)
Expand Down
28 changes: 20 additions & 8 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,16 +320,28 @@ def is_nullary(self):
return True


def method_has_params(f: Callable[..., Any]) -> bool:
"""Return True if a method (bound or unbound) has parameters other than self.
def callable_has_non_self_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has parameters other than self.
Used for deprecation of @exit() parameters.
Used to ensure that @exit(), @asgi_app, and @wsgi_app functions don't have parameters.
"""
num_params = len(inspect.signature(f).parameters)
if hasattr(f, "__self__"):
return num_params > 0
else:
return num_params > 1
return any(param.name != "self" for param in inspect.signature(f).parameters.values())


def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has non-default parameters other than self.
Used for deprecation of default parameters in @asgi_app and @wsgi_app functions.
"""
for param in inspect.signature(f).parameters.values():
if param.name == "self":
continue

if param.default != inspect.Parameter.empty:
continue

return True
return False


async def _stream_function_call_data(
Expand Down
44 changes: 34 additions & 10 deletions modal/partial_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modal_proto import api_pb2

from ._utils.async_utils import synchronize_api, synchronizer
from ._utils.function_utils import method_has_params
from ._utils.function_utils import callable_has_non_self_non_default_params, callable_has_non_self_params
from .config import logger
from .exception import InvalidError, deprecation_error, deprecation_warning
from .functions import _Function
Expand Down Expand Up @@ -191,7 +191,7 @@ def f(self):
...
```
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@method()`.")

if keep_warm is not None:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _web_endpoint(
if isinstance(_warn_parentheses_missing, str):
# Probably passing the method string as a positional argument.
raise InvalidError('Positional arguments are not allowed. Suggestion: `@web_endpoint(method="GET")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@web_endpoint()`."
)
Expand Down Expand Up @@ -334,12 +334,24 @@ def create_asgi() -> Callable:
"""
if isinstance(_warn_parentheses_missing, str):
raise InvalidError('Positional arguments are not allowed. Suggestion: `@asgi_app(label="foo")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@asgi_app()`."
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"ASGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#asgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"ASGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
f"Modal will drop support for default parameters in a future release.",
)

if not wait_for_response:
deprecation_warning(
(2024, 5, 13),
Expand Down Expand Up @@ -394,12 +406,24 @@ def create_wsgi() -> Callable:
"""
if isinstance(_warn_parentheses_missing, str):
raise InvalidError('Positional arguments are not allowed. Suggestion: `@wsgi_app(label="foo")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@wsgi_app()`."
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"WSGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#wsgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"WSGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
f"Modal will drop support for default parameters in a future release.",
)

if not wait_for_response:
deprecation_warning(
(2024, 5, 13),
Expand Down Expand Up @@ -508,7 +532,7 @@ def download_models(self):
LlamaTokenizer.from_pretrained(base_model)
```
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@build()`.")

def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction:
Expand All @@ -531,7 +555,7 @@ def _enter(
"""Decorator for methods which should be executed when a new container is started.
See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#enter) for more information."""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@enter()`.")

if snap:
Expand Down Expand Up @@ -561,14 +585,14 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia
"""Decorator for methods which should be executed when a container is about to exit.
See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#exit) for more information."""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@exit()`.")

def wrapper(f: ExitHandlerType) -> _PartialFunction:
if isinstance(f, _PartialFunction):
_disallow_wrapping_method(f, "exit")

if method_has_params(f):
if callable_has_non_self_params(f):
message = (
"Support for decorating parameterized methods with `@exit` has been deprecated."
" Please update your code by removing the parameters."
Expand Down Expand Up @@ -601,7 +625,7 @@ async def batched_multiply(xs: list[int], ys: list[int]) -> list[int]:
See the [dynamic batching guide](https://modal.com/docs/guide/dynamic-batching) for more information.
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batched()`."
)
Expand Down
14 changes: 6 additions & 8 deletions modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import time
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Dict, List, Optional, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, TypeVar

from grpclib import GRPCError, Status
from synchronicity.async_wrap import asynccontextmanager
Expand All @@ -27,7 +27,7 @@
InvalidError,
RemoteError,
_CliUserExecutionError,
deprecation_warning,
deprecation_error,
)
from .execution_context import is_local
from .object import _Object
Expand Down Expand Up @@ -533,20 +533,18 @@ async def _interactive_shell(_app: _App, cmds: List[str], environment_name: str
raise


def _run_stub(*args: Any, **kwargs: Any) -> AsyncGenerator[_App, None]:
def _run_stub(*args: Any, **kwargs: Any):
"""mdmd:hidden
`run_stub` has been renamed to `run_app` and is deprecated. Please update your code.
"""
deprecation_warning(
deprecation_error(
(2024, 5, 1), "`run_stub` has been renamed to `run_app` and is deprecated. Please update your code."
)
return _run_app(*args, **kwargs)


def _deploy_stub(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, DeployResult]:
def _deploy_stub(*args: Any, **kwargs: Any):
"""`deploy_stub` has been renamed to `deploy_app` and is deprecated. Please update your code."""
deprecation_warning((2024, 5, 1), str(_deploy_stub.__doc__))
return _deploy_app(*args, **kwargs)
deprecation_error((2024, 5, 1), str(_deploy_stub.__doc__))


run_app = synchronize_api(_run_app)
Expand Down
30 changes: 30 additions & 0 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class _Sandbox(_Object, type_prefix="sb"):
_stderr: _StreamReader
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[List[api_pb2.TunnelData]] = None

@staticmethod
def _new(
Expand All @@ -66,6 +67,8 @@ def _new(
block_network: bool = False,
volumes: Dict[Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]] = {},
pty_info: Optional[api_pb2.PTYInfo] = None,
encrypted_ports: Sequence[int] = [],
unencrypted_ports: Sequence[int] = [],
_experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
_experimental_gpus: Sequence[GPU_T] = [],
) -> "_Sandbox":
Expand Down Expand Up @@ -109,6 +112,9 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
for path, volume in validated_volumes
]

open_ports = [api_pb2.PortSpec(port=port, unencrypted=False) for port in encrypted_ports]
open_ports.extend([api_pb2.PortSpec(port=port, unencrypted=True) for port in unencrypted_ports])

ephemeral_disk = None # Ephemeral disk requests not supported on Sandboxes.
definition = api_pb2.Sandbox(
entrypoint_args=entrypoint_args,
Expand All @@ -130,6 +136,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
worker_id=config.get("worker_id"),
i6pn_enabled=config.get("i6pn_enabled"),
open_ports=api_pb2.PortSpecs(ports=open_ports),
)

# Note - `resolver.app_id` will be `None` for app-less sandboxes
Expand Down Expand Up @@ -166,6 +173,10 @@ async def create(
Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]
] = {}, # Mount points for Modal Volumes and CloudBucketMounts
pty_info: Optional[api_pb2.PTYInfo] = None,
# List of ports to tunnel into the sandbox. Encrypted ports are tunneled with TLS.
encrypted_ports: Sequence[int] = [],
# List of ports to tunnel into the sandbox without encryption.
unencrypted_ports: Sequence[int] = [],
_experimental_scheduler_placement: Optional[
SchedulerPlacement
] = None, # Experimental controls over fine-grained scheduling (alpha).
Expand All @@ -192,6 +203,8 @@ async def create(
block_network=block_network,
volumes=volumes,
pty_info=pty_info,
encrypted_ports=encrypted_ports,
unencrypted_ports=unencrypted_ports,
_experimental_scheduler_placement=_experimental_scheduler_placement,
_experimental_gpus=_experimental_gpus,
)
Expand Down Expand Up @@ -245,6 +258,23 @@ async def wait(self, raise_on_termination: bool = True):
raise SandboxTerminatedError()
break

async def tunnels(self, timeout: int = 50) -> List[api_pb2.TunnelData]:
"""Get tunnel metadata for the sandbox."""

if self._tunnels:
return self._tunnels

req = api_pb2.SandboxGetTunnelsRequest(sandbox_id=self.object_id, timeout=timeout)
resp = await retry_transient_errors(self._client.stub.SandboxGetTunnels, req)

# If we couldn't get the tunnels in time, report the timeout.
if resp.result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT:
raise SandboxTimeoutError()

# Otherwise, we got the tunnels and can report the result.
self._tunnels = resp.tunnels
return resp.tunnels

async def terminate(self):
"""Terminate Sandbox execution.
Expand Down
5 changes: 2 additions & 3 deletions modal/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .cli.import_refs import import_app
from .client import _Client
from .config import config
from .exception import deprecation_warning
from .exception import deprecation_error
from .runner import _run_app, serve_update

if TYPE_CHECKING:
Expand Down Expand Up @@ -123,8 +123,7 @@ async def _serve_app(


def _serve_stub(*args, **kwargs):
deprecation_warning((2024, 5, 1), "`serve_stub` is deprecated. Please use `serve_app` instead.")
return _run_app(*args, **kwargs)
deprecation_error((2024, 5, 1), "`serve_stub` is deprecated. Please use `serve_app` instead.")


serve_app = synchronize_api(_serve_app)
Expand Down
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 84 # git: ed5cac4
build_number = 87 # git: c8cfaf9
2 changes: 1 addition & 1 deletion test/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_stub():
def test_deploy_stub(servicer, client):
app = App("xyz")
deploy_app(app, client=client)
with pytest.warns(match="deploy_app"):
with pytest.raises(DeprecationError, match="deploy_app"):
deploy_stub(app, client=client)


Expand Down
Loading

0 comments on commit 78884df

Please sign in to comment.