Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor smartrelashionshipsmixin #47

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@ jobs:

strategy:
matrix:
python-version: [ '3.7', '3.9' ]
python-version: ['3.9']
sqlalchemy-version: [ '1.3', '1.4' ]
include:
- sqlalchemy-version: '1.3'
sqlalchemy-lt-version: '1.4'
flask-sqlalchemy-version: '2.0'
flask-sqlalchemy-lt-version: '3.0'
flask-version: '2'
flask-lt-version: '3'
- sqlalchemy-version: '1.4'
sqlalchemy-lt-version: '2.0'
flask-sqlalchemy-version: '3.0'
flask-sqlalchemy-lt-version: '4.0'
flask-version: '3'
flask-lt-version: '4'

name: Python ${{ matrix.python-version }} - SQLAlchemy ${{ matrix.sqlalchemy-version }}

Expand All @@ -43,7 +47,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -e .[tests] pytest-cov \
"sqlalchemy>=${{ matrix.sqlalchemy-version }},<${{ matrix.sqlalchemy-lt-version }}" \
"flask-sqlalchemy>=${{ matrix.flask-sqlalchemy-version }},<${{ matrix.flask-sqlalchemy-lt-version }}"
"flask-sqlalchemy>=${{ matrix.flask-sqlalchemy-version }},<${{ matrix.flask-sqlalchemy-lt-version }}" \
"flask>=${{ matrix.flask-version }},<${{ matrix.flask-lt-version }}"

- name: Test with pytest
run: |
Expand Down
10 changes: 9 additions & 1 deletion src/utils_flask_sqla/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from marshmallow.fields import Nested
from marshmallow_sqlalchemy.fields import RelatedList, Related

# from flask_marshmallow.fields import RelatedList


class SmartRelationshipsMixin:
Expand All @@ -19,7 +22,11 @@ def __init__(self, *args, **kwargs):
# excluded fields at meta level are not even generated by auto-schema
if field is None:
continue
if isinstance(field, Nested):
if (
isinstance(field, Nested)
or isinstance(field, RelatedList)
or isinstance(field, Related)
):
nested_fields.add(name)
elif field.metadata.get("exclude", False):
excluded_fields.add(name)
Expand All @@ -40,6 +47,7 @@ def __init__(self, *args, **kwargs):
exclude = kwargs.pop("exclude", None)
exclude = set(exclude) if exclude is not None else set()
exclude |= (excluded_fields | nested_fields) - firstlevel_only

# If only contains only nested & additional fields, we need to add included_fields to serialize nested, additional & included fields.
# If only does not contains nested or additional fields, we do nothing and marshmallow will serialize only specified fields.
if only and not firstlevel_only - nested_fields - additional_fields:
Expand Down
60 changes: 46 additions & 14 deletions src/utils_flask_sqla/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,35 @@ class Parent(db.Model):
col = Column(String)


cor_hobby_child = db.Table(
"cor_hobby_child",
db.Column("id_child", db.Integer, ForeignKey("child.pk")),
db.Column("id_hobby", db.Integer, ForeignKey("hobby.pk")),
)


class Hobby(db.Model):
__tablename__ = "hobby"
pk = Column(Integer, primary_key=True)
name = Column(Integer)


class Address(db.Model):
__tablename__ = "address"
pk = Column(Integer, primary_key=True)
street = Column(Integer)
city = Column(Integer)


class Child(db.Model):
__tablename__ = "child"
pk = Column(Integer, primary_key=True)
col = Column(String)
parent_pk = Column(Integer, ForeignKey(Parent.pk))
address_pk = Column(Integer, ForeignKey(Address.pk))
parent = relationship("Parent", backref="childs")
hobbies = relationship(Hobby, secondary=cor_hobby_child)
address = relationship(Address)


class ParentSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema):
Expand All @@ -34,18 +58,34 @@ class Meta:
childs = Nested("ChildSchema", many=True)


class HobbySchema(SQLAlchemyAutoSchema):
class Meta:
model = Hobby


class AdressSchema(SQLAlchemyAutoSchema):
class Meta:
model = Address


class ChildSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema):
class Meta:
model = Child
include_fk = True

parent = Nested(ParentSchema)
hobbies = (
auto_field()
) # For a n-n relationship a RelatedList field is created by marshmallow_sqalchemy
address = auto_field()


class TestSmartRelationshipsMixin:
def test_only(self):
parent = Parent(pk=1, col="p")
child = Child(pk=1, col="c", parent_pk=1, parent=parent)
child = Child(pk=1, col="c", parent_pk=1, address_pk=1, parent=parent)
child.hobbies = [Hobby(pk=1, name="Tennis"), Hobby(pk=2, name="petanque")]
child.address = Address(pk=1, street="5th avenue", city="New-York")
parent.childs = [child]

TestCase().assertDictEqual(
Expand All @@ -58,20 +98,16 @@ def test_only(self):

TestCase().assertDictEqual(
ChildSchema().dump(child),
{
"pk": 1,
"col": "c",
"parent_pk": 1,
},
{"pk": 1, "col": "c", "parent_pk": 1, "address_pk": 1},
)

TestCase().assertDictEqual(
ParentSchema(only=["childs"]).dump(parent),
ParentSchema(only=["childs", "childs.hobbies"]).dump(parent),
{
"pk": 1,
"col": "p",
"childs": [
{"pk": 1, "col": "c", "parent_pk": 1},
{"pk": 1, "col": "c", "parent_pk": 1, "address_pk": 1, "hobbies": [1, 2]},
],
},
)
Expand All @@ -97,6 +133,7 @@ def test_only(self):
"pk": 1,
"col": "c",
"parent_pk": 1,
"address_pk": 1,
"parent": {
"pk": 1,
"col": "p",
Expand Down Expand Up @@ -176,12 +213,7 @@ def test_null_relationship(self):

TestCase().assertDictEqual(
ChildSchema(only=("parent",)).dump(child),
{
"pk": 1,
"col": None,
"parent_pk": None,
"parent": None,
},
{"pk": 1, "col": None, "parent_pk": None, "parent": None, "address_pk": None},
)

def test_polymorphic_model(self):
Expand Down
Loading