From 6025aea4a7ac8514d7caa49ad708a1b3e31b0843 Mon Sep 17 00:00:00 2001 From: Keming Date: Mon, 21 Feb 2022 23:32:05 +0800 Subject: [PATCH] fix flask ImmutableMultiDict and EnvironHeaders parser (#205) * fix parse json error * fix flask multidict and headers * release 0.7.5 --- setup.py | 2 +- spectree/plugins/flask_plugin.py | 31 ++++++++++++++++++++----------- spectree/utils.py | 14 ++++++++++++++ tests/test_plugin_flask.py | 2 ++ 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 4c88ed44..61786ee7 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="spectree", - version="0.7.4", + version="0.7.5", license="Apache-2.0", author="Keming Yang", author_email="kemingy94@gmail.com", diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 755b6209..59ba6c14 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -1,5 +1,6 @@ from pydantic import ValidationError +from ..utils import get_multidict_items from .base import BasePlugin, Context @@ -128,22 +129,30 @@ def parse_path(self, route, path_parameter_descriptions): return "".join(subs), parameters def request_validation(self, request, query, json, headers, cookies): - req_query = request.args or {} + """ + req_query: werkzeug.datastructures.ImmutableMultiDict + req_json: dict + req_headers: werkzeug.datastructures.EnvironHeaders + req_cookies: werkzeug.datastructures.ImmutableMultiDict + """ + req_query = get_multidict_items(request.args) or {} if request.mimetype in self.FORM_MIMETYPE: - req_json = request.form or {} + req_json = get_multidict_items(request.form) or {} if request.files: - req_json = dict( - list(request.form.items()) + list(request.files.items()) - ) + req_json = { + **req_json, + **get_multidict_items(request.files), + } else: req_json = request.get_json(silent=True) or {} - req_headers = request.headers or {} - req_cookies = request.cookies or {} + req_headers = dict(iter(request.headers)) or {} + req_cookies = get_multidict_items(request.cookies) or {} + request.context = Context( - query.parse_obj(req_query.items()) if query else None, - json.parse_obj(req_json.items()) if json else None, - headers.parse_obj(req_headers.items()) if headers else None, - cookies.parse_obj(req_cookies.items()) if cookies else None, + query.parse_obj(req_query) if query else None, + json.parse_obj(req_json) if json else None, + headers.parse_obj(req_headers) if headers else None, + cookies.parse_obj(req_cookies) if cookies else None, ) def validate( diff --git a/spectree/utils.py b/spectree/utils.py index c2c724d8..c1fc37cf 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -245,3 +245,17 @@ def get_security(security): security = [security] return security + + +def get_multidict_items(multidict): + """ + return the items of a :class:`werkzeug.datastructures.ImmutableMultiDict` + """ + res = {} + for key in multidict: + if len(multidict.getlist(key)) > 1: + res[key] = multidict.getlist(key) + else: + res[key] = multidict.get(key) + + return res diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index d64beaa1..4b9b162e 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -144,6 +144,7 @@ def test_flask_validate(client): data=json.dumps(dict(name="flask", limit=10)), content_type="application/json", ) + assert resp.status_code == 200, resp.json assert resp.json["score"] == sorted(resp.json["score"], reverse=False) resp = client.post( @@ -151,6 +152,7 @@ def test_flask_validate(client): data="name=flask&limit=10", content_type="application/x-www-form-urlencoded", ) + assert resp.status_code == 200, resp.json assert resp.json["score"] == sorted(resp.json["score"], reverse=False)