Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propose change for subclass inherited_fields override when include_fk=False #657

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/marshmallow_sqlalchemy/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, cast

import sqlalchemy as sa
from marshmallow.fields import Field
from marshmallow.schema import Schema, SchemaMeta, SchemaOpts
from marshmallow.schema import Schema, SchemaMeta, SchemaOpts, _get_fields

from .convert import ModelConverter
from .exceptions import IncorrectSchemaTypeError
Expand Down Expand Up @@ -118,7 +119,7 @@ def get_declared_fields(
cls_fields,
# Filter out fields generated from foreign key columns
# if include_fk is set to False in the options
mcs._maybe_filter_foreign_keys(inherited_fields, opts=opts),
mcs._maybe_filter_foreign_keys(inherited_fields, opts=opts, klass=klass),
dict_cls,
)
fields.update(mcs.get_declared_sqla_fields(fields, converter, opts, dict_cls))
Expand Down Expand Up @@ -159,6 +160,7 @@ def _maybe_filter_foreign_keys(
fields: list[tuple[str, Field]],
*,
opts: SQLAlchemySchemaOpts,
klass: SchemaMeta,
) -> list[tuple[str, Field]]:
if opts.model is not None or opts.table is not None:
if not hasattr(opts, "include_fk") or opts.include_fk is True:
Expand All @@ -168,7 +170,31 @@ def _maybe_filter_foreign_keys(
for column in sa.inspect(opts.model or opts.table).columns # type: ignore[union-attr]
if column.foreign_keys
}
return [(name, field) for name, field in fields if name not in foreign_keys]

schema_overrides = [
base
for base in inspect.getmro(klass)
if issubclass(base, Schema)
and not issubclass(base, SQLAlchemyAutoSchema)
]

def is_overridden(field: str) -> bool:
return any(
field
in [
name
for name, _ in _get_fields(
getattr(base, "_declared_fields", base.__dict__)
)
]
for base in schema_overrides
)

return [
(name, field)
for name, field in fields
if name not in foreign_keys or is_overridden(name)
]
return fields


Expand Down
29 changes: 29 additions & 0 deletions tests/test_sqlalchemy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,34 @@ class Meta(TeacherSchema.Meta):
assert "current_school_id" in schema2.fields


def test_auto_schema_with_model_allows_schema_extensions_to_override_include_fk_with_explicit_inherited_field(
models,
):
class OverrideSchema(Schema):
current_school_id = fields.Integer()

class TeacherSchema(SQLAlchemyAutoSchema, OverrideSchema):
class Meta:
model = models.Teacher

schema = TeacherSchema()
assert "current_school_id" in schema.fields


def test_auto_schema_with_table_allows_schema_extensions_to_override_include_fk_with_explicit_inherited_field(
models,
):
class OverrideSchema(Schema):
current_school_id = fields.Integer()

class TeacherSchema(SQLAlchemyAutoSchema, OverrideSchema):
class Meta:
table = models.Teacher.__table__

schema = TeacherSchema()
assert "current_school_id" in schema.fields


def test_auto_field_does_not_accept_arbitrary_kwargs(models):
if int(version("marshmallow")[0]) < 4:
from marshmallow.warnings import RemovedInMarshmallow4Warning
Expand All @@ -695,6 +723,7 @@ class Meta:
model = models.Course

name = auto_field(description="A course name")

else:
with pytest.raises(TypeError, match="unexpected keyword argument"):

Expand Down