Skip to content

Commit

Permalink
Add a deprecation warning for default args used in asgi / wsgi app fu…
Browse files Browse the repository at this point in the history
…nctions.
  • Loading branch information
danielshaar committed Sep 4, 2024
1 parent 14c06bc commit 85e76c3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 34 deletions.
4 changes: 2 additions & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
from ._utils.async_utils import TaskContext, synchronizer
from ._utils.function_utils import (
LocalFunctionError,
callable_has_non_self_params,
is_async as get_is_async,
is_global_object,
method_has_params,
)
from .app import App, _App
from .client import Client, _Client
Expand Down Expand Up @@ -684,7 +684,7 @@ def call_lifecycle_functions(
for func in funcs:
# We are deprecating parameterized exit methods but want to gracefully handle old code.
# We can remove this once the deprecation in the actual @exit decorator is enforced.
args = (None, None, None) if method_has_params(func) else ()
args = (None, None, None) if callable_has_non_self_params(func) else ()
# in case func is non-async, it's executed here and sigint will by default
# interrupt it using a KeyboardInterrupt exception
res = func(*args)
Expand Down
23 changes: 20 additions & 3 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,31 @@ def is_nullary(self):
return True


def method_has_params(f: Callable[..., Any]) -> bool:
"""Return True if a method (bound or unbound) has parameters other than self.
def callable_has_non_self_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has parameters other than self.
Used for deprecation of @exit() parameters and ensuring asgi / wsgi app functions don't have args.
Used for deprecation of @exit() parameters and ensuring @asgi_app / @wsgi_app functions don't have non-default
parameters.
"""
return any(param.name != "self" for param in inspect.signature(f).parameters.values())


def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
"""Return True if a callable (function, bound method, or unbound method) has non-default parameters other than self.
Used for deprecation of default parameters in asgi / wsgi app functions.
"""
for param in inspect.signature(f).parameters.values():
if param.name == "self":
continue

if param.default != inspect.Parameter.empty:
continue

return True
return False


async def _stream_function_call_data(
client, function_call_id: str, variant: Literal["data_in", "data_out"]
) -> AsyncIterator[Any]:
Expand Down
34 changes: 24 additions & 10 deletions modal/partial_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modal_proto import api_pb2

from ._utils.async_utils import synchronize_api, synchronizer
from ._utils.function_utils import method_has_params
from ._utils.function_utils import callable_has_non_self_non_default_params, callable_has_non_self_params
from .config import logger
from .exception import InvalidError, deprecation_error, deprecation_warning
from .functions import _Function
Expand Down Expand Up @@ -340,10 +340,17 @@ def create_asgi() -> Callable:
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if method_has_params(raw_f):
raise InvalidError(
f"ASGI app function {raw_f.__name__} can't have arguments. See https://modal.com/docs/guide/webhooks#asgi."
)
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"ASGI app function {raw_f.__name__} can't have arguments. See https://modal.com/docs/guide/webhooks#asgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"ASGI app function {raw_f.__name__} has default arguments - Modal will drop support for this in a"
f" future release.",
)

if not wait_for_response:
deprecation_warning(
Expand Down Expand Up @@ -405,10 +412,17 @@ def create_wsgi() -> Callable:
)

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
if method_has_params(raw_f):
raise InvalidError(
f"WSGI app function {raw_f.__name__} can't have arguments. See https://modal.com/docs/guide/webhooks#wsgi."
)
if callable_has_non_self_params(raw_f):
if callable_has_non_self_non_default_params(raw_f):
raise InvalidError(
f"WSGI app function {raw_f.__name__} can't have arguments. See https://modal.com/docs/guide/webhooks#wsgi."
)
else:
deprecation_warning(
(2024, 9, 4),
f"WSGI app function {raw_f.__name__} has default arguments - Modal will drop support for this in a"
f"future release.",
)

if not wait_for_response:
deprecation_warning(
Expand Down Expand Up @@ -578,7 +592,7 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction:
if isinstance(f, _PartialFunction):
_disallow_wrapping_method(f, "exit")

if method_has_params(f):
if callable_has_non_self_params(f):
message = (
"Support for decorating parameterized methods with `@exit` has been deprecated."
" Please update your code by removing the parameters."
Expand Down
65 changes: 48 additions & 17 deletions test/function_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from modal import method, web_endpoint
from modal._serialization import serialize_data_format
from modal._utils import async_utils
from modal._utils.function_utils import FunctionInfo, _stream_function_call_data, method_has_params
from modal._utils.function_utils import (
FunctionInfo,
_stream_function_call_data,
callable_has_non_self_non_default_params,
callable_has_non_self_params,
)
from modal_proto import api_pb2


Expand Down Expand Up @@ -34,31 +39,57 @@ def test_is_nullary():


class Cls:
def foo(self):
def f1(self):
pass

def bar(self, x):
def f2(self, x):
pass

def buz(self, *args):
def f3(self, *args):
pass


def test_method_has_params():
def qux():
def f4(self, x=1):
pass

def foobar(baz):
pass

assert not method_has_params(Cls.foo)
assert not method_has_params(Cls().foo)
assert method_has_params(Cls.bar)
assert method_has_params(Cls().bar)
assert method_has_params(Cls.buz)
assert method_has_params(Cls().buz)
assert not method_has_params(qux)
assert method_has_params(foobar)
def f5():
pass


def f6(x):
pass


def f7(x=1):
pass


def test_callable_has_non_self_params():
assert not callable_has_non_self_params(Cls.f1)
assert not callable_has_non_self_params(Cls().f1)
assert callable_has_non_self_params(Cls.f2)
assert callable_has_non_self_params(Cls().f2)
assert callable_has_non_self_params(Cls.f3)
assert callable_has_non_self_params(Cls().f3)
assert callable_has_non_self_params(Cls.f4)
assert callable_has_non_self_params(Cls().f4)
assert not callable_has_non_self_params(f5)
assert callable_has_non_self_params(f6)
assert callable_has_non_self_params(f7)


def test_callable_has_non_self_non_default_params():
assert not callable_has_non_self_non_default_params(Cls.f1)
assert not callable_has_non_self_non_default_params(Cls().f1)
assert callable_has_non_self_non_default_params(Cls.f2)
assert callable_has_non_self_non_default_params(Cls().f2)
assert callable_has_non_self_non_default_params(Cls.f3)
assert callable_has_non_self_non_default_params(Cls().f3)
assert not callable_has_non_self_non_default_params(Cls.f4)
assert not callable_has_non_self_non_default_params(Cls().f4)
assert not callable_has_non_self_non_default_params(f5)
assert callable_has_non_self_non_default_params(f6)
assert not callable_has_non_self_non_default_params(f7)


class Foo:
Expand Down
20 changes: 18 additions & 2 deletions test/webhook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from modal import App, asgi_app, web_endpoint, wsgi_app
from modal._asgi import webhook_asgi_app
from modal.exception import InvalidError
from modal.exception import DeprecationError, InvalidError
from modal.functions import Function
from modal.running_app import RunningApp
from modal_proto import api_pb2
Expand Down Expand Up @@ -156,12 +156,28 @@ async def my_invalid_asgi(x):
async def my_invalid_wsgi(x):
pass

with pytest.warns(DeprecationError, match="default argument"):

@app.function(serialized=True)
@asgi_app()
async def my_deprecated_default_params_asgi(x=1):
pass

with pytest.warns(DeprecationError, match="default arguments"):

@app.function(serialized=True)
@wsgi_app()
async def my_deprecated_default_params_wsgi(x=1):
pass

async with app.run(client=client):
pass

assert len(servicer.app_functions) == 2
assert len(servicer.app_functions) == 4
assert servicer.app_functions["fu-1"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP
assert servicer.app_functions["fu-2"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP
assert servicer.app_functions["fu-3"].webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP
assert servicer.app_functions["fu-4"].webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP


def test_positional_method(servicer, client):
Expand Down

0 comments on commit 85e76c3

Please sign in to comment.