Skip to content

Commit

Permalink
support dict as Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
OmmyZhang committed Nov 28, 2022
1 parent 5ac7296 commit 0deedfd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
7 changes: 5 additions & 2 deletions flask_smorest/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps
import http

import marshmallow as ma
from webargs.flaskparser import FlaskParser

from .utils import deepupdate
Expand All @@ -28,8 +29,8 @@ def arguments(
):
"""Decorator specifying the schema used to deserialize parameters
:param type|Schema schema: Marshmallow ``Schema`` class or instance
used to deserialize and validate the argument.
:param type|Schema|dict schema: Marshmallow ``Schema`` class or instance
or dict used to deserialize and validate the argument.
:param str location: Location of the argument.
:param str content_type: Content type of the argument.
Should only be used in conjunction with ``json``, ``form`` or
Expand All @@ -56,6 +57,8 @@ def arguments(
See :doc:`Arguments <arguments>`.
"""
if isinstance(schema, dict):
schema = ma.Schema.from_dict(schema)
# At this stage, put schema instance in doc dictionary. Il will be
# replaced later on by $ref or json.
parameters = {
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema):
error_id = ma.fields.Str()
text = ma.fields.Str()

return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))(
DocSchema, QueryArgsSchema, ClientErrorSchema
DictSchema = {
"item_id": ma.fields.Int(dump_only=True),
"field": ma.fields.Int(attribute="db_field"),
}

return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema"))(
DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema
)
41 changes: 41 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,47 @@ def func(document, query_args):
"query_args": {"arg1": "test"},
}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_dict_arguments(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
api = Api(app)
blp = Blueprint("test", __name__, url_prefix="/test")
client = app.test_client()

@blp.route("/", methods=("POST",))
@blp.arguments(schemas.DictSchema)
def func(document):
return {"document": document}

api.register_blueprint(blp)
spec = api.spec.to_dict()

# Check parameters are documented
if openapi_version == "2.0":
parameters = spec["paths"]["/test/"]["post"]["parameters"]
assert len(parameters) == 1
assert parameters[0]["in"] == "body"
assert "schema" in parameters[0]
else:
assert (
"schema"
in spec["paths"]["/test/"]["post"]["requestBody"]["content"][
"application/json"
]
)

# Check parameters are passed as arguments to view function
item_data = {"field": 12}
response = client.post(
"/test/",
data=json.dumps(item_data),
content_type="application/json",
)
assert response.status_code == 200
assert response.json == {
"document": {"db_field": 12},
}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
Expand Down

0 comments on commit 0deedfd

Please sign in to comment.