Skip to content

Commit

Permalink
Support skipping validation and returning models (#212)
Browse files Browse the repository at this point in the history
* Support skipping validation and returning models

* Format

* Install deps in makefile

* Linter

* Starlette fixes

* Revert makefile

* Remove comment

* Fix on_post params

* Add model type validation

When returning a pydantic model directly (or for Starlette a
PydanticResponse) we will validate that the returned model type is one
of the models configured in the `Response` configuration.

For Flask we will also look up the status code for the provided model if it's not explicitly
stated.

* Clean up flask logic

* Fix flask tuple handling

* Lazily load JSONResponse

* Update linted files

* Fix flask validator
  • Loading branch information
danstewart authored Apr 24, 2022
1 parent 9147598 commit 744f0c9
Show file tree
Hide file tree
Showing 12 changed files with 439 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ lint:
isort --check --diff --project=spectree ${SOURCE_FILES}
black --check --diff ${SOURCE_FILES}
flake8 ${SOURCE_FILES} --count --show-source --statistics
mypy --install-types --non-interactive ${SOURCE_FILES}
mypy --install-types --non-interactive spectree

.PHONY: test doc
.PHONY: test doc
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ This library provides `before` and `after` hooks to do these. Check the [doc](ht
You can change the `validation_error_status` in SpecTree (global) or a specific endpoint (local). This also takes effect in the OpenAPI documentation.

> How can I skip the validation?
Add `skip_validation=True` to the decorator.

```py
@api.validate(json=Profile, resp=Response(HTTP_200=Message, HTTP_403=None), skip_validation=True)
```

> How can I return my model directly?
Yes, returning an instance of `BaseModel` will assume the model is valid and bypass spectree's validation and automatically call `.dict()` on the model.

For starlette you should return a `PydanticResponse`:
```py
from spectree.plugins.starlette_plugin import PydanticResponse

return PydanticResponse(MyModel)
```

## Demo

Try it with `http post :8000/api/user name=alice age=18`. (if you are using `httpie`)
Expand Down
15 changes: 13 additions & 2 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def validate(
before,
after,
validation_error_status,
skip_validation,
*args,
**kwargs,
):
Expand All @@ -223,9 +224,14 @@ def validate(
return

func(*args, **kwargs)

if resp and resp.has_model():
if isinstance(_resp.media, resp.find_model(_resp.status[:3])):
_resp.media = _resp.media.dict()
skip_validation = True

model = resp.find_model(_resp.status[:3])
if model:
if model and not skip_validation:
try:
model.parse_obj(_resp.media)
except ValidationError as err:
Expand Down Expand Up @@ -275,6 +281,7 @@ async def validate(
before,
after,
validation_error_status,
skip_validation,
*args,
**kwargs,
):
Expand All @@ -300,8 +307,12 @@ async def validate(
await func(*args, **kwargs)

if resp and resp.has_model():
if resp and isinstance(_resp.media, resp.find_model(_resp.http_status[:3])):
_resp.media = _resp.media.dict()
skip_validation = True

model = resp.find_model(_resp.status[:3])
if model:
if model and not skip_validation:
try:
model.parse_obj(_resp.media)
except ValidationError as err:
Expand Down
24 changes: 21 additions & 3 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError

from ..utils import get_multidict_items
from .base import BasePlugin, Context
Expand Down Expand Up @@ -166,6 +166,7 @@ def validate(
before,
after,
validation_error_status,
skip_validation,
*args,
**kwargs,
):
Expand All @@ -187,11 +188,28 @@ def validate(
after(request, response, req_validation_error, None)
abort(response)

response = make_response(func(*args, **kwargs))
result = func(*args, **kwargs)

status = 200
rest = []
if resp and isinstance(result, tuple) and isinstance(result[0], BaseModel):
if len(result) > 1:
model, status, *rest = result
else:
model = result[0]
else:
model = result

if isinstance(model, resp.find_model(status)):
skip_validation = True
result = (model.dict(), status, *rest)

response = make_response(result)

if resp and resp.has_model():

model = resp.find_model(response.status_code)
if model:
if model and not skip_validation:
try:
model.parse_obj(response.get_json())
except ValidationError as err:
Expand Down
21 changes: 20 additions & 1 deletion spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
Route = namedtuple("Route", ["path", "methods", "func"])


def PydanticResponse(content):
from starlette.responses import JSONResponse

class _PydanticResponse(JSONResponse):
def render(self, content) -> bytes:
self._model_class = content.__class__
return super().render(content.dict())

return _PydanticResponse(content)


class StarlettePlugin(BasePlugin):
ASYNC = True

Expand Down Expand Up @@ -60,6 +71,7 @@ async def validate(
before,
after,
validation_error_status,
skip_validation,
*args,
**kwargs,
):
Expand Down Expand Up @@ -102,8 +114,15 @@ async def validate(
response = func(*args, **kwargs)

if resp:
if (
isinstance(response, JSONResponse)
and hasattr(response, "_model_class")
and response._model_class == resp.find_model(response.status_code)
):
skip_validation = True

model = resp.find_model(response.status_code)
if model:
if model and not skip_validation:
try:
model.parse_raw(response.body)
except ValidationError as err:
Expand Down
3 changes: 3 additions & 0 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def validate(
after: Callable = None,
validation_error_status: int = 0,
path_parameter_descriptions: Mapping[str, str] = None,
skip_validation: bool = False,
) -> Callable:
"""
- validate query, json, headers in request
Expand Down Expand Up @@ -159,6 +160,7 @@ def sync_validate(*args: Any, **kwargs: Any):
before or self.before,
after or self.after,
validation_error_status,
skip_validation,
*args,
**kwargs,
)
Expand All @@ -176,6 +178,7 @@ async def async_validate(*args: Any, **kwargs: Any):
before or self.before,
after or self.after,
validation_error_status,
skip_validation,
*args,
**kwargs,
)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test_plugin_spec(api):
"/api/user/{name}",
"/api/user/{name}/address/{address_id}",
"/api/user_annotated/{name}",
"/api/user_model/{name}",
"/api/user_skip/{name}",
"/ping",
]

Expand Down
81 changes: 80 additions & 1 deletion tests/test_plugin_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def on_get(self, req, resp, name):
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
)
def on_post(self, req, resp, name):
def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
Expand Down Expand Up @@ -88,6 +88,59 @@ def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies):
resp.media = {"name": req.context.json.name, "score": score}


class UserScoreSkip:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
query=Query,
json=JSON,
cookies=Cookies,
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
skip_validation=True,
)
def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = {"name": req.context.json.name, "x_score": score}


class UserScoreModel:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
query=Query,
json=JSON,
cookies=Cookies,
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
)
def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = Resp(name=req.context.json.name, score=score)


class UserAddress:
name = "user's address"

Expand All @@ -107,6 +160,8 @@ def on_get(self, req, resp, name, address_id):
app.add_route("/api/user/{name}", UserScore())
app.add_route("/api/user_annotated/{name}", UserScoreAnnotated())
app.add_route("/api/user/{name}/address/{address_id}", UserAddress())
app.add_route("/api/user_skip/{name}", UserScoreSkip())
app.add_route("/api/user_model/{name}", UserScoreModel())
api.register(app)


Expand Down Expand Up @@ -160,6 +215,30 @@ def test_falcon_validate(client):
assert resp.headers.get("X-Name") == "sorted random score"


def test_falcon_skip_validation(client):
resp = client.simulate_request(
"POST",
"/api/user_skip/falcon?order=1",
json=dict(name="falcon", limit=10),
headers={"Cookie": "pub=abcdefg"},
)
assert resp.json["name"] == "falcon"
assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True)
assert resp.headers.get("X-Name") == "sorted random score"


def test_falcon_return_model(client):
resp = client.simulate_request(
"POST",
"/api/user_model/falcon?order=1",
json=dict(name="falcon", limit=10),
headers={"Cookie": "pub=abcdefg"},
)
assert resp.json["name"] == "falcon"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)
assert resp.headers.get("X-Name") == "sorted random score"


@pytest.fixture
def test_client_and_api(request):
api_args = ["falcon"]
Expand Down
65 changes: 65 additions & 0 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,41 @@ def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies):
return jsonify(name=json.name, score=score)


@app.route("/api/user_skip/<name>", methods=["POST"])
@api.validate(
query=Query,
json=JSON,
cookies=Cookies,
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
after=api_after_handler,
skip_validation=True,
)
def user_score_skip_validation(name):
score = [randint(0, request.context.json.limit) for _ in range(5)]
score.sort(reverse=True if request.context.query.order == Order.desc else False)
assert request.context.cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return jsonify(name=request.context.json.name, x_score=score)


@app.route("/api/user_model/<name>", methods=["POST"])
@api.validate(
query=Query,
json=JSON,
cookies=Cookies,
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
after=api_after_handler,
)
def user_score_model(name):
score = [randint(0, request.context.json.limit) for _ in range(5)]
score.sort(reverse=True if request.context.query.order == Order.desc else False)
assert request.context.cookies.pub == "abcdefg"
assert request.cookies["pub"] == "abcdefg"
return Resp(name=request.context.json.name, score=score), 200


@app.route("/api/user/<name>/address/<address_id>", methods=["GET"])
@api.validate(
query=Query,
Expand Down Expand Up @@ -157,6 +192,36 @@ def test_flask_validate(client):
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)


def test_flask_skip_validation(client):
client.set_cookie("flask", "pub", "abcdefg")

resp = client.post(
"/api/user_skip/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True)


def test_flask_return_model(client):
client.set_cookie("flask", "pub", "abcdefg")

resp = client.post(
"/api/user_model/flask?order=1",
data=json.dumps(dict(name="flask", limit=10)),
content_type="application/json",
)
assert resp.status_code == 200, resp.json
assert resp.headers.get("X-Validation") is None
assert resp.headers.get("X-API") == "OK"
assert resp.json["name"] == "flask"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)


@pytest.fixture
def test_client_and_api(request):
api_args = ["flask"]
Expand Down
Loading

0 comments on commit 744f0c9

Please sign in to comment.