From 8f30325e55cdef2453aff865ec656283dc7697a1 Mon Sep 17 00:00:00 2001 From: Evgeny Arshinov Date: Mon, 8 Apr 2024 13:59:06 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8Properly=20support=20inheritance=20of?= =?UTF-8?q?=20Relationship=20attributes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 7 ++- tests/test_relationship_inheritance.py | 62 ++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/test_relationship_inheritance.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e8330d69d..330bdf374b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -415,6 +415,8 @@ def __new__( **kwargs: Any, ) -> Any: relationships: Dict[str, RelationshipInfo] = {} + for base in bases: + relationships.update(getattr(base, "__sqlmodel_relationships__", {})) dict_for_pydantic = {} original_annotations = get_annotations(class_dict) pydantic_annotations = {} @@ -471,8 +473,9 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) - setattr(new_cls, k, col) + if k not in relationships: + col = get_column_from_field(v) + setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. # This could be done by reading new_cls.model_config['table'] in FastAPI, but diff --git a/tests/test_relationship_inheritance.py b/tests/test_relationship_inheritance.py new file mode 100644 index 0000000000..804c4bd741 --- /dev/null +++ b/tests/test_relationship_inheritance.py @@ -0,0 +1,62 @@ +from typing import Optional + +from sqlalchemy.orm import declared_attr, relationship +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select + + +def test_relationship_inheritance() -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class CreatedUpdatedMixin(SQLModel): + # With Pydantic V2, it is also possible to define `created_by` like this: + # + # ```python + # @declared_attr + # def _created_by(cls): + # return relationship(User, foreign_keys=cls.created_by_id) + # + # created_by: Optional[User] = Relationship(sa_relationship=_created_by)) + # ``` + # + # The difference from Pydantic V1 is that Pydantic V2 plucks attributes with names starting with '_' (but not '__') + # from class attributes and stores them separately as instances of `pydantic.ModelPrivateAttr` somewhere in depths of + # Pydantic internals. Under Pydantic V1 this doesn't happen, so SQLAlchemy ends up having two class attributes + # (`_created_by` and `created_by`) corresponding to one database attribute, causing a conflict and unreliable behavior. + # The approach with a lambda always works because it doesn't produce the second class attribute and thus eliminates + # the possibility of a conflict entirely. + # + created_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + created_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.created_by_id) + ) + ) + + updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + updated_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.updated_by_id) + ) + ) + + class Asset(CreatedUpdatedMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + john = User(name="John") + jane = User(name="Jane") + asset = Asset(created_by=john, updated_by=jane) + + with Session(engine) as session: + session.add(asset) + session.commit() + + with Session(engine) as session: + asset = session.exec(select(Asset)).one() + assert asset.created_by.name == "John" + assert asset.updated_by.name == "Jane"