-
-
Notifications
You must be signed in to change notification settings - Fork 427
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* PatchDict utility
- Loading branch information
Showing
5 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type | ||
|
||
from pydantic_core import core_schema | ||
from typing_extensions import Annotated | ||
|
||
from ninja import Body | ||
from ninja.utils import is_optional_type | ||
|
||
|
||
class ModelToDict(dict): | ||
_wrapped_model: Any = None | ||
_wrapped_model_dump_params: Dict[str, Any] = {} | ||
|
||
@classmethod | ||
def __get_pydantic_core_schema__(cls, _source: Any, _handler: Any) -> Any: | ||
return core_schema.no_info_after_validator_function( | ||
cls._validate, | ||
cls._wrapped_model.__pydantic_core_schema__, | ||
) | ||
|
||
@classmethod | ||
def _validate(cls, input_value: Any) -> Any: | ||
return input_value.model_dump(**cls._wrapped_model_dump_params) | ||
|
||
|
||
def create_patch_schema(schema_cls: Type[Any]) -> Type[ModelToDict]: | ||
values, annotations = {}, {} | ||
for f in schema_cls.__fields__.keys(): | ||
t = schema_cls.__annotations__[f] | ||
if not is_optional_type(t): | ||
values[f] = getattr(schema_cls, f, None) | ||
annotations[f] = Optional[t] | ||
values["__annotations__"] = annotations | ||
OptionalSchema = type(f"{schema_cls.__name__}Patch", (schema_cls,), values) | ||
|
||
class OptionalDictSchema(ModelToDict): | ||
_wrapped_model = OptionalSchema | ||
_wrapped_model_dump_params = {"exclude_unset": True} | ||
|
||
return OptionalDictSchema | ||
|
||
|
||
class PatchDictUtil: | ||
def __getitem__(self, schema_cls: Any) -> Any: | ||
new_cls = create_patch_schema(schema_cls) | ||
return Body[new_cls] # type: ignore | ||
|
||
|
||
if TYPE_CHECKING: # pragma: nocover | ||
PatchDict = Annotated[dict, "<PatchDict>"] | ||
else: | ||
PatchDict = PatchDictUtil() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
from ninja import NinjaAPI, Schema | ||
from ninja.patch_dict import PatchDict | ||
from ninja.testing import TestClient | ||
|
||
api = NinjaAPI() | ||
|
||
client = TestClient(api) | ||
|
||
|
||
class SomeSchema(Schema): | ||
name: str | ||
age: int | ||
category: Optional[str] = None | ||
|
||
|
||
@api.patch("/patch") | ||
def patch(request, payload: PatchDict[SomeSchema]): | ||
return {"payload": payload, "type": str(type(payload))} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input,output", | ||
[ | ||
({"name": "foo"}, {"name": "foo"}), | ||
({"age": "1"}, {"age": 1}), | ||
({}, {}), | ||
({"wrong_param": 1}, {}), | ||
({"age": None}, {"age": None}), | ||
], | ||
) | ||
def test_patch_calls(input: dict, output: dict): | ||
response = client.patch("/patch", json=input) | ||
assert response.json() == {"payload": output, "type": "<class 'dict'>"} | ||
|
||
|
||
def test_schema(): | ||
"Checking that json schema properties are all optional" | ||
schema = api.get_openapi_schema() | ||
assert schema["components"]["schemas"]["SomeSchemaPatch"] == { | ||
"title": "SomeSchemaPatch", | ||
"type": "object", | ||
"properties": { | ||
"name": { | ||
"anyOf": [{"type": "string"}, {"type": "null"}], | ||
"title": "Name", | ||
}, | ||
"age": { | ||
"anyOf": [{"type": "integer"}, {"type": "null"}], | ||
"title": "Age", | ||
}, | ||
"category": { | ||
"anyOf": [{"type": "string"}, {"type": "null"}], | ||
"title": "Category", | ||
}, | ||
}, | ||
} |