Skip to content

Commit

Permalink
Improve Pydantic model detection robustness
Browse files Browse the repository at this point in the history
It was previously too easy to hit some false positives
  • Loading branch information
Viicos committed May 6, 2024
1 parent bf0572d commit c8ff814
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 14 deletions.
78 changes: 67 additions & 11 deletions src/flake8_pydantic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_decorator_names(decorator_list: list[ast.expr]) -> set[str]:
return names


def _has_pydantic_model_base(node: ast.ClassDef, include_root_model: bool) -> bool:
def _has_pydantic_model_base(node: ast.ClassDef, *, include_root_model: bool) -> bool:
model_class_names = {"BaseModel"}
if include_root_model:
model_class_names.add("RootModel")
Expand All @@ -42,15 +42,55 @@ def _has_model_config(node: ast.ClassDef) -> bool:
return False


PYDANTIC_FIELD_ARGUMENTS = {
"default",
"default_factory",
"alias",
"alias_priority",
"validation_alias",
"title",
"description",
"examples",
"exclude",
"discriminator",
"json_schema_extra",
"frozen",
"validate_default",
"repr",
"init",
"init_var",
"kw_only",
"pattern",
"strict",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_length",
"max_length",
"union_mode",
}


def _has_field_function(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, (ast.Assign, ast.AnnAssign)) and isinstance(stmt.value, ast.Call):
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field":
# f = Field(...)
return True
if isinstance(stmt.value.func, ast.Attribute) and stmt.value.func.attr == "Field":
# f = pydantic.Field(...)
return True
if (
isinstance(stmt, (ast.Assign, ast.AnnAssign))
and isinstance(stmt.value, ast.Call)
and (
(isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field") # f = Field(...)
or (
isinstance(stmt.value.func, ast.Attribute) and stmt.value.func.attr == "Field"
) # f = pydantic.Field(...)
)
and all(kw.arg in PYDANTIC_FIELD_ARGUMENTS for kw in stmt.value.keywords if kw.arg is not None)
):
return True

return False


Expand Down Expand Up @@ -84,14 +124,30 @@ def _has_pydantic_decorator(node: ast.ClassDef) -> bool:
return False


PYDANTIC_METHODS = {
"model_construct",
"model_copy",
"model_dump",
"model_dump_json",
"model_json_schema",
"model_parametrized_name",
"model_rebuild",
"model_validate",
"model_validate_json",
"model_validate_strings",
}


def _has_pydantic_method(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef) and stmt.name.startswith(("model_", "__pydantic_")):
if isinstance(stmt, ast.FunctionDef) and (
stmt.name.startswith(("__pydantic_", "__get_pydantic_")) or stmt.name in PYDANTIC_METHODS
):
return True
return False


def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bool:
def is_pydantic_model(node: ast.ClassDef, *, include_root_model: bool = True) -> bool:
"""Determine if a class definition is a Pydantic model.
Multiple heuristics are use to determine if this is the case:
Expand All @@ -106,7 +162,7 @@ def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bo
return False

return (
_has_pydantic_model_base(node, include_root_model)
_has_pydantic_model_base(node, include_root_model=include_root_model)
or _has_model_config(node)
or _has_field_function(node)
or _has_annotated_field(node)
Expand Down
36 changes: 33 additions & 3 deletions tests/test_is_pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,34 @@ class SubModel(ParentModel):

HAS_FIELD_FUNCTION_1 = """
class SubModel(ParentModel):
a = Field()
a = Field(title="A")
"""

HAS_FIELD_FUNCTION_2 = """
class SubModel(ParentModel):
a: int = Field()
a: int = Field(gt=1)
"""

HAS_FIELD_FUNCTION_3 = """
class SubModel(ParentModel):
a = pydantic.Field()
a = pydantic.Field(alias="b")
"""

HAS_FIELD_FUNCTION_4 = """
class SubModel(ParentModel):
a: int = pydantic.Field(repr=True)
"""

HAS_FIELD_FUNCTION_5 = """
class SubModel(ParentModel):
a: int = pydantic.Field()
"""

HAS_FIELD_FUNCTION_6 = """
class SubModel(ParentModel):
a: int = pydantic.Field(1)
"""

USES_ANNOTATED_1 = """
class SubModel(ParentModel):
a: Annotated[int, ""]
Expand Down Expand Up @@ -86,12 +96,27 @@ class SubModel(ParentModel):
def __pydantic_some_method__(self): pass
"""

HAS_PYDANTIC_METHOD_3 = """
class SubModel(ParentModel):
def __get_pydantic_core_schema__(self): pass
"""

# Negative cases:
NO_BASES = """
class Model:
a = Field()
"""

UNRELATED_FIELD_ARG = """
class SubModel(ParentModel):
a: int = Field(some_arg=1)
"""

UNRELATED_MODEL_METHOD = """
class SubModel(ParentModel):
def model_unrelated(): pass
"""


@pytest.mark.parametrize(
["source", "expected"],
Expand All @@ -105,13 +130,18 @@ class Model:
(HAS_FIELD_FUNCTION_2, True),
(HAS_FIELD_FUNCTION_3, True),
(HAS_FIELD_FUNCTION_4, True),
(HAS_FIELD_FUNCTION_5, True),
(HAS_FIELD_FUNCTION_6, True),
(USES_ANNOTATED_1, True),
(USES_ANNOTATED_2, True),
(HAS_PYDANTIC_DECORATOR_1, True),
(HAS_PYDANTIC_DECORATOR_2, True),
(HAS_PYDANTIC_METHOD_1, True),
(HAS_PYDANTIC_METHOD_2, True),
(HAS_PYDANTIC_METHOD_3, True),
(NO_BASES, False),
(UNRELATED_FIELD_ARG, False),
(UNRELATED_MODEL_METHOD, False),
],
)
def test_is_pydantic_model(source: str, expected: bool) -> None:
Expand Down

0 comments on commit c8ff814

Please sign in to comment.