Skip to content

Commit

Permalink
Merge pull request #944 from sirosen/tighten-mypy
Browse files Browse the repository at this point in the history
Improve type annotations for the pyramid and tornado parsers
  • Loading branch information
sirosen authored May 31, 2024
2 parents f3de9c7 + a553245 commit e46313a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 45 deletions.
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ disallow_untyped_defs = false

[mypy-webargs.falconparser]
disallow_untyped_defs = false

[mypy-webargs.pyramidparser]
disallow_untyped_defs = false

[mypy-webargs.tornadoparser]
disallow_untyped_defs = false
2 changes: 1 addition & 1 deletion src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
)
return func(*args, **kwargs)

wrapper.__wrapped__ = func # type: ignore
wrapper.__wrapped__ = func
_record_arg_name(wrapper, arg_name)
return wrapper

Expand Down
61 changes: 39 additions & 22 deletions src/webargs/pyramidparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def hello_world(request, args):
from __future__ import annotations

import functools
import typing
from collections.abc import Mapping

import marshmallow as ma
Expand All @@ -38,6 +39,8 @@ def hello_world(request, args):
from webargs import core
from webargs.core import json

F = typing.TypeVar("F", bound=typing.Callable)


def is_json_request(req: Request) -> bool:
return core.is_json(req.headers.get("content-type"))
Expand All @@ -57,7 +60,7 @@ class PyramidParser(core.Parser[Request]):
**core.Parser.__location_map__,
)

def _raw_load_json(self, req: Request):
def _raw_load_json(self, req: Request) -> typing.Any:
"""Return a json payload from the request for the core parser's load_json
Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -67,34 +70,40 @@ def _raw_load_json(self, req: Request):

return core.parse_json(req.body, encoding=req.charset)

def load_querystring(self, req: Request, schema):
def load_querystring(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(req.GET, schema)

def load_form(self, req: Request, schema):
def load_form(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(req.POST, schema)

def load_cookies(self, req: Request, schema):
def load_cookies(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return cookies from the request as a MultiDictProxy."""
return self._makeproxy(req.cookies, schema)

def load_headers(self, req: Request, schema):
def load_headers(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return headers from the request as a MultiDictProxy."""
return self._makeproxy(req.headers, schema)

def load_files(self, req: Request, schema):
def load_files(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return files from the request as a MultiDictProxy."""
files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file"))
return self._makeproxy(MultiDict(files), schema)

def load_matchdict(self, req: Request, schema):
def load_matchdict(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return the request's ``matchdict`` as a MultiDictProxy."""
return self._makeproxy(req.matchdict, schema)

def handle_error(
self, error, req: Request, schema, *, error_status_code, error_headers
):
self,
error: ma.ValidationError,
req: Request,
schema: ma.Schema,
*,
error_status_code: int | None,
error_headers: typing.Mapping[str, str] | None,
) -> typing.NoReturn:
"""Handles errors during parsing. Aborts the current HTTP request and
responds with a 400 error.
"""
Expand All @@ -109,7 +118,13 @@ def handle_error(
response.body = body.encode("utf-8") if isinstance(body, str) else body
raise response

def _handle_invalid_json_error(self, error, req: Request, *args, **kwargs):
def _handle_invalid_json_error(
self,
error: json.JSONDecodeError | UnicodeDecodeError,
req: Request,
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.NoReturn:
messages = {"json": ["Invalid JSON body."]}
response = exception_response(
400, detail=str(messages), content_type="application/json"
Expand All @@ -120,17 +135,17 @@ def _handle_invalid_json_error(self, error, req: Request, *args, **kwargs):

def use_args(
self,
argmap,
argmap: core.ArgMap,
req: Request | None = None,
*,
location=core.Parser.DEFAULT_LOCATION,
unknown=None,
as_kwargs=False,
arg_name=None,
validate=None,
error_status_code=None,
error_headers=None,
):
location: str | None = core.Parser.DEFAULT_LOCATION,
unknown: str | None = None,
as_kwargs: bool = False,
arg_name: str | None = None,
validate: core.ValidateArg = None,
error_status_code: int | None = None,
error_headers: typing.Mapping[str, str] | None = None,
) -> typing.Callable[..., typing.Callable]:
"""Decorator that injects parsed arguments into a view callable.
Supports the *Class-based View* pattern where `request` is saved as an instance
attribute on a view class.
Expand Down Expand Up @@ -167,9 +182,11 @@ def use_args(
argmap = dict(argmap)
argmap = self.schema_class.from_dict(argmap)()

def decorator(func):
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(obj, *args, **kwargs):
def wrapper(
obj: typing.Any, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
# The first argument is either `self` or `request`
try: # get self.request
request = req or obj.request
Expand All @@ -191,7 +208,7 @@ def wrapper(obj, *args, **kwargs):
return func(obj, *args, **kwargs)

wrapper.__wrapped__ = func
return wrapper
return wrapper # type: ignore[return-value]

return decorator

Expand Down
51 changes: 36 additions & 15 deletions src/webargs/tornadoparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def get(self, args):
self.write(response)
"""

from __future__ import annotations

import json
import typing

import marshmallow as ma
import tornado.concurrent
import tornado.web
from tornado.escape import _unicode
Expand All @@ -26,13 +32,13 @@ def get(self, args):
class HTTPError(tornado.web.HTTPError):
"""`tornado.web.HTTPError` that stores validation errors."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.messages = kwargs.pop("messages", {})
self.headers = kwargs.pop("headers", None)
super().__init__(*args, **kwargs)


def is_json_request(req: HTTPServerRequest):
def is_json_request(req: HTTPServerRequest) -> bool:
content_type = req.headers.get("Content-Type")
return content_type is not None and core.is_json(content_type)

Expand All @@ -43,7 +49,7 @@ class WebArgsTornadoMultiDictProxy(MultiDictProxy):
requirements.
"""

def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
try:
value = self.data.get(key, core.missing)
if value is core.missing:
Expand All @@ -70,7 +76,7 @@ class WebArgsTornadoCookiesMultiDictProxy(MultiDictProxy):
Also, does not use the `_unicode` decoding step
"""

def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
cookie = self.data.get(key, core.missing)
if cookie is core.missing:
return core.missing
Expand All @@ -82,7 +88,7 @@ def __getitem__(self, key):
class TornadoParser(core.Parser[HTTPServerRequest]):
"""Tornado request argument parser."""

def _raw_load_json(self, req: HTTPServerRequest):
def _raw_load_json(self, req: HTTPServerRequest) -> typing.Any:
"""Return a json payload from the request for the core parser's load_json
Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -97,37 +103,43 @@ def _raw_load_json(self, req: HTTPServerRequest):

return core.parse_json(req.body)

def load_querystring(self, req: HTTPServerRequest, schema):
def load_querystring(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(
req.query_arguments, schema, cls=WebArgsTornadoMultiDictProxy
)

def load_form(self, req: HTTPServerRequest, schema):
def load_form(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(
req.body_arguments, schema, cls=WebArgsTornadoMultiDictProxy
)

def load_headers(self, req: HTTPServerRequest, schema):
def load_headers(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return headers from the request as a MultiDictProxy."""
return self._makeproxy(req.headers, schema, cls=WebArgsTornadoMultiDictProxy)

def load_cookies(self, req: HTTPServerRequest, schema):
def load_cookies(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return cookies from the request as a MultiDictProxy."""
# use the specialized subclass specifically for handling Tornado
# cookies
return self._makeproxy(
req.cookies, schema, cls=WebArgsTornadoCookiesMultiDictProxy
)

def load_files(self, req: HTTPServerRequest, schema):
def load_files(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return files from the request as a MultiDictProxy."""
return self._makeproxy(req.files, schema, cls=WebArgsTornadoMultiDictProxy)

def handle_error(
self, error, req: HTTPServerRequest, schema, *, error_status_code, error_headers
):
self,
error: ma.ValidationError,
req: HTTPServerRequest,
schema: ma.Schema,
*,
error_status_code: int | None,
error_headers: typing.Mapping[str, str] | None,
) -> typing.NoReturn:
"""Handles errors during parsing. Raises a `tornado.web.HTTPError`
with a 400 error.
"""
Expand All @@ -145,16 +157,25 @@ def handle_error(
)

def _handle_invalid_json_error(
self, error, req: HTTPServerRequest, *args, **kwargs
):
self,
error: json.JSONDecodeError | UnicodeDecodeError,
req: HTTPServerRequest,
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.NoReturn:
raise HTTPError(
400,
log_message="Invalid JSON body.",
reason="Bad Request",
messages={"json": ["Invalid JSON body."]},
)

def get_request_from_view_args(self, view, args, kwargs):
def get_request_from_view_args(
self,
view: typing.Any,
args: tuple[typing.Any, ...],
kwargs: typing.Mapping[str, typing.Any],
) -> HTTPServerRequest:
return args[0].request


Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ commands = pre-commit run --all-files
# `webargs` and `marshmallow` both installed is a valuable safeguard against
# issues in which `mypy` running on every file standalone won't catch things
[testenv:mypy]
deps = mypy==1.8.0
deps = mypy==1.10.0
extras = frameworks
commands = mypy src/ {posargs}

Expand Down

0 comments on commit e46313a

Please sign in to comment.