diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index fcd9477c3..841f92958 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -42,9 +42,9 @@ from ._utils.async_utils import TaskContext, synchronizer from ._utils.function_utils import ( LocalFunctionError, + callable_has_non_self_params, is_async as get_is_async, is_global_object, - method_has_params, ) from .app import App, _App from .client import Client, _Client @@ -684,7 +684,7 @@ def call_lifecycle_functions( for func in funcs: # We are deprecating parameterized exit methods but want to gracefully handle old code. # We can remove this once the deprecation in the actual @exit decorator is enforced. - args = (None, None, None) if method_has_params(func) else () + args = (None, None, None) if callable_has_non_self_params(func) else () # in case func is non-async, it's executed here and sigint will by default # interrupt it using a KeyboardInterrupt exception res = func(*args) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 44dbcd642..097cf5eec 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -320,16 +320,28 @@ def is_nullary(self): return True -def method_has_params(f: Callable[..., Any]) -> bool: - """Return True if a method (bound or unbound) has parameters other than self. +def callable_has_non_self_params(f: Callable[..., Any]) -> bool: + """Return True if a callable (function, bound method, or unbound method) has parameters other than self. - Used for deprecation of @exit() parameters. + Used to ensure that @exit(), @asgi_app, and @wsgi_app functions don't have parameters. """ - num_params = len(inspect.signature(f).parameters) - if hasattr(f, "__self__"): - return num_params > 0 - else: - return num_params > 1 + return any(param.name != "self" for param in inspect.signature(f).parameters.values()) + + +def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool: + """Return True if a callable (function, bound method, or unbound method) has non-default parameters other than self. + + Used for deprecation of default parameters in @asgi_app and @wsgi_app functions. + """ + for param in inspect.signature(f).parameters.values(): + if param.name == "self": + continue + + if param.default != inspect.Parameter.empty: + continue + + return True + return False async def _stream_function_call_data( diff --git a/modal/partial_function.py b/modal/partial_function.py index 3f17ff37e..181a74f5c 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -18,7 +18,7 @@ from modal_proto import api_pb2 from ._utils.async_utils import synchronize_api, synchronizer -from ._utils.function_utils import method_has_params +from ._utils.function_utils import callable_has_non_self_non_default_params, callable_has_non_self_params from .config import logger from .exception import InvalidError, deprecation_error, deprecation_warning from .functions import _Function @@ -191,7 +191,7 @@ def f(self): ... ``` """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@method()`.") if keep_warm is not None: @@ -264,7 +264,7 @@ def _web_endpoint( if isinstance(_warn_parentheses_missing, str): # Probably passing the method string as a positional argument. raise InvalidError('Positional arguments are not allowed. Suggestion: `@web_endpoint(method="GET")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@web_endpoint()`." ) @@ -334,12 +334,24 @@ def create_asgi() -> Callable: """ if isinstance(_warn_parentheses_missing, str): raise InvalidError('Positional arguments are not allowed. Suggestion: `@asgi_app(label="foo")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@asgi_app()`." ) def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: + if callable_has_non_self_params(raw_f): + if callable_has_non_self_non_default_params(raw_f): + raise InvalidError( + f"ASGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#asgi." + ) + else: + deprecation_warning( + (2024, 9, 4), + f"ASGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - " + f"Modal will drop support for default parameters in a future release.", + ) + if not wait_for_response: deprecation_warning( (2024, 5, 13), @@ -394,12 +406,24 @@ def create_wsgi() -> Callable: """ if isinstance(_warn_parentheses_missing, str): raise InvalidError('Positional arguments are not allowed. Suggestion: `@wsgi_app(label="foo")`.') - elif _warn_parentheses_missing: + elif _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@wsgi_app()`." ) def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: + if callable_has_non_self_params(raw_f): + if callable_has_non_self_non_default_params(raw_f): + raise InvalidError( + f"WSGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#wsgi." + ) + else: + deprecation_warning( + (2024, 9, 4), + f"WSGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - " + f"Modal will drop support for default parameters in a future release.", + ) + if not wait_for_response: deprecation_warning( (2024, 5, 13), @@ -508,7 +532,7 @@ def download_models(self): LlamaTokenizer.from_pretrained(base_model) ``` """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@build()`.") def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: @@ -531,7 +555,7 @@ def _enter( """Decorator for methods which should be executed when a new container is started. See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#enter) for more information.""" - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@enter()`.") if snap: @@ -561,14 +585,14 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia """Decorator for methods which should be executed when a container is about to exit. See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#exit) for more information.""" - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@exit()`.") def wrapper(f: ExitHandlerType) -> _PartialFunction: if isinstance(f, _PartialFunction): _disallow_wrapping_method(f, "exit") - if method_has_params(f): + if callable_has_non_self_params(f): message = ( "Support for decorating parameterized methods with `@exit` has been deprecated." " Please update your code by removing the parameters." @@ -601,7 +625,7 @@ async def batched_multiply(xs: list[int], ys: list[int]) -> list[int]: See the [dynamic batching guide](https://modal.com/docs/guide/dynamic-batching) for more information. """ - if _warn_parentheses_missing: + if _warn_parentheses_missing is not None: raise InvalidError( "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batched()`." ) diff --git a/test/function_utils_test.py b/test/function_utils_test.py index 08b520742..17926a5f2 100644 --- a/test/function_utils_test.py +++ b/test/function_utils_test.py @@ -6,7 +6,12 @@ from modal import method, web_endpoint from modal._serialization import serialize_data_format from modal._utils import async_utils -from modal._utils.function_utils import FunctionInfo, _stream_function_call_data, method_has_params +from modal._utils.function_utils import ( + FunctionInfo, + _stream_function_call_data, + callable_has_non_self_non_default_params, + callable_has_non_self_params, +) from modal_proto import api_pb2 @@ -34,23 +39,57 @@ def test_is_nullary(): class Cls: - def foo(self): + def f1(self): pass - def bar(self, x): + def f2(self, x): pass - def buz(self, *args): + def f3(self, *args): pass + def f4(self, x=1): + pass + + +def f5(): + pass + + +def f6(x): + pass + + +def f7(x=1): + pass + + +def test_callable_has_non_self_params(): + assert not callable_has_non_self_params(Cls.f1) + assert not callable_has_non_self_params(Cls().f1) + assert callable_has_non_self_params(Cls.f2) + assert callable_has_non_self_params(Cls().f2) + assert callable_has_non_self_params(Cls.f3) + assert callable_has_non_self_params(Cls().f3) + assert callable_has_non_self_params(Cls.f4) + assert callable_has_non_self_params(Cls().f4) + assert not callable_has_non_self_params(f5) + assert callable_has_non_self_params(f6) + assert callable_has_non_self_params(f7) + -def test_method_has_params(): - assert not method_has_params(Cls.foo) - assert not method_has_params(Cls().foo) - assert method_has_params(Cls.bar) - assert method_has_params(Cls().bar) - assert method_has_params(Cls.buz) - assert method_has_params(Cls().buz) +def test_callable_has_non_self_non_default_params(): + assert not callable_has_non_self_non_default_params(Cls.f1) + assert not callable_has_non_self_non_default_params(Cls().f1) + assert callable_has_non_self_non_default_params(Cls.f2) + assert callable_has_non_self_non_default_params(Cls().f2) + assert callable_has_non_self_non_default_params(Cls.f3) + assert callable_has_non_self_non_default_params(Cls().f3) + assert not callable_has_non_self_non_default_params(Cls.f4) + assert not callable_has_non_self_non_default_params(Cls().f4) + assert not callable_has_non_self_non_default_params(f5) + assert callable_has_non_self_non_default_params(f6) + assert not callable_has_non_self_non_default_params(f7) class Foo: diff --git a/test/webhook_test.py b/test/webhook_test.py index 3fba649c2..60318d55d 100644 --- a/test/webhook_test.py +++ b/test/webhook_test.py @@ -8,7 +8,7 @@ from modal import App, asgi_app, web_endpoint, wsgi_app from modal._asgi import webhook_asgi_app -from modal.exception import InvalidError +from modal.exception import DeprecationError, InvalidError from modal.functions import Function from modal.running_app import RunningApp from modal_proto import api_pb2 @@ -134,20 +134,50 @@ async def test_asgi_wsgi(servicer, client): @app.function(serialized=True) @asgi_app() - async def my_asgi(x): + async def my_asgi(): pass @app.function(serialized=True) @wsgi_app() - async def my_wsgi(x): + async def my_wsgi(): pass + with pytest.raises(InvalidError, match="can't have parameters"): + + @app.function(serialized=True) + @asgi_app() + async def my_invalid_asgi(x): + pass + + with pytest.raises(InvalidError, match="can't have parameters"): + + @app.function(serialized=True) + @wsgi_app() + async def my_invalid_wsgi(x): + pass + + with pytest.warns(DeprecationError, match="default parameters"): + + @app.function(serialized=True) + @asgi_app() + async def my_deprecated_default_params_asgi(x=1): + pass + + with pytest.warns(DeprecationError, match="default parameters"): + + @app.function(serialized=True) + @wsgi_app() + async def my_deprecated_default_params_wsgi(x=1): + pass + async with app.run(client=client): pass - assert len(servicer.app_functions) == 2 + assert len(servicer.app_functions) == 4 assert servicer.app_functions["fu-1"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP assert servicer.app_functions["fu-2"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP + assert servicer.app_functions["fu-3"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP + assert servicer.app_functions["fu-4"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP def test_positional_method(servicer, client):