diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 990bf46de..9b8fd9cb5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,7 +68,7 @@ jobs: run: pip install --upgrade "pydantic>=1.10.0,<2.0.0" - name: Install Pydantic v2 if: matrix.pydantic-version == 'pydantic-v2' - run: pip install --upgrade "pydantic>=2.0.2,<3.0.0" + run: pip install --upgrade "pydantic>=2.0.2,<2.7.0" - name: Lint # Do not run on Python 3.7 as mypy behaves differently if: matrix.python-version != '3.7' && matrix.pydantic-version == 'pydantic-v2' diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e8330d69..8b7a20777 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -106,6 +106,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined) + sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined) if sa_column is not Undefined: if sa_column_args is not Undefined: raise RuntimeError( @@ -153,6 +155,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs + self.sa_foreign_key_args = sa_foreign_key_args + self.sa_foreign_key_kwargs = sa_foreign_key_kwargs class RelationshipInfo(Representation): @@ -222,6 +226,8 @@ def Field( sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -303,6 +309,8 @@ def Field( sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -340,6 +348,8 @@ def Field( sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, + sa_foreign_key_args=sa_foreign_key_args, + sa_foreign_key_kwargs=sa_foreign_key_kwargs, **current_schema_extra, ) post_init_field_info(field_info) @@ -638,7 +648,19 @@ def get_column_from_field(field: Any) -> Column: # type: ignore unique = False if foreign_key: assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) + sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined) + fk_args = ( + [] + if sa_foreign_key_args is Undefined + else list(cast(Sequence[Any], sa_foreign_key_args)) + ) + sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined) + fk_kwargs = ( + {} + if sa_foreign_key_kwargs is Undefined + else cast(Dict[Any, Any], sa_foreign_key_kwargs) + ) + args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs)) kwargs = { "primary_key": primary_key, "nullable": nullable, diff --git a/tests/test_field_sa_fk_args_kwargs.py b/tests/test_field_sa_fk_args_kwargs.py new file mode 100644 index 000000000..2cbe26210 --- /dev/null +++ b/tests/test_field_sa_fk_args_kwargs.py @@ -0,0 +1,75 @@ +import contextlib +import re +from typing import Optional + +import pytest +import sqlalchemy.exc +from sqlalchemy import ForeignKey, create_engine +from sqlmodel import Field, SQLModel +from sqlmodel._compat import IS_PYDANTIC_V2 + + +def test_base_model_fk(clear_sqlmodel, caplog) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),) + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + # Fails in Pydantic v2, but not v1 + with pytest.raises( + sqlalchemy.exc.InvalidRequestError + ) if IS_PYDANTIC_V2 else contextlib.nullcontext() as e: + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + if e: + assert "This ForeignKey already has a parent" in str(e.errisinstance) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + + fk_log = [ + message + for message in caplog.messages + if re.search( + r"FOREIGN KEY\s*\(owner_id\)\s*REFERENCES\s*user\s*\(id\)", message + ) + ][0] + assert "ON DELETE SET NULL" in fk_log + + +def test_base_model_fk_args(clear_sqlmodel, caplog) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, + foreign_key="user.id", + sa_foreign_key_kwargs={"ondelete": "SET NULL"}, + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + + fk_log = [ + message + for message in caplog.messages + if re.search( + r"FOREIGN KEY\s*\(owner_id\)\s*REFERENCES\s*user\s*\(id\)", message + ) + ][0] + assert "ON DELETE SET NULL" in fk_log