diff --git a/.coveragerc b/.coveragerc index 10c3bf499..2c5041d7e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,7 @@ [run] branch = True source = falcon -omit = falcon/tests*,falcon/cmd/bench.py,falcon/bench/*,falcon/vendor/* +omit = falcon/tests*,falcon/typing.py,falcon/cmd/bench.py,falcon/bench/*,falcon/vendor/* parallel = True @@ -9,6 +9,7 @@ parallel = True show_missing = True exclude_lines = if TYPE_CHECKING: + if not TYPE_CHECKING: pragma: nocover pragma: no cover pragma: no py39,py310 cover diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b415c0e7f..398fd0046 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -27,6 +27,7 @@ jobs: - "pep8-examples" - "pep8-docstrings" - "mypy" + - "mypy_tests" - "py310" - "py310_sans_msgpack" - "py310_cython" diff --git a/MANIFEST.in b/MANIFEST.in index f39619936..06c1485c7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -9,6 +9,7 @@ include README.rst include AUTHORS include LICENSE include docs/conf.py docs/Makefile +include falcon/py.typed graft docs/_static graft docs/_templates graft requirements diff --git a/docs/_newsfragments/1947.newandimproved.rst b/docs/_newsfragments/1947.newandimproved.rst new file mode 100644 index 000000000..8ea9675c2 --- /dev/null +++ b/docs/_newsfragments/1947.newandimproved.rst @@ -0,0 +1,4 @@ +Basic typing annotations have been added to the most commonly used functions of +Falcon's public interface to the package itself in order to better support +`mypy `_ users without having to install any +third-party typeshed packages. diff --git a/e2e-tests/server/app.py b/e2e-tests/server/app.py index 28d508b42..bde7e399d 100644 --- a/e2e-tests/server/app.py +++ b/e2e-tests/server/app.py @@ -10,7 +10,7 @@ STATIC = HERE.parent / 'static' -def create_app(): +def create_app() -> falcon.asgi.App: app = falcon.asgi.App() hub = Hub() diff --git a/e2e-tests/server/chat.py b/e2e-tests/server/chat.py index df862250c..8cad04a89 100644 --- a/e2e-tests/server/chat.py +++ b/e2e-tests/server/chat.py @@ -1,14 +1,18 @@ import re +from falcon.asgi import Request, WebSocket + +from .hub import Hub + class Chat: ALL = re.compile(r'^/all\s+(.+)$') MSG = re.compile(r'^/msg\s+(\w+)\s+(.+)$') - def __init__(self, hub): + def __init__(self, hub: Hub): self._hub = hub - async def on_websocket(self, req, ws, name): + async def on_websocket(self, req: Request, ws: WebSocket, name: str) -> None: await ws.accept() try: diff --git a/e2e-tests/server/hub.py b/e2e-tests/server/hub.py index 4b4648477..181b2555f 100644 --- a/e2e-tests/server/hub.py +++ b/e2e-tests/server/hub.py @@ -1,7 +1,8 @@ import asyncio +import typing import uuid -from falcon.asgi import SSEvent +from falcon.asgi import Request, Response, SSEvent, WebSocket class Emitter: @@ -11,7 +12,7 @@ def __init__(self): self._done = False self._queue = asyncio.Queue() - async def events(self): + async def events(self) -> typing.AsyncGenerator[typing.Optional[SSEvent], None]: try: yield SSEvent(text='SSE CONNECTED') @@ -28,7 +29,7 @@ async def events(self): # TODO(vytas): Is there a more elegant way to detect a disconnect? self._done = True - async def enqueue(self, message): + async def enqueue(self, message: str) -> None: event = SSEvent(text=message, event_id=str(uuid.uuid4())) await self._queue.put(event) @@ -42,28 +43,28 @@ def __init__(self): self._emitters = set() self._users = {} - def _update_emitters(self): + def _update_emitters(self) -> set: done = {emitter for emitter in self._emitters if emitter.done} self._emitters.difference_update(done) return self._emitters.copy() - def add_user(self, name, ws): + def add_user(self, name: str, ws: WebSocket) -> None: self._users[name] = ws - def remove_user(self, name): + def remove_user(self, name: str) -> None: self._users.pop(name, None) - async def broadcast(self, message): + async def broadcast(self, message: str) -> None: for emitter in self._update_emitters(): await emitter.enqueue(message) - async def message(self, name, text): + async def message(self, name: str, text: str) -> None: ws = self._users.get(name) if ws: # TODO(vytas): What if this overlaps with another ongoing send? await ws.send_text(text) - def events(self): + def events(self) -> typing.AsyncGenerator[typing.Optional[SSEvent], None]: emitter = Emitter() self._update_emitters() self._emitters.add(emitter) @@ -71,8 +72,8 @@ def events(self): class Events: - def __init__(self, hub): + def __init__(self, hub: Hub): self._hub = hub - async def on_get(self, req, resp): + async def on_get(self, req: Request, resp: Response) -> None: resp.sse = self._hub.events() diff --git a/e2e-tests/server/ping.py b/e2e-tests/server/ping.py index 771915ec6..bac2aec68 100644 --- a/e2e-tests/server/ping.py +++ b/e2e-tests/server/ping.py @@ -1,10 +1,11 @@ from http import HTTPStatus import falcon +from falcon.asgi import Request, Response class Pong: - async def on_get(self, req, resp): + async def on_get(self, req: Request, resp: Response) -> None: resp.content_type = falcon.MEDIA_TEXT resp.text = 'PONG\n' resp.status = HTTPStatus.OK diff --git a/falcon/app.py b/falcon/app.py index 4a64cc0d4..442d41676 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -16,8 +16,10 @@ from functools import wraps from inspect import iscoroutinefunction +import pathlib import re import traceback +from typing import Callable, Iterable, Optional, Tuple, Type, Union from falcon import app_helpers as helpers from falcon import constants @@ -34,6 +36,7 @@ from falcon.response import Response from falcon.response import ResponseOptions import falcon.status_codes as status +from falcon.typing import ErrorHandler, ErrorSerializer, SinkPrefix from falcon.util import deprecation from falcon.util import misc from falcon.util.misc import code_to_http_status @@ -226,6 +229,9 @@ def process_response(self, req, resp, resource, req_succeeded) 'resp_options', ) + req_options: RequestOptions + resp_options: ResponseOptions + def __init__( self, media_type=constants.DEFAULT_MEDIA_TYPE, @@ -285,7 +291,9 @@ def __init__( self.add_error_handler(HTTPError, self._http_error_handler) self.add_error_handler(HTTPStatus, self._http_status_handler) - def __call__(self, env, start_response): # noqa: C901 + def __call__( # noqa: C901 + self, env: dict, start_response: Callable + ) -> Iterable[bytes]: """WSGI `app` method. Makes instances of App callable from a WSGI server. May be used to @@ -302,11 +310,11 @@ def __call__(self, env, start_response): # noqa: C901 """ req = self._request_type(env, options=self.req_options) resp = self._response_type(options=self.resp_options) - resource = None - responder = None - params = {} + resource: Optional[object] = None + responder: Optional[Callable] = None + params: dict = {} - dependent_mw_resp_stack = [] + dependent_mw_resp_stack: list = [] mw_req_stack, mw_rsrc_stack, mw_resp_stack = self._middleware req_succeeded = False @@ -361,7 +369,7 @@ def __call__(self, env, start_response): # noqa: C901 break if not resp.complete: - responder(req, resp, **params) + responder(req, resp, **params) # type: ignore req_succeeded = True except Exception as ex: @@ -438,7 +446,7 @@ def __call__(self, env, start_response): # noqa: C901 def router_options(self): return self._router.options - def add_middleware(self, middleware): + def add_middleware(self, middleware: object) -> None: """Add one or more additional middleware components. Arguments: @@ -465,7 +473,7 @@ def add_middleware(self, middleware): independent_middleware=self._independent_middleware, ) - def add_route(self, uri_template, resource, **kwargs): + def add_route(self, uri_template: str, resource: object, **kwargs): """Associate a templatized URI path with a resource. Falcon routes incoming requests to resources based on a set of @@ -572,7 +580,11 @@ def on_get_bar(self, req, resp): self._router.add_route(uri_template, resource, **kwargs) def add_static_route( - self, prefix, directory, downloadable=False, fallback_filename=None + self, + prefix: str, + directory: Union[str, pathlib.Path], + downloadable: bool = False, + fallback_filename: Optional[str] = None, ): """Add a route to a directory of static files. @@ -641,7 +653,7 @@ def add_static_route( self._static_routes.insert(0, (sr, sr, False)) self._update_sink_and_static_routes() - def add_sink(self, sink, prefix=r'/'): + def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/'): """Register a sink method for the App. If no route matches a request, but the path in the requested URI @@ -694,7 +706,11 @@ def add_sink(self, sink, prefix=r'/'): self._sinks.insert(0, (prefix, sink, True)) self._update_sink_and_static_routes() - def add_error_handler(self, exception, handler=None): + def add_error_handler( + self, + exception: Union[Type[BaseException], Iterable[Type[BaseException]]], + handler: Optional[ErrorHandler] = None, + ): """Register a handler for one or more exception types. Error handlers may be registered for any exception type, including @@ -794,7 +810,7 @@ def handler(req, resp, ex, params): if handler is None: try: - handler = exception.handle + handler = exception.handle # type: ignore except AttributeError: raise AttributeError( 'handler must either be specified ' @@ -814,8 +830,9 @@ def handler(req, resp, ex, params): ) or arg_names[1:3] in (('req', 'resp'), ('request', 'response')): handler = wrap_old_handler(handler) + exception_tuple: tuple try: - exception_tuple = tuple(exception) + exception_tuple = tuple(exception) # type: ignore except TypeError: exception_tuple = (exception,) @@ -825,7 +842,7 @@ def handler(req, resp, ex, params): self._error_handlers[exc] = handler - def set_error_serializer(self, serializer): + def set_error_serializer(self, serializer: ErrorSerializer): """Override the default serializer for instances of :class:`~.HTTPError`. When a responder raises an instance of :class:`~.HTTPError`, @@ -882,7 +899,9 @@ def _prepare_middleware(self, middleware=None, independent_middleware=False): middleware=middleware, independent_middleware=independent_middleware ) - def _get_responder(self, req): + def _get_responder( + self, req: Request + ) -> Tuple[Callable, dict, object, Optional[str]]: """Search routes for a matching responder. Args: @@ -953,7 +972,9 @@ def _get_responder(self, req): return (responder, params, resource, uri_template) - def _compose_status_response(self, req, resp, http_status): + def _compose_status_response( + self, req: Request, resp: Response, http_status: HTTPStatus + ) -> None: """Compose a response for the given HTTPStatus instance.""" # PERF(kgriffs): The code to set the status and headers is identical @@ -968,7 +989,9 @@ def _compose_status_response(self, req, resp, http_status): # it's acceptable to set resp.text to None (to indicate no body). resp.text = http_status.text - def _compose_error_response(self, req, resp, error): + def _compose_error_response( + self, req: Request, resp: Response, error: HTTPError + ) -> None: """Compose a response for the given HTTPError instance.""" resp.status = error.status diff --git a/falcon/app_helpers.py b/falcon/app_helpers.py index 3e58ebd3f..38b591914 100644 --- a/falcon/app_helpers.py +++ b/falcon/app_helpers.py @@ -15,11 +15,14 @@ """Utilities for the App class.""" from inspect import iscoroutinefunction +from typing import IO, Iterable, List, Tuple from falcon import util from falcon.constants import MEDIA_JSON from falcon.constants import MEDIA_XML -from falcon.errors import CompatibilityError +from falcon.errors import CompatibilityError, HTTPError +from falcon.request import Request +from falcon.response import Response from falcon.util.sync import _wrap_non_coroutine_unsafe __all__ = ( @@ -30,7 +33,9 @@ ) -def prepare_middleware(middleware, independent_middleware=False, asgi=False): +def prepare_middleware( + middleware: Iterable, independent_middleware: bool = False, asgi: bool = False +) -> Tuple[tuple, tuple, tuple]: """Check middleware interfaces and prepare the methods for request handling. Note: @@ -52,9 +57,9 @@ def prepare_middleware(middleware, independent_middleware=False, asgi=False): # PERF(kgriffs): do getattr calls once, in advance, so we don't # have to do them every time in the request path. - request_mw = [] - resource_mw = [] - response_mw = [] + request_mw: List = [] + resource_mw: List = [] + response_mw: List = [] for component in middleware: # NOTE(kgriffs): Middleware that supports both WSGI and ASGI can @@ -148,7 +153,7 @@ def prepare_middleware(middleware, independent_middleware=False, asgi=False): return (tuple(request_mw), tuple(resource_mw), tuple(response_mw)) -def prepare_middleware_ws(middleware): +def prepare_middleware_ws(middleware: Iterable) -> Tuple[list, list]: """Check middleware interfaces and prepare WebSocket methods for request handling. Note: @@ -196,7 +201,7 @@ def prepare_middleware_ws(middleware): return request_mw, resource_mw -def default_serialize_error(req, resp, exception): +def default_serialize_error(req: Request, resp: Response, exception: HTTPError): """Serialize the given instance of HTTPError. This function determines which of the supported media types, if @@ -275,7 +280,7 @@ class CloseableStreamIterator: block_size (int): Number of bytes to read per iteration. """ - def __init__(self, stream, block_size): + def __init__(self, stream: IO, block_size: int): self._stream = stream self._block_size = block_size diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index b7dbc94fe..a4e0b9366 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -18,6 +18,7 @@ from inspect import isasyncgenfunction from inspect import iscoroutinefunction import traceback +from typing import Awaitable, Callable, Iterable, Optional, Type, Union import falcon.app from falcon.app_helpers import prepare_middleware @@ -33,6 +34,7 @@ from falcon.http_status import HTTPStatus from falcon.media.multipart import MultipartFormHandler import falcon.routing +from falcon.typing import ErrorHandler, SinkPrefix from falcon.util.misc import is_python_func from falcon.util.sync import _should_wrap_non_coroutines from falcon.util.sync import _wrap_non_coroutine_unsafe @@ -279,7 +281,12 @@ def __init__(self, *args, request_type=Request, response_type=Response, **kwargs ) @_wrap_asgi_coroutine_func - async def __call__(self, scope, receive, send): # noqa: C901 + async def __call__( # noqa: C901 + self, + scope: dict, + receive: Callable[[], Awaitable[dict]], + send: Callable[[dict], Awaitable[None]], + ) -> None: # NOTE(kgriffs): The ASGI spec requires the 'type' key to be present. scope_type = scope['type'] @@ -339,10 +346,10 @@ async def __call__(self, scope, receive, send): # noqa: C901 resp = self._response_type(options=self.resp_options) resource = None - responder = None - params = {} + responder: Optional[Callable] = None + params: dict = {} - dependent_mw_resp_stack = [] + dependent_mw_resp_stack: list = [] mw_req_stack, mw_rsrc_stack, mw_resp_stack = self._middleware req_succeeded = False @@ -402,7 +409,7 @@ async def __call__(self, scope, receive, send): # noqa: C901 break if not resp.complete: - await responder(req, resp, **params) + await responder(req, resp, **params) # type: ignore req_succeeded = True @@ -716,7 +723,7 @@ async def watch_disconnect(): if resp._registered_callbacks: self._schedule_callbacks(resp) - def add_route(self, uri_template, resource, **kwargs): + def add_route(self, uri_template: str, resource: object, **kwargs): # NOTE(kgriffs): Inject an extra kwarg so that the compiled router # will know to validate the responder methods to make sure they # are async coroutines. @@ -725,7 +732,7 @@ def add_route(self, uri_template, resource, **kwargs): add_route.__doc__ = falcon.app.App.add_route.__doc__ - def add_sink(self, sink, prefix=r'/'): + def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/'): if not iscoroutinefunction(sink) and is_python_func(sink): if _should_wrap_non_coroutines(): sink = wrap_sync_to_async(sink) @@ -739,7 +746,11 @@ def add_sink(self, sink, prefix=r'/'): add_sink.__doc__ = falcon.app.App.add_sink.__doc__ - def add_error_handler(self, exception, handler=None): + def add_error_handler( + self, + exception: Union[Type[BaseException], Iterable[Type[BaseException]]], + handler: Optional[ErrorHandler] = None, + ): """Register a handler for one or more exception types. Error handlers may be registered for any exception type, including @@ -840,7 +851,7 @@ async def handle(req, resp, ex, params): if handler is None: try: - handler = exception.handle + handler = exception.handle # type: ignore except AttributeError: raise AttributeError( 'handler must either be specified ' @@ -870,8 +881,9 @@ async def handle(req, resp, ex, params): 'to be used safely with an ASGI app.' ) + exception_tuple: tuple try: - exception_tuple = tuple(exception) + exception_tuple = tuple(exception) # type: ignore except TypeError: exception_tuple = (exception,) diff --git a/falcon/asgi/multipart.py b/falcon/asgi/multipart.py index 2e6ae5289..b268c2b5a 100644 --- a/falcon/asgi/multipart.py +++ b/falcon/asgi/multipart.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 by Vytautas Liuolia. +# Copyright 2019-2023 by Vytautas Liuolia. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/falcon/errors.py b/falcon/errors.py index 63f95f43d..cdad7ffa7 100644 --- a/falcon/errors.py +++ b/falcon/errors.py @@ -35,6 +35,7 @@ def on_get(self, req, resp): """ from datetime import datetime +from typing import Optional from falcon.http_error import HTTPError import falcon.status_codes as status @@ -142,7 +143,7 @@ class WebSocketDisconnected(ConnectionError): code (int): The WebSocket close code, as per the WebSocket spec. """ - def __init__(self, code: int = None): + def __init__(self, code: Optional[int] = None): self.code = code or 1000 # Default to "Normal Closure" diff --git a/falcon/media/base.py b/falcon/media/base.py index 2c435c0bd..ad06b8674 100644 --- a/falcon/media/base.py +++ b/falcon/media/base.py @@ -1,6 +1,6 @@ import abc import io -from typing import Union +from typing import IO, Optional, Union from falcon.constants import MEDIA_JSON @@ -22,7 +22,7 @@ class BaseHandler(metaclass=abc.ABCMeta): """Override to provide a synchronous deserialization method that takes a byte string.""" - def serialize(self, media, content_type) -> bytes: + def serialize(self, media: object, content_type: str) -> bytes: """Serialize the media object on a :any:`falcon.Response`. By default, this method raises an instance of @@ -54,7 +54,7 @@ def serialize(self, media, content_type) -> bytes: else: raise NotImplementedError() - async def serialize_async(self, media, content_type) -> bytes: + async def serialize_async(self, media: object, content_type: str) -> bytes: """Serialize the media object on a :any:`falcon.Response`. This method is similar to :py:meth:`~.BaseHandler.serialize` @@ -81,7 +81,9 @@ async def serialize_async(self, media, content_type) -> bytes: """ return self.serialize(media, content_type) - def deserialize(self, stream, content_type, content_length) -> object: + def deserialize( + self, stream: IO, content_type: str, content_length: Optional[int] + ) -> object: """Deserialize the :any:`falcon.Request` body. By default, this method raises an instance of @@ -114,7 +116,9 @@ def deserialize(self, stream, content_type, content_length) -> object: else: raise NotImplementedError() - async def deserialize_async(self, stream, content_type, content_length) -> object: + async def deserialize_async( + self, stream: IO, content_type: str, content_length: Optional[int] + ) -> object: """Deserialize the :any:`falcon.Request` body. This method is similar to :py:meth:`~.BaseHandler.deserialize` except @@ -134,7 +138,8 @@ async def deserialize_async(self, stream, content_type, content_length) -> objec Args: stream (object): Asynchronous file-like object to deserialize. content_type (str): Type of request content. - content_length (int): Length of request content. + content_length (int): Length of request content, or ``None`` if the + Content-Length header is missing. Returns: object: A deserialized object. diff --git a/falcon/media/multipart.py b/falcon/media/multipart.py index b336526bc..c3fc37d56 100644 --- a/falcon/media/multipart.py +++ b/falcon/media/multipart.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 by Vytautas Liuolia. +# Copyright 2019-2023 by Vytautas Liuolia. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/falcon/py.typed b/falcon/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/falcon/request.py b/falcon/request.py index 82df21c15..cb037369d 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -2085,6 +2085,13 @@ class RequestOptions: ``multipart/form-data`` media types. """ + keep_black_qs_values: bool + auto_parse_form_urlencoded: bool + auto_parse_qs_csv: bool + strip_url_path_trailing_slash: bool + default_media_type: str + media_handlers: Handlers + __slots__ = ( 'keep_blank_qs_values', 'auto_parse_form_urlencoded', diff --git a/falcon/response.py b/falcon/response.py index f69ff082f..cad529d3c 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -16,6 +16,7 @@ import functools import mimetypes +from typing import Optional from falcon.constants import _DEFAULT_STATIC_MEDIA_TYPES from falcon.constants import _UNSET @@ -1233,6 +1234,11 @@ class ResponseOptions: after calling ``mimetypes.init()``. """ + secure_cookies_by_default: bool + default_media_type: Optional[str] + media_handlers: Handlers + static_media_types: dict + __slots__ = ( 'secure_cookies_by_default', 'default_media_type', diff --git a/falcon/routing/converters.py b/falcon/routing/converters.py index 0c35ddb1d..8fd28fa32 100644 --- a/falcon/routing/converters.py +++ b/falcon/routing/converters.py @@ -15,6 +15,7 @@ import abc from datetime import datetime from math import isfinite +from typing import Optional import uuid __all__ = ( @@ -125,7 +126,12 @@ class FloatConverter(IntConverter): __slots__ = '_finite' - def __init__(self, min: float = None, max: float = None, finite: bool = True): + def __init__( + self, + min: Optional[float] = None, + max: Optional[float] = None, + finite: bool = True, + ): self._min = min self._max = max self._finite = finite if finite is not None else True @@ -135,15 +141,15 @@ def convert(self, value: str): return None try: - value = float(value) + converted = float(value) - if self._finite and not isfinite(value): + if self._finite and not isfinite(converted): return None except ValueError: return None - return self._validate_min_max_value(value) + return self._validate_min_max_value(converted) class DateTimeConverter(BaseConverter): diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index 8ca48c101..ccebb0d0d 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -139,9 +139,9 @@ class ASGIRequestEventEmitter: def __init__( self, - body: Union[str, bytes] = None, - chunk_size: int = None, - disconnect_at: Union[int, float] = None, + body: Optional[Union[str, bytes]] = None, + chunk_size: Optional[int] = None, + disconnect_at: Optional[Union[int, float]] = None, ): if body is None: body = b'' @@ -170,7 +170,7 @@ def __init__( def disconnected(self): return self._disconnected or (self._disconnect_at <= time.time()) - def disconnect(self, exhaust_body: bool = None): + def disconnect(self, exhaust_body: Optional[bool] = None): """Set the client connection state to disconnected. Call this method to simulate an immediate client disconnect and @@ -654,7 +654,7 @@ def _require_accepted(self): # NOTE(kgriffs): This is a coroutine just in case we need it to be # in a future code revision. It also makes it more consistent # with the other methods. - async def _send(self, data: bytes = None, text: str = None): + async def _send(self, data: Optional[bytes] = None, text: Optional[str] = None): self._require_accepted() # NOTE(kgriffs): From the client's perspective, it was a send, diff --git a/falcon/typing.py b/falcon/typing.py new file mode 100644 index 000000000..4acf8ad97 --- /dev/null +++ b/falcon/typing.py @@ -0,0 +1,36 @@ +# Copyright 2021-2023 by Vytautas Liuolia. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shorthand definitions for more complex types.""" + +from typing import Any, Callable, Pattern, Union + +from falcon.request import Request +from falcon.response import Response + + +# Error handlers +ErrorHandler = Callable[[Request, Response, BaseException, dict], Any] + +# Error serializers +ErrorSerializer = Callable[[Request, Response, BaseException], Any] + +# Sinks +SinkPrefix = Union[str, Pattern] + +# TODO(vytas): Is it possible to specify a Callable or a Protocol that defines +# type hints for the two first parameters, but accepts any number of keyword +# arguments afterwords? +# class SinkCallable(Protocol): +# def __call__(sef, req: Request, resp: Response, ): ... diff --git a/falcon/util/misc.py b/falcon/util/misc.py index d20fd4595..1c05d1090 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -119,7 +119,7 @@ def is_python_func(func): return inspect.isfunction(func) -def http_now(): +def http_now() -> str: """Return the current UTC time as an IMF-fixdate. Returns: @@ -130,7 +130,7 @@ def http_now(): return dt_to_http(utcnow()) -def dt_to_http(dt): +def dt_to_http(dt: datetime.datetime) -> str: """Convert a ``datetime`` instance to an HTTP date string. Args: @@ -145,7 +145,7 @@ def dt_to_http(dt): return dt.strftime('%a, %d %b %Y %H:%M:%S GMT') -def http_date_to_dt(http_date, obs_date=False): +def http_date_to_dt(http_date: str, obs_date: bool = False) -> datetime.datetime: """Convert an HTTP date string to a datetime instance. Args: @@ -191,7 +191,9 @@ def http_date_to_dt(http_date, obs_date=False): raise ValueError('time data %r does not match known formats' % http_date) -def to_query_str(params, comma_delimited_lists=True, prefix=True): +def to_query_str( + params: dict, comma_delimited_lists: bool = True, prefix: bool = True +) -> str: """Convert a dictionary of parameters to a query string. Args: @@ -344,7 +346,7 @@ def get_http_status(status_code, default_reason=_DEFAULT_HTTP_REASON): return str(code) + ' ' + default_reason -def secure_filename(filename): +def secure_filename(filename: str) -> str: """Sanitize the provided `filename` to contain only ASCII characters. Only ASCII alphanumerals, ``'.'``, ``'-'`` and ``'_'`` are allowed for @@ -461,7 +463,7 @@ def code_to_http_status(status): except (ValueError, TypeError): raise ValueError('{!r} is not a valid status code'.format(status)) if not 100 <= code <= 999: - raise ValueError('{} is not a valid status code'.format(status)) + raise ValueError('{!r} is not a valid status code'.format(status)) try: # NOTE(kgriffs): We do this instead of using http.HTTPStatus since @@ -489,7 +491,7 @@ def _encode_items_to_latin1(data): return result -def _isascii(string): +def _isascii(string: str): """Return ``True`` if all characters in the string are ASCII. ASCII characters have code points in the range U+0000-U+007F. diff --git a/falcon/util/time.py b/falcon/util/time.py index 00a69f07c..485f3b78a 100644 --- a/falcon/util/time.py +++ b/falcon/util/time.py @@ -10,6 +10,7 @@ """ import datetime +from typing import Optional __all__ = ['TimezoneGMT'] @@ -20,7 +21,7 @@ class TimezoneGMT(datetime.tzinfo): GMT_ZERO = datetime.timedelta(hours=0) - def utcoffset(self, dt): + def utcoffset(self, dt: Optional[datetime.datetime]) -> datetime.timedelta: """Get the offset from UTC. Args: @@ -33,7 +34,7 @@ def utcoffset(self, dt): return self.GMT_ZERO - def tzname(self, dt): + def tzname(self, dt: Optional[datetime.datetime]) -> str: """Get the name of this timezone. Args: @@ -45,7 +46,7 @@ def tzname(self, dt): return 'GMT' - def dst(self, dt): + def dst(self, dt: Optional[datetime.datetime]) -> datetime.timedelta: """Return the daylight saving time (DST) adjustment. Args: diff --git a/falcon/util/uri.py b/falcon/util/uri.py index 601b5ee72..e6078dfb9 100644 --- a/falcon/util/uri.py +++ b/falcon/util/uri.py @@ -23,6 +23,8 @@ name, port = uri.parse_host('example.org:8080') """ +from typing import Tuple, TYPE_CHECKING + from falcon.constants import PYPY try: @@ -328,7 +330,7 @@ def decode(encoded_uri, unquote_plus=True): return _join_tokens(tokens) -def parse_query_string(query_string, keep_blank=False, csv=True): +def parse_query_string(query_string: str, keep_blank: bool = False, csv: bool = True): """Parse a query string into a dict. Query string parameters are assumed to use standard form-encoding. Only @@ -372,7 +374,7 @@ def parse_query_string(query_string, keep_blank=False, csv=True): """ - params = {} + params: dict = {} is_encoded = '+' in query_string or '%' in query_string @@ -402,15 +404,17 @@ def parse_query_string(query_string, keep_blank=False, csv=True): # assigned to a single param instance. If it turns out that # very few people use this, it can be deprecated at some # point. - v = v.split(',') + values = v.split(',') if not keep_blank: # NOTE(kgriffs): Normalize the result in the case that # some elements are empty strings, such that the result # will be the same for 'foo=1,,3' as 'foo=1&foo=&foo=3'. - additional_values = [decode(element) for element in v if element] + additional_values = [ + decode(element) for element in values if element + ] else: - additional_values = [decode(element) for element in v] + additional_values = [decode(element) for element in values] if isinstance(old_value, list): old_value.extend(additional_values) @@ -434,15 +438,15 @@ def parse_query_string(query_string, keep_blank=False, csv=True): # assigned to a single param instance. If it turns out that # very few people use this, it can be deprecated at some # point. - v = v.split(',') + values = v.split(',') if not keep_blank: # NOTE(kgriffs): Normalize the result in the case that # some elements are empty strings, such that the result # will be the same for 'foo=1,,3' as 'foo=1&foo=&foo=3'. - params[k] = [decode(element) for element in v if element] + params[k] = [decode(element) for element in values if element] else: - params[k] = [decode(element) for element in v] + params[k] = [decode(element) for element in values] elif is_encoded: params[k] = decode(v) else: @@ -451,7 +455,7 @@ def parse_query_string(query_string, keep_blank=False, csv=True): return params -def parse_host(host, default_port=None): +def parse_host(host: str, default_port=None) -> Tuple[str, int]: """Parse a canonical 'host:port' string into parts. Parse a host string (which may or may not contain a port) into @@ -536,8 +540,9 @@ def unquote_string(quoted): # TODO(vytas): Restructure this in favour of a cleaner way to hoist the pure # Cython functions into this module. -decode = _cy_decode or decode # NOQA -parse_query_string = _cy_parse_query_string or parse_query_string # NOQA +if not TYPE_CHECKING: + decode = _cy_decode or decode # NOQA + parse_query_string = _cy_parse_query_string or parse_query_string # NOQA __all__ = [ diff --git a/pyproject.toml b/pyproject.toml index 0e013393d..b386f7a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,34 @@ "cython>=0.29.21; python_implementation == 'CPython'", # Skip cython when using pypy ] +[tool.mypy] + exclude = "falcon/bench/|falcon/cmd/" + + [[tool.mypy.overrides]] + module = [ + "cbor2", + "cython", + "daphne", + "gunicorn", + "hypercorn", + "meinheld", + "msgpack", + "mujson", + "pyximport", + "testtools", + "uvicorn" + ] + ignore_missing_imports = true + + [[tool.mypy.overrides]] + # Pure Cython modules + module = [ + "falcon.cyutil.misc", + "falcon.cyutil.reader", + "falcon.cyutil.uri" + ] + ignore_missing_imports = true + [tool.towncrier] package = "falcon" package_dir = "" diff --git a/tox.ini b/tox.ini index f60d85dd7..459a37faa 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,7 @@ envlist = cleanup, blue, pep8, mypy, + mypy_tests, mintest, pytest, pytest_sans_msgpack, @@ -142,16 +143,27 @@ deps = {[with-debug-tools]deps} # -------------------------------------------------------------------- # mypy # -------------------------------------------------------------------- + [testenv:mypy] +skipsdist = True +skip_install = True +deps = mypy + types-jsonschema +commands = python {toxinidir}/tools/clean.py "{toxinidir}/falcon" + mypy falcon + +[testenv:mypy_tests] deps = {[testenv]deps} - pytest-mypy + mypy types-requests types-PyYAML types-ujson types-waitress types-aiofiles + types-jsonschema commands = python "{toxinidir}/tools/clean.py" "{toxinidir}/falcon" - pytest --mypy -m mypy tests [] + mypy e2e-tests/server + mypy tests # -------------------------------------------------------------------- # Cython