diff --git a/CHANGELOG.md b/CHANGELOG.md index 52f2456ba..206a8138d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ We appreciate your patience while we speedily work towards a stable release of t +### 0.64.87 (2024-09-05) + +Sandboxes now support port tunneling. Ports can be exposed via the `open_ports` argument, and a list of active tunnels can be retrieved via the `.tunnels()` method. + + + ### 0.64.67 (2024-08-30) - Fix a regression in `modal launch` behavior not showing progress output when starting the container. diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index fcd9477c3..841f92958 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -42,9 +42,9 @@ from ._utils.async_utils import TaskContext, synchronizer from ._utils.function_utils import ( LocalFunctionError, + callable_has_non_self_params, is_async as get_is_async, is_global_object, - method_has_params, ) from .app import App, _App from .client import Client, _Client @@ -684,7 +684,7 @@ def call_lifecycle_functions( for func in funcs: # We are deprecating parameterized exit methods but want to gracefully handle old code. # We can remove this once the deprecation in the actual @exit decorator is enforced. - args = (None, None, None) if method_has_params(func) else () + args = (None, None, None) if callable_has_non_self_params(func) else () # in case func is non-async, it's executed here and sigint will by default # interrupt it using a KeyboardInterrupt exception res = func(*args) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 44dbcd642..097cf5eec 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -320,16 +320,28 @@ def is_nullary(self): return True -def method_has_params(f: Callable[..., Any]) -> bool: - """Return True if a method (bound or unbound) has parameters other than self. +def callable_has_non_self_params(f: Callable[..., Any]) -> bool: + """Return True if a callable (function, bound method, or unbound method) has parameters other than self. - Used for deprecation of @exit() parameters. + Used to ensure that @exit(), @asgi_app, and @wsgi_app functions don't have parameters. """ - num_params = len(inspect.signature(f).parameters) - if hasattr(f, "__self__"): - return num_params > 0 - else: - return num_params > 1 + return any(param.name != "self" for param in inspect.signature(f).parameters.values()) + + +def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool: + """Return True if a callable (function, bound method, or unbound method) has non-default parameters other than self. + + Used for deprecation of default parameters in @asgi_app and @wsgi_app functions. + """ + for param in inspect.signature(f).parameters.values(): + if param.name == "self": + continue + + if param.default != inspect.Parameter.empty: + continue + + return True + return False async def _stream_function_call_data( diff --git a/modal/partial_function.py b/modal/partial_function.py index 3f17ff37e..181a74f5c 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -18,7 +18,7 @@ from modal_proto import api_pb2 from ._utils.async_utils import synchronize_api, synchronizer -from ._utils.function_utils import method_has_params +from ._utils.function_utils import callable_has_non_self_non_default_params, callable_has_non_self_params from .config import logger from .exception import InvalidError, deprecation_error, deprecation_warning from .functions import _Function @@ -191,7 +191,7 @@ def f(self): ... ``` """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@method()`.") if keep_warm is not None: @@ -264,7 +264,7 @@ def _web_endpoint( if isinstance(_warn_parentheses_missing, str): # Probably passing the method string as a positional argument. raise InvalidError('Positional arguments are not allowed. Suggestion: `@web_endpoint(method="GET")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@web_endpoint()`." ) @@ -334,12 +334,24 @@ def create_asgi() -> Callable: """ if isinstance(_warn_parentheses_missing, str): raise InvalidError('Positional arguments are not allowed. Suggestion: `@asgi_app(label="foo")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@asgi_app()`." ) def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: + if callable_has_non_self_params(raw_f): + if callable_has_non_self_non_default_params(raw_f): + raise InvalidError( + f"ASGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#asgi." + ) + else: + deprecation_warning( + (2024, 9, 4), + f"ASGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - " + f"Modal will drop support for default parameters in a future release.", + ) + if not wait_for_response: deprecation_warning( (2024, 5, 13), @@ -394,12 +406,24 @@ def create_wsgi() -> Callable: """ if isinstance(_warn_parentheses_missing, str): raise InvalidError('Positional arguments are not allowed. Suggestion: `@wsgi_app(label="foo")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@wsgi_app()`." ) def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: + if callable_has_non_self_params(raw_f): + if callable_has_non_self_non_default_params(raw_f): + raise InvalidError( + f"WSGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#wsgi." + ) + else: + deprecation_warning( + (2024, 9, 4), + f"WSGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - " + f"Modal will drop support for default parameters in a future release.", + ) + if not wait_for_response: deprecation_warning( (2024, 5, 13), @@ -508,7 +532,7 @@ def download_models(self): LlamaTokenizer.from_pretrained(base_model) ``` """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@build()`.") def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: @@ -531,7 +555,7 @@ def _enter( """Decorator for methods which should be executed when a new container is started. See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#enter) for more information.""" - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@enter()`.") if snap: @@ -561,14 +585,14 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia """Decorator for methods which should be executed when a container is about to exit. See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#exit) for more information.""" - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@exit()`.") def wrapper(f: ExitHandlerType) -> _PartialFunction: if isinstance(f, _PartialFunction): _disallow_wrapping_method(f, "exit") - if method_has_params(f): + if callable_has_non_self_params(f): message = ( "Support for decorating parameterized methods with `@exit` has been deprecated." " Please update your code by removing the parameters." @@ -601,7 +625,7 @@ async def batched_multiply(xs: list[int], ys: list[int]) -> list[int]: See the [dynamic batching guide](https://modal.com/docs/guide/dynamic-batching) for more information. """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batched()`." ) diff --git a/modal/runner.py b/modal/runner.py index a4db38bd1..23fcf4c89 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -4,7 +4,7 @@ import os import time from multiprocessing.synchronize import Event -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Dict, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, TypeVar from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager @@ -27,7 +27,7 @@ InvalidError, RemoteError, _CliUserExecutionError, - deprecation_warning, + deprecation_error, ) from .execution_context import is_local from .object import _Object @@ -533,20 +533,18 @@ async def _interactive_shell(_app: _App, cmds: List[str], environment_name: str raise -def _run_stub(*args: Any, **kwargs: Any) -> AsyncGenerator[_App, None]: +def _run_stub(*args: Any, **kwargs: Any): """mdmd:hidden `run_stub` has been renamed to `run_app` and is deprecated. Please update your code. """ - deprecation_warning( + deprecation_error( (2024, 5, 1), "`run_stub` has been renamed to `run_app` and is deprecated. Please update your code." ) - return _run_app(*args, **kwargs) -def _deploy_stub(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, DeployResult]: +def _deploy_stub(*args: Any, **kwargs: Any): """`deploy_stub` has been renamed to `deploy_app` and is deprecated. Please update your code.""" - deprecation_warning((2024, 5, 1), str(_deploy_stub.__doc__)) - return _deploy_app(*args, **kwargs) + deprecation_error((2024, 5, 1), str(_deploy_stub.__doc__)) run_app = synchronize_api(_run_app) diff --git a/modal/sandbox.py b/modal/sandbox.py index 4b8e059a0..87591b81a 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -48,6 +48,7 @@ class _Sandbox(_Object, type_prefix="sb"): _stderr: _StreamReader _stdin: _StreamWriter _task_id: Optional[str] = None + _tunnels: Optional[List[api_pb2.TunnelData]] = None @staticmethod def _new( @@ -66,6 +67,8 @@ def _new( block_network: bool = False, volumes: Dict[Union[str, os.PathLike], Union[_Volume, _CloudBucketMount]] = {}, pty_info: Optional[api_pb2.PTYInfo] = None, + encrypted_ports: Sequence[int] = [], + unencrypted_ports: Sequence[int] = [], _experimental_scheduler_placement: Optional[SchedulerPlacement] = None, _experimental_gpus: Sequence[GPU_T] = [], ) -> "_Sandbox": @@ -109,6 +112,9 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona for path, volume in validated_volumes ] + open_ports = [api_pb2.PortSpec(port=port, unencrypted=False) for port in encrypted_ports] + open_ports.extend([api_pb2.PortSpec(port=port, unencrypted=True) for port in unencrypted_ports]) + ephemeral_disk = None # Ephemeral disk requests not supported on Sandboxes. definition = api_pb2.Sandbox( entrypoint_args=entrypoint_args, @@ -130,6 +136,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona scheduler_placement=scheduler_placement.proto if scheduler_placement else None, worker_id=config.get("worker_id"), i6pn_enabled=config.get("i6pn_enabled"), + open_ports=api_pb2.PortSpecs(ports=open_ports), ) # Note - `resolver.app_id` will be `None` for app-less sandboxes @@ -166,6 +173,10 @@ async def create( Union[str, os.PathLike], Union[_Volume, _CloudBucketMount] ] = {}, # Mount points for Modal Volumes and CloudBucketMounts pty_info: Optional[api_pb2.PTYInfo] = None, + # List of ports to tunnel into the sandbox. Encrypted ports are tunneled with TLS. + encrypted_ports: Sequence[int] = [], + # List of ports to tunnel into the sandbox without encryption. + unencrypted_ports: Sequence[int] = [], _experimental_scheduler_placement: Optional[ SchedulerPlacement ] = None, # Experimental controls over fine-grained scheduling (alpha). @@ -192,6 +203,8 @@ async def create( block_network=block_network, volumes=volumes, pty_info=pty_info, + encrypted_ports=encrypted_ports, + unencrypted_ports=unencrypted_ports, _experimental_scheduler_placement=_experimental_scheduler_placement, _experimental_gpus=_experimental_gpus, ) @@ -245,6 +258,23 @@ async def wait(self, raise_on_termination: bool = True): raise SandboxTerminatedError() break + async def tunnels(self, timeout: int = 50) -> List[api_pb2.TunnelData]: + """Get tunnel metadata for the sandbox.""" + + if self._tunnels: + return self._tunnels + + req = api_pb2.SandboxGetTunnelsRequest(sandbox_id=self.object_id, timeout=timeout) + resp = await retry_transient_errors(self._client.stub.SandboxGetTunnels, req) + + # If we couldn't get the tunnels in time, report the timeout. + if resp.result.status == api_pb2.GenericResult.GENERIC_STATUS_TIMEOUT: + raise SandboxTimeoutError() + + # Otherwise, we got the tunnels and can report the result. + self._tunnels = resp.tunnels + return resp.tunnels + async def terminate(self): """Terminate Sandbox execution. diff --git a/modal/serving.py b/modal/serving.py index a0ebe3c63..391318db7 100644 --- a/modal/serving.py +++ b/modal/serving.py @@ -16,7 +16,7 @@ from .cli.import_refs import import_app from .client import _Client from .config import config -from .exception import deprecation_warning +from .exception import deprecation_error from .runner import _run_app, serve_update if TYPE_CHECKING: @@ -123,8 +123,7 @@ async def _serve_app( def _serve_stub(*args, **kwargs): - deprecation_warning((2024, 5, 1), "`serve_stub` is deprecated. Please use `serve_app` instead.") - return _run_app(*args, **kwargs) + deprecation_error((2024, 5, 1), "`serve_stub` is deprecated. Please use `serve_app` instead.") serve_app = synchronize_api(_serve_app) diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index 13acde769..7cf3c378e 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 84 # git: ed5cac4 +build_number = 87 # git: c8cfaf9 diff --git a/test/app_test.py b/test/app_test.py index b0004983d..4cde4b381 100644 --- a/test/app_test.py +++ b/test/app_test.py @@ -338,7 +338,7 @@ def test_stub(): def test_deploy_stub(servicer, client): app = App("xyz") deploy_app(app, client=client) - with pytest.warns(match="deploy_app"): + with pytest.raises(DeprecationError, match="deploy_app"): deploy_stub(app, client=client) diff --git a/test/function_utils_test.py b/test/function_utils_test.py index 08b520742..17926a5f2 100644 --- a/test/function_utils_test.py +++ b/test/function_utils_test.py @@ -6,7 +6,12 @@ from modal import method, web_endpoint from modal._serialization import serialize_data_format from modal._utils import async_utils -from modal._utils.function_utils import FunctionInfo, _stream_function_call_data, method_has_params +from modal._utils.function_utils import ( + FunctionInfo, + _stream_function_call_data, + callable_has_non_self_non_default_params, + callable_has_non_self_params, +) from modal_proto import api_pb2 @@ -34,23 +39,57 @@ def test_is_nullary(): class Cls: - def foo(self): + def f1(self): pass - def bar(self, x): + def f2(self, x): pass - def buz(self, *args): + def f3(self, *args): pass + def f4(self, x=1): + pass + + +def f5(): + pass + + +def f6(x): + pass + + +def f7(x=1): + pass + + +def test_callable_has_non_self_params(): + assert not callable_has_non_self_params(Cls.f1) + assert not callable_has_non_self_params(Cls().f1) + assert callable_has_non_self_params(Cls.f2) + assert callable_has_non_self_params(Cls().f2) + assert callable_has_non_self_params(Cls.f3) + assert callable_has_non_self_params(Cls().f3) + assert callable_has_non_self_params(Cls.f4) + assert callable_has_non_self_params(Cls().f4) + assert not callable_has_non_self_params(f5) + assert callable_has_non_self_params(f6) + assert callable_has_non_self_params(f7) + -def test_method_has_params(): - assert not method_has_params(Cls.foo) - assert not method_has_params(Cls().foo) - assert method_has_params(Cls.bar) - assert method_has_params(Cls().bar) - assert method_has_params(Cls.buz) - assert method_has_params(Cls().buz) +def test_callable_has_non_self_non_default_params(): + assert not callable_has_non_self_non_default_params(Cls.f1) + assert not callable_has_non_self_non_default_params(Cls().f1) + assert callable_has_non_self_non_default_params(Cls.f2) + assert callable_has_non_self_non_default_params(Cls().f2) + assert callable_has_non_self_non_default_params(Cls.f3) + assert callable_has_non_self_non_default_params(Cls().f3) + assert not callable_has_non_self_non_default_params(Cls.f4) + assert not callable_has_non_self_non_default_params(Cls().f4) + assert not callable_has_non_self_non_default_params(f5) + assert callable_has_non_self_non_default_params(f6) + assert not callable_has_non_self_non_default_params(f7) class Foo: diff --git a/test/webhook_test.py b/test/webhook_test.py index 3fba649c2..60318d55d 100644 --- a/test/webhook_test.py +++ b/test/webhook_test.py @@ -8,7 +8,7 @@ from modal import App, asgi_app, web_endpoint, wsgi_app from modal._asgi import webhook_asgi_app -from modal.exception import InvalidError +from modal.exception import DeprecationError, InvalidError from modal.functions import Function from modal.running_app import RunningApp from modal_proto import api_pb2 @@ -134,20 +134,50 @@ async def test_asgi_wsgi(servicer, client): @app.function(serialized=True) @asgi_app() - async def my_asgi(x): + async def my_asgi(): pass @app.function(serialized=True) @wsgi_app() - async def my_wsgi(x): + async def my_wsgi(): pass + with pytest.raises(InvalidError, match="can't have parameters"): + + @app.function(serialized=True) + @asgi_app() + async def my_invalid_asgi(x): + pass + + with pytest.raises(InvalidError, match="can't have parameters"): + + @app.function(serialized=True) + @wsgi_app() + async def my_invalid_wsgi(x): + pass + + with pytest.warns(DeprecationError, match="default parameters"): + + @app.function(serialized=True) + @asgi_app() + async def my_deprecated_default_params_asgi(x=1): + pass + + with pytest.warns(DeprecationError, match="default parameters"): + + @app.function(serialized=True) + @wsgi_app() + async def my_deprecated_default_params_wsgi(x=1): + pass + async with app.run(client=client): pass - assert len(servicer.app_functions) == 2 + assert len(servicer.app_functions) == 4 assert servicer.app_functions["fu-1"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP assert servicer.app_functions["fu-2"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP + assert servicer.app_functions["fu-3"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP + assert servicer.app_functions["fu-4"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP def test_positional_method(servicer, client):