Skip to content

Commit

Permalink
[MOD-3613] Improve client error message for asgi / wsgi apps that hav…
Browse files Browse the repository at this point in the history
…e non-nullary functions (#2190)

* Return a more helpful error to ensure that ASGI app functions get no args.

* Return ASGI app function not nullary error earlier by inspecting the function.

* Require WSGI app functions to also be non nullary and fix a small bug - the decorator accepted 0 as a positional arg.

* Fix the small bug everywhere - the decorators accepted 0 as a positional arg.

* Shuffle around the is_nullary logic and test fix.

* Switch is_nullary for method_has_params and fix bug in logic there.

* Add a bit more testing for method_has_params.

* Missing parens.

* Add a deprecation warning for default args used in asgi / wsgi app functions.

* Clarify a few comments.

* Fix tests to reflect updated warning messages.

---------

Co-authored-by: Daniel Shaar <[email protected]>
  • Loading branch information
danielshaar and danielshaar authored Sep 5, 2024
1 parent 7fcb063 commit 01c3799
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 35 deletions.
4 changes: 2 additions & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
from ._utils.async_utils import TaskContext, synchronizer
from ._utils.function_utils import (
LocalFunctionError,
callable_has_non_self_params,
is_async as get_is_async,
is_global_object,
method_has_params,
)
from .app import App, _App
from .client import Client, _Client
Expand Down Expand Up @@ -684,7 +684,7 @@ def call_lifecycle_functions(
for func in funcs:
# We are deprecating parameterized exit methods but want to gracefully handle old code.
# We can remove this once the deprecation in the actual @exit decorator is enforced.
args = (None, None, None) if method_has_params(func) else ()
args = (None, None, None) if callable_has_non_self_params(func) else ()
# in case func is non-async, it's executed here and sigint will by default
# interrupt it using a KeyboardInterrupt exception
res = func(*args)
Expand Down
28 changes: 20 additions & 8 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,16 +320,28 @@ def is_nullary(self):
return True


def method_has_params(f: Callable[..., Any]) -> bool:
"""Return True if a method (bound or unbound) has parameters other than self.
def callable_has_non_self_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has parameters other than self.
Used for deprecation of @exit() parameters.
Used to ensure that @exit(), @asgi_app, and @wsgi_app functions don't have parameters.
"""
num_params = len(inspect.signature(f).parameters)
if hasattr(f, "__self__"):
return num_params > 0
else:
return num_params > 1
return any(param.name != "self" for param in inspect.signature(f).parameters.values())


def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has non-default parameters other than self.
Used for deprecation of default parameters in @asgi_app and @wsgi_app functions.
"""
for param in inspect.signature(f).parameters.values():
if param.name == "self":
continue

if param.default != inspect.Parameter.empty:
continue

return True
return False


async def _stream_function_call_data(
Expand Down
44 changes: 34 additions & 10 deletions modal/partial_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modal_proto import api_pb2

from ._utils.async_utils import synchronize_api, synchronizer
from ._utils.function_utils import method_has_params
from ._utils.function_utils import callable_has_non_self_non_default_params, callable_has_non_self_params
from .config import logger
from .exception import InvalidError, deprecation_error, deprecation_warning
from .functions import _Function
Expand Down Expand Up @@ -191,7 +191,7 @@ def f(self):
...
```
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@method()`.")

if keep_warm is not None:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _web_endpoint(
if isinstance(_warn_parentheses_missing, str):
# Probably passing the method string as a positional argument.
raise InvalidError('Positional arguments are not allowed. Suggestion: `@web_endpoint(method="GET")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@web_endpoint()`."
)
Expand Down Expand Up @@ -334,12 +334,24 @@ def create_asgi() -> Callable:
"""
if isinstance(_warn_parentheses_missing, str):
raise InvalidError('Positional arguments are not allowed. Suggestion: `@asgi_app(label="foo")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@asgi_app()`."
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"ASGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#asgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"ASGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
f"Modal will drop support for default parameters in a future release.",
)

if not wait_for_response:
deprecation_warning(
(2024, 5, 13),
Expand Down Expand Up @@ -394,12 +406,24 @@ def create_wsgi() -> Callable:
"""
if isinstance(_warn_parentheses_missing, str):
raise InvalidError('Positional arguments are not allowed. Suggestion: `@wsgi_app(label="foo")`.')
elif _warn_parentheses_missing:
elif _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@wsgi_app()`."
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"WSGI app function {raw_f.__name__} can't have parameters. See https://modal.com/docs/guide/webhooks#wsgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"WSGI app function {raw_f.__name__} has default parameters, but shouldn't have any parameters - "
f"Modal will drop support for default parameters in a future release.",
)

if not wait_for_response:
deprecation_warning(
(2024, 5, 13),
Expand Down Expand Up @@ -508,7 +532,7 @@ def download_models(self):
LlamaTokenizer.from_pretrained(base_model)
```
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@build()`.")

def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction:
Expand All @@ -531,7 +555,7 @@ def _enter(
"""Decorator for methods which should be executed when a new container is started.
See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#enter) for more information."""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@enter()`.")

if snap:
Expand Down Expand Up @@ -561,14 +585,14 @@ def _exit(_warn_parentheses_missing=None) -> Callable[[ExitHandlerType], _Partia
"""Decorator for methods which should be executed when a container is about to exit.
See the [lifeycle function guide](https://modal.com/docs/guide/lifecycle-functions#exit) for more information."""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@exit()`.")

def wrapper(f: ExitHandlerType) -> _PartialFunction:
if isinstance(f, _PartialFunction):
_disallow_wrapping_method(f, "exit")

if method_has_params(f):
if callable_has_non_self_params(f):
message = (
"Support for decorating parameterized methods with `@exit` has been deprecated."
" Please update your code by removing the parameters."
Expand Down Expand Up @@ -601,7 +625,7 @@ async def batched_multiply(xs: list[int], ys: list[int]) -> list[int]:
See the [dynamic batching guide](https://modal.com/docs/guide/dynamic-batching) for more information.
"""
if _warn_parentheses_missing:
if _warn_parentheses_missing is not None:
raise InvalidError(
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batched()`."
)
Expand Down
61 changes: 50 additions & 11 deletions test/function_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from modal import method, web_endpoint
from modal._serialization import serialize_data_format
from modal._utils import async_utils
from modal._utils.function_utils import FunctionInfo, _stream_function_call_data, method_has_params
from modal._utils.function_utils import (
FunctionInfo,
_stream_function_call_data,
callable_has_non_self_non_default_params,
callable_has_non_self_params,
)
from modal_proto import api_pb2


Expand Down Expand Up @@ -34,23 +39,57 @@ def test_is_nullary():


class Cls:
def foo(self):
def f1(self):
pass

def bar(self, x):
def f2(self, x):
pass

def buz(self, *args):
def f3(self, *args):
pass

def f4(self, x=1):
pass


def f5():
pass


def f6(x):
pass


def f7(x=1):
pass


def test_callable_has_non_self_params():
assert not callable_has_non_self_params(Cls.f1)
assert not callable_has_non_self_params(Cls().f1)
assert callable_has_non_self_params(Cls.f2)
assert callable_has_non_self_params(Cls().f2)
assert callable_has_non_self_params(Cls.f3)
assert callable_has_non_self_params(Cls().f3)
assert callable_has_non_self_params(Cls.f4)
assert callable_has_non_self_params(Cls().f4)
assert not callable_has_non_self_params(f5)
assert callable_has_non_self_params(f6)
assert callable_has_non_self_params(f7)


def test_method_has_params():
assert not method_has_params(Cls.foo)
assert not method_has_params(Cls().foo)
assert method_has_params(Cls.bar)
assert method_has_params(Cls().bar)
assert method_has_params(Cls.buz)
assert method_has_params(Cls().buz)
def test_callable_has_non_self_non_default_params():
assert not callable_has_non_self_non_default_params(Cls.f1)
assert not callable_has_non_self_non_default_params(Cls().f1)
assert callable_has_non_self_non_default_params(Cls.f2)
assert callable_has_non_self_non_default_params(Cls().f2)
assert callable_has_non_self_non_default_params(Cls.f3)
assert callable_has_non_self_non_default_params(Cls().f3)
assert not callable_has_non_self_non_default_params(Cls.f4)
assert not callable_has_non_self_non_default_params(Cls().f4)
assert not callable_has_non_self_non_default_params(f5)
assert callable_has_non_self_non_default_params(f6)
assert not callable_has_non_self_non_default_params(f7)


class Foo:
Expand Down
38 changes: 34 additions & 4 deletions test/webhook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from modal import App, asgi_app, web_endpoint, wsgi_app
from modal._asgi import webhook_asgi_app
from modal.exception import InvalidError
from modal.exception import DeprecationError, InvalidError
from modal.functions import Function
from modal.running_app import RunningApp
from modal_proto import api_pb2
Expand Down Expand Up @@ -134,20 +134,50 @@ async def test_asgi_wsgi(servicer, client):

@app.function(serialized=True)
@asgi_app()
async def my_asgi(x):
async def my_asgi():
pass

@app.function(serialized=True)
@wsgi_app()
async def my_wsgi(x):
async def my_wsgi():
pass

with pytest.raises(InvalidError, match="can't have parameters"):

@app.function(serialized=True)
@asgi_app()
async def my_invalid_asgi(x):
pass

with pytest.raises(InvalidError, match="can't have parameters"):

@app.function(serialized=True)
@wsgi_app()
async def my_invalid_wsgi(x):
pass

with pytest.warns(DeprecationError, match="default parameters"):

@app.function(serialized=True)
@asgi_app()
async def my_deprecated_default_params_asgi(x=1):
pass

with pytest.warns(DeprecationError, match="default parameters"):

@app.function(serialized=True)
@wsgi_app()
async def my_deprecated_default_params_wsgi(x=1):
pass

async with app.run(client=client):
pass

assert len(servicer.app_functions) == 2
assert len(servicer.app_functions) == 4
assert servicer.app_functions["fu-1"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP
assert servicer.app_functions["fu-2"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP
assert servicer.app_functions["fu-3"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP
assert servicer.app_functions["fu-4"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP


def test_positional_method(servicer, client):
Expand Down

0 comments on commit 01c3799

Please sign in to comment.