From 78868d9f25d79a2e334b30a7043124291ff45282 Mon Sep 17 00:00:00 2001 From: Keming Date: Sun, 24 Nov 2024 20:53:57 +0800 Subject: [PATCH] feat: skip validation for request & response (#383) * feat: skip validation for request & response Signed-off-by: Keming * fix lint Signed-off-by: Keming * fix mypy Signed-off-by: Keming --------- Signed-off-by: Keming --- README.md | 6 ++- pyproject.toml | 4 +- spectree/__init__.py | 8 +-- spectree/_pydantic.py | 14 ++--- spectree/plugins/__init__.py | 2 +- spectree/plugins/falcon_plugin.py | 52 ++++++++++--------- spectree/plugins/flask_plugin.py | 24 +++++---- spectree/plugins/quart_plugin.py | 30 ++++++----- spectree/plugins/starlette_plugin.py | 44 +++++++++------- spectree/spec.py | 11 ++++ spectree/utils.py | 4 +- ...st_plugin_spec[flask_view][full_spec].json | 5 -- tests/common.py | 4 ++ tests/flask_imports/__init__.py | 18 +++---- tests/quart_imports/__init__.py | 6 +-- tests/test_plugin_falcon.py | 10 ++-- tests/test_plugin_flask.py | 10 ++-- tests/test_plugin_flask_blueprint.py | 10 ++-- tests/test_plugin_flask_view.py | 13 +++-- tests/test_plugin_quart.py | 12 ++--- tests/test_plugin_starlette.py | 12 ++--- 21 files changed, 166 insertions(+), 133 deletions(-) diff --git a/README.md b/README.md index e79f8f01..cede4807 100644 --- a/README.md +++ b/README.md @@ -263,7 +263,11 @@ You can change the `validation_error_status` in SpecTree (global) or a specific > How can I skip the validation? -Add `skip_validation=True` to the decorator. For now, this only skip the response validation. +Add `skip_validation=True` to the decorator. + +Before v1.3.0, this only skip the response validation. + +Starts from v1.3.0, this will skip all the validations. As an result, you won't be able to access the validated data from `context`. ```py @api.validate(json=Profile, resp=Response(HTTP_200=Message, HTTP_403=None), skip_validation=True) diff --git a/pyproject.toml b/pyproject.toml index 7d37b6e0..0bb525d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spectree" -version = "1.2.11" +version = "1.3.0" dynamic = [] description = "generate OpenAPI document and validate request&response with Python annotations." readme = "README.md" @@ -67,7 +67,7 @@ target-version = "py38" line-length = 88 [tool.ruff.lint] select = ["E", "F", "B", "G", "I", "SIM", "TID", "PL", "RUF"] -ignore = ["E501", "PLR2004", "RUF012"] +ignore = ["E501", "PLR2004", "RUF012", "B009"] [tool.ruff.lint.pylint] max-args = 12 max-branches = 15 diff --git a/spectree/__init__.py b/spectree/__init__.py index ccec8c19..5cebac7d 100644 --- a/spectree/__init__.py +++ b/spectree/__init__.py @@ -5,13 +5,13 @@ from .spec import SpecTree __all__ = [ - "SpecTree", - "Response", - "Tag", - "SecurityScheme", "BaseFile", "ExternalDocs", + "Response", + "SecurityScheme", "SecuritySchemeData", + "SpecTree", + "Tag", ] # setup library logging diff --git a/spectree/_pydantic.py b/spectree/_pydantic.py index b263d47e..15194983 100644 --- a/spectree/_pydantic.py +++ b/spectree/_pydantic.py @@ -7,19 +7,19 @@ __all__ = [ - "BaseModel", - "ValidationError", - "Field", - "root_validator", "AnyUrl", + "BaseModel", "BaseSettings", "EmailStr", - "validator", + "Field", + "ValidationError", + "is_base_model", + "is_base_model_instance", "is_root_model", "is_root_model_instance", + "root_validator", "serialize_model_instance", - "is_base_model", - "is_base_model_instance", + "validator", ] if PYDANTIC2: diff --git a/spectree/plugins/__init__.py b/spectree/plugins/__init__.py index 0731b892..a3a753d9 100644 --- a/spectree/plugins/__init__.py +++ b/spectree/plugins/__init__.py @@ -13,4 +13,4 @@ "starlette": Plugin(".starlette_plugin", __name__, "StarlettePlugin"), } -__all__ = ["BasePlugin", "PLUGINS", "Plugin"] +__all__ = ["PLUGINS", "BasePlugin", "Plugin"] diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 8759ea85..38255fa4 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -208,18 +208,20 @@ def validate( # falcon endpoint method arguments: (self, req, resp) _self, _req, _resp = args[:3] req_validation_error, resp_validation_error = None, None - try: - self.request_validation(_req, query, json, form, headers, cookies) - if self.config.annotations: - annotations = get_type_hints(func) - for name in ("query", "json", "form", "headers", "cookies"): - if annotations.get(name): - kwargs[name] = getattr(_req.context, name) - - except ValidationError as err: - req_validation_error = err - _resp.status = f"{validation_error_status} Validation Error" - _resp.media = err.errors() + if not skip_validation: + try: + self.request_validation(_req, query, json, form, headers, cookies) + + except ValidationError as err: + req_validation_error = err + _resp.status = f"{validation_error_status} Validation Error" + _resp.media = err.errors() + + if self.config.annotations: + annotations = get_type_hints(func) + for name in ("query", "json", "form", "headers", "cookies"): + if annotations.get(name): + kwargs[name] = getattr(_req.context, name, None) before(_req, _resp, req_validation_error, _self) if req_validation_error: @@ -312,18 +314,20 @@ async def validate( # falcon endpoint method arguments: (self, req, resp) _self, _req, _resp = args[:3] req_validation_error, resp_validation_error = None, None - try: - await self.request_validation(_req, query, json, form, headers, cookies) - if self.config.annotations: - annotations = get_type_hints(func) - for name in ("query", "json", "form", "headers", "cookies"): - if annotations.get(name): - kwargs[name] = getattr(_req.context, name) - - except ValidationError as err: - req_validation_error = err - _resp.status = f"{validation_error_status} Validation Error" - _resp.media = err.errors() + if not skip_validation: + try: + await self.request_validation(_req, query, json, form, headers, cookies) + + except ValidationError as err: + req_validation_error = err + _resp.status = f"{validation_error_status} Validation Error" + _resp.media = err.errors() + + if self.config.annotations: + annotations = get_type_hints(func) + for name in ("query", "json", "form", "headers", "cookies"): + if annotations.get(name): + kwargs[name] = getattr(_req.context, name, None) before(_req, _resp, req_validation_error, _self) if req_validation_error: diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 1b617e77..dddd5651 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -182,16 +182,20 @@ def validate( **kwargs: Any, ): response, req_validation_error, resp_validation_error = None, None, None - try: - self.request_validation(request, query, json, form, headers, cookies) - if self.config.annotations: - annotations = get_type_hints(func) - for name in ("query", "json", "form", "headers", "cookies"): - if annotations.get(name): - kwargs[name] = getattr(request.context, name) - except ValidationError as err: - req_validation_error = err - response = make_response(jsonify(err.errors()), validation_error_status) + if not skip_validation: + try: + self.request_validation(request, query, json, form, headers, cookies) + except ValidationError as err: + req_validation_error = err + response = make_response(jsonify(err.errors()), validation_error_status) + + if self.config.annotations: + annotations = get_type_hints(func) + for name in ("query", "json", "form", "headers", "cookies"): + if annotations.get(name): + kwargs[name] = getattr( + getattr(request, "context", None), name, None + ) before(request, response, req_validation_error, None) if req_validation_error: diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index eddb3c68..6ce93312 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -190,18 +190,24 @@ async def validate( **kwargs: Any, ): response, req_validation_error, resp_validation_error = None, None, None - try: - await self.request_validation(request, query, json, form, headers, cookies) - if self.config.annotations: - annotations = get_type_hints(func) - for name in ("query", "json", "form", "headers", "cookies"): - if annotations.get(name): - kwargs[name] = getattr(request.context, name) - except ValidationError as err: - req_validation_error = err - response = await make_response( - jsonify(err.errors()), validation_error_status - ) + if not skip_validation: + try: + await self.request_validation( + request, query, json, form, headers, cookies + ) + except ValidationError as err: + req_validation_error = err + response = await make_response( + jsonify(err.errors()), validation_error_status + ) + + if self.config.annotations: + annotations = get_type_hints(func) + for name in ("query", "json", "form", "headers", "cookies"): + if annotations.get(name): + kwargs[name] = getattr( + getattr(request, "context", None), name, None + ) before(request, response, req_validation_error, None) if req_validation_error: diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index d19d1f8a..fb9dc878 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -105,24 +105,32 @@ async def validate( response = None req_validation_error = resp_validation_error = json_decode_error = None - try: - await self.request_validation(request, query, json, form, headers, cookies) - if self.config.annotations: - annotations = get_type_hints(func) - for name in ("query", "json", "form", "headers", "cookies"): - if annotations.get(name): - kwargs[name] = getattr(request.context, name) - except ValidationError as err: - req_validation_error = err - response = JSONResponse(err.errors(), validation_error_status) - except JSONDecodeError as err: - json_decode_error = err - self.logger.info( - "%s Validation Error", - validation_error_status, - extra={"spectree_json_decode_error": str(err)}, - ) - response = JSONResponse({"error_msg": str(err)}, validation_error_status) + if not skip_validation: + try: + await self.request_validation( + request, query, json, form, headers, cookies + ) + except ValidationError as err: + req_validation_error = err + response = JSONResponse(err.errors(), validation_error_status) + except JSONDecodeError as err: + json_decode_error = err + self.logger.info( + "%s Validation Error", + validation_error_status, + extra={"spectree_json_decode_error": str(err)}, + ) + response = JSONResponse( + {"error_msg": str(err)}, validation_error_status + ) + + if self.config.annotations: + annotations = get_type_hints(func) + for name in ("query", "json", "form", "headers", "cookies"): + if annotations.get(name): + kwargs[name] = getattr( + getattr(request, "context", None), name, None + ) before(request, response, req_validation_error, instance) if req_validation_error or json_decode_error: diff --git a/spectree/spec.py b/spectree/spec.py index 19fe6676..b3fa6da7 100644 --- a/spectree/spec.py +++ b/spectree/spec.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from copy import deepcopy from functools import wraps @@ -167,6 +168,8 @@ def validate( # noqa: PLR0913 [too-many-arguments] in :meth:`spectree.spec.SpecTree`. :param path_parameter_descriptions: A dictionary of path parameter names and their description. + :param skip_validation: If set to `True`, the endpoint will skip + request / response validations. :param operation_id: a string override for operationId for the given endpoint """ # If the status code for validation errors is not overridden on the level of @@ -174,6 +177,14 @@ def validate( # noqa: PLR0913 [too-many-arguments] if validation_error_status == 0: validation_error_status = self.validation_error_status + if self.config.annotations and skip_validation: + warnings.warn( + "`skip_validation` cannot be used with `annotations` enabled. The instances" + " of `json`, `headers`, `cookies`, etc. read from function will be `None`.", + UserWarning, + stacklevel=2, + ) + def decorate_validation(func: Callable): # for sync framework @wraps(func) diff --git a/spectree/utils.py b/spectree/utils.py index 76bd48a7..5cd359b1 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -113,7 +113,7 @@ def parse_params( attr_to_spec_key = {"query": "query", "headers": "header", "cookies": "cookie"} route_param_keywords = ("explode", "style", "allowReserved") - for attr in attr_to_spec_key: + for attr, position in attr_to_spec_key.items(): if hasattr(func, attr): model = models[getattr(func, attr)] properties = model.get("properties", {model.get("title"): model}) @@ -125,7 +125,7 @@ def parse_params( params.append( { "name": name, - "in": attr_to_spec_key[attr], + "in": position, "schema": schema, "required": name in model.get("required", []), "description": schema.get("description", ""), diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index b1531388..539f65cc 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -892,11 +892,6 @@ "schema": { "$ref": "#/components/schemas/JSON.7068f62" } - }, - "multipart/form-data": { - "schema": { - "$ref": "#/components/schemas/Form.7068f62" - } } } }, diff --git a/tests/common.py b/tests/common.py index 78c1baa6..3d885b9e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,3 +1,4 @@ +import warnings import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum, IntEnum @@ -7,6 +8,9 @@ from spectree._pydantic import BaseModel, Field, root_validator from spectree.utils import hash_module_path +# suppress warnings +warnings.filterwarnings("ignore", category=UserWarning) + api_tag = Tag( name="API", description="🐱", externalDocs=ExternalDocs(url="https://pypi.org") ) diff --git a/tests/flask_imports/__init__.py b/tests/flask_imports/__init__.py index 77c4de3b..b1a8d1bd 100644 --- a/tests/flask_imports/__init__.py +++ b/tests/flask_imports/__init__.py @@ -14,16 +14,16 @@ ) __all__ = [ - "test_flask_return_model", - "test_flask_skip_validation", - "test_flask_validation_error_response_status_code", "test_flask_doc", - "test_flask_optional_alias_response", - "test_flask_validate_post_data", - "test_flask_no_response", - "test_flask_upload_file", "test_flask_list_json_request", - "test_flask_return_list_request", - "test_flask_make_response_post", "test_flask_make_response_get", + "test_flask_make_response_post", + "test_flask_no_response", + "test_flask_optional_alias_response", + "test_flask_return_list_request", + "test_flask_return_model", + "test_flask_skip_validation", + "test_flask_upload_file", + "test_flask_validate_post_data", + "test_flask_validation_error_response_status_code", ] diff --git a/tests/quart_imports/__init__.py b/tests/quart_imports/__init__.py index 8056bbf8..a5d6c093 100644 --- a/tests/quart_imports/__init__.py +++ b/tests/quart_imports/__init__.py @@ -8,10 +8,10 @@ ) __all__ = [ + "test_quart_doc", + "test_quart_no_response", "test_quart_return_model", "test_quart_skip_validation", - "test_quart_validation_error_response_status_code", - "test_quart_doc", "test_quart_validate", - "test_quart_no_response", + "test_quart_validation_error_response_status_code", ] diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index 10a9118e..6f76eed2 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -14,6 +14,7 @@ Headers, ListJSON, OptionalAliasResp, + Order, Query, Resp, RootResp, @@ -121,15 +122,14 @@ def on_get(self, req, resp, name): def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies): response_format = req.params.get("response_format") assert response_format in ("json", "xml") - score = [randint(0, req.context.json.limit) for _ in range(5)] - score.sort(reverse=req.context.query.order) - assert req.context.cookies.pub == "abcdefg" + score = [randint(0, req.media.get("limit")) for _ in range(5)] + score.sort(reverse=int(req.params.get("order")) == Order.desc) assert req.cookies["pub"] == "abcdefg" if response_format == "json": - resp.media = {"name": req.context.json.name, "x_score": score} + resp.media = {"name": req.media.get("name"), "x_score": score} else: resp.content_type = falcon.MEDIA_XML - resp.text = UserXmlData(name=req.context.json.name, score=score).dump_xml() + resp.text = UserXmlData(name=req.media.get("name"), score=score).dump_xml() class UserScoreModel: diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index 902822f1..63d61f6d 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -127,15 +127,15 @@ def user_score_annotated(name, query: Query, json: JSON, form: Form, cookies: Co def user_score_skip_validation(name): response_format = request.args.get("response_format") assert response_format in ("json", "xml") - score = [randint(0, request.context.json.limit) for _ in range(5)] - score.sort(reverse=(request.context.query.order == Order.desc)) - assert request.context.cookies.pub == "abcdefg" + json = request.get_json() + score = [randint(0, json.get("limit")) for _ in range(5)] + score.sort(reverse=(int(request.args.get("order")) == Order.desc)) assert request.cookies["pub"] == "abcdefg" if response_format == "json": - return jsonify(name=request.context.json.name, x_score=score) + return jsonify(name=name, x_score=score) else: return app.response_class( - UserXmlData(name=request.context.json.name, score=score).dump_xml(), + UserXmlData(name=name, score=score).dump_xml(), content_type="text/xml", ) diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index c277e844..315dbadf 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -114,15 +114,15 @@ def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies, form: def user_score_skip_validation(name): response_format = request.args.get("response_format") assert response_format in ("json", "xml") - score = [randint(0, request.context.json.limit) for _ in range(5)] - score.sort(reverse=request.context.query.order == Order.desc) - assert request.context.cookies.pub == "abcdefg" + json = request.get_json() + score = [randint(0, json.get("limit")) for _ in range(5)] + score.sort(reverse=int(request.args.get("order")) == Order.desc) assert request.cookies["pub"] == "abcdefg" if response_format == "json": - return jsonify(name=request.context.json.name, x_score=score) + return jsonify(name=name, x_score=score) else: return flask.Response( - UserXmlData(name=request.context.json.name, score=score).dump_xml(), + UserXmlData(name=name, score=score).dump_xml(), content_type="text/xml", ) diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index 2f6b0725..82e8c150 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -113,19 +113,18 @@ class UserSkip(MethodView): after=api_after_handler, skip_validation=True, ) - def post(self, name, query: Query, json: JSON, form: Form, cookies: Cookies): + def post(self, name): response_format = request.args.get("response_format") assert response_format in ("json", "xml") - data_src = json or form - score = [randint(0, int(data_src.limit)) for _ in range(5)] - score.sort(reverse=(query.order == Order.desc)) - assert cookies.pub == "abcdefg" + data_src = request.get_json() or request.get_data() + score = [randint(0, int(data_src["limit"])) for _ in range(5)] + score.sort(reverse=(int(request.args["order"]) == Order.desc)) assert request.cookies["pub"] == "abcdefg" if response_format == "json": - return jsonify(name=request.context.json.name, x_score=score) + return jsonify(name=name, x_score=score) else: return app.response_class( - UserXmlData(name=request.context.json.name, score=score).dump_xml(), + UserXmlData(name=name, score=score).dump_xml(), content_type="text/xml", ) diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index 4a2565ba..3a4b4d62 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -104,17 +104,15 @@ async def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies) async def user_score_skip_validation(name): response_format = request.args.get("response_format") assert response_format in ("json", "xml") - score = [randint(0, request.context.json.limit) for _ in range(5)] - score.sort(reverse=request.context.query.order == Order.desc) - assert request.context.cookies.pub == "abcdefg" + json = request.get_json() + score = [randint(0, json.get("limit")) for _ in range(5)] + score.sort(reverse=request.args.get("order") == Order.desc) assert request.cookies["pub"] == "abcdefg" if response_format == "json": - return jsonify(name=request.context.json.name, x_score=score) + return jsonify(name=json.get("name"), x_score=score) else: return app.response_class( - response=UserXmlData( - name=request.context.json.name, score=score - ).dump_xml(), + response=UserXmlData(name=json.get("name"), score=score).dump_xml(), content_type="text/xml", ) diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index fe76c31f..19223582 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -82,7 +82,7 @@ async def file_upload(request): resp=Response(HTTP_200=Resp, HTTP_401=None), tags=[api_tag, "test"], ) -async def user_score(request): +async def user_score(request, json: JSON, query: Query): score = [randint(0, request.context.json.limit) for _ in range(5)] score.sort(reverse=request.context.query.order == Order.desc) assert request.context.cookies.pub == "abcdefg" @@ -112,15 +112,15 @@ async def user_score_annotated(request, query: Query, json: JSON, cookies: Cooki ) async def user_score_skip(request): response_format = request.query_params.get("response_format") - score = [randint(0, request.context.json.limit) for _ in range(5)] - score.sort(reverse=request.context.query.order == Order.desc) - assert request.context.cookies.pub == "abcdefg" + json = await request.json() + score = [randint(0, json.get("limit")) for _ in range(5)] + score.sort(reverse=int(request.query_params.get("order")) == Order.desc) assert request.cookies["pub"] == "abcdefg" if response_format == "json": - return JSONResponse({"name": request.context.json.name, "x_score": score}) + return JSONResponse({"name": json.get("name"), "x_score": score}) else: return StarletteResponse( - UserXmlData(name=request.context.json.name, score=score).dump_xml(), + UserXmlData(name=json.get("name"), score=score).dump_xml(), media_type="text/xml", )