Skip to content

Commit

Permalink
Add f key on delete cascading (#64)
Browse files Browse the repository at this point in the history
* Added ondelete argument to ForeignKey to allow cascading deletion if foreign model matching key is removed

* Added SQLiteConnection connection factory to sqlite connections to support pragma f-key constraint handling

* added foreign key delete cascading to test models & tests

* Added Parent, Child tables to database_with_cache fixture

* Added pydantic versioning limit

* added pydantic max version limit

* Added ondelete CASCADE example with ForeignKey usage

---------

Co-authored-by: Joshua (codemation) <[email protected]>
  • Loading branch information
codemation and Joshua (codemation) authored Sep 9, 2023
1 parent 8cc1392 commit 739be0d
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/model-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class EmployeeInfo(DataBaseModel):
class Employee(DataBaseModel):
__tablename__ = "employee" # instead of Employee
employee_id: str = PrimaryKey()
emp_ssn: Optional[int] = ForeignKey(EmployeeInfo, 'ssn')
emp_ssn: Optional[int] = ForeignKey(EmployeeInfo, 'ssn', ondelete="CASCADE")
employee_info: Optional[EmployeeInfo] = Relationship("EmployeeInfo", 'employee_id', 'bio_id')
position: List[Optional[Positions]] = []
salary: float
Expand Down
21 changes: 18 additions & 3 deletions pydbantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,20 @@ def PrimaryKey(


def ForeignKey(
foreign_model: Union[T, str], foreign_model_key: str, default=None
foreign_model: Union[T, str],
foreign_model_key: str,
default=None,
ondelete: Optional[str] = None,
) -> Any:
"""
ondelete:
"CASCADE"
"""
return get_field_config(
foreign_model=foreign_model,
foreign_model_key=foreign_model_key,
default=default,
foreign_model_ondelete=ondelete,
)


Expand Down Expand Up @@ -205,6 +213,7 @@ def get_field_config(
autoincrement: Optional[bool] = None,
foreign_model: Any = None,
foreign_model_key: Optional[str] = None,
foreign_model_ondelete: Optional[str] = None,
relationship_model: Optional[str] = None,
relationship_local_column: Optional[str] = None,
relationship_model_column: Optional[str] = None,
Expand All @@ -227,6 +236,8 @@ def get_field_config(
if foreign_model is not None:
config["foreign_model"] = foreign_model
config["foreign_model_key"] = foreign_model_key
if foreign_model_ondelete is not None:
config["ondelete"] = foreign_model_ondelete
if relationship_model is not None:
config["relationship_model"] = relationship_model
config["relationship_local_column"] = relationship_local_column
Expand Down Expand Up @@ -761,12 +772,16 @@ def convert_fields_to_columns(
if "foreign_model_key" not in config
else config["foreign_model_key"]
)

ondelete_config = {}
if "ondelete" in config:
ondelete_config["ondelete"] = config["ondelete"]
# foreign_model_sqlalchemy_type = cls.__metadata__.tables[foreign_model_name]['column_map'][foreign_model_key][0]
# sqlalchemy_type_config[field_property] = foreign_model_sqlalchemy_type

field_constraints[field_property].append(
sqlalchemy.ForeignKey(f"{foreign_table_name}.{foreign_model_key}")
sqlalchemy.ForeignKey(
f"{foreign_table_name}.{foreign_model_key}", **ondelete_config
)
)
if "relationship_model" in config:
relationship_definitions[config["relationship_model"]] = {
Expand Down
14 changes: 13 additions & 1 deletion pydbantic/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import sqlite3
import time
import uuid
from copy import deepcopy
Expand All @@ -11,13 +12,20 @@
from alembic.migration import MigrationContext
from alembic.operations import Operations
from databases import Database as _Database
from databases import backends
from sqlalchemy import create_engine

from pydbantic.cache import Redis
from pydbantic.core import BaseMeta, DatabaseInit, DataBaseModel, TableMeta
from pydbantic.translations import DEFAULT_TRANSLATIONS


class SQLiteConnection(sqlite3.Connection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.execute("pragma foreign_keys=1")


class Database:
def __init__(
self,
Expand Down Expand Up @@ -679,7 +687,11 @@ async def _migrate(self):
return self

async def db_connection(self):
async with _Database(self.DB_URL) as connection:
conn_factory = (
{"factory": SQLiteConnection} if "sqlite" in self.DB_URL.lower() else {}
)

async with _Database(self.DB_URL, **conn_factory) as connection:
while True:
status = yield connection
if status == "finished":
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SQLAlchemy==1.4.28
databases==0.6.0
redis>=4.2.0
pydantic>=1.9.1
pydantic>=1.9.1,<2
asyncpg==0.24.0
aiosqlite==0.17.0
alembic==1.8.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"SQLAlchemy==1.4.28",
"databases==0.5.3",
"redis>=4.2.0",
"pydantic>=1.9.1",
"pydantic>=1.9.1,pydantic<2",
"alembic==1.8.1",
]
MYSQL_REQUIREMENTS = [
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pydbantic import Database
from pydbantic.cache import Redis
from tests.models import Department, Employee, EmployeeInfo, Positions
from tests.models import Child, Department, Employee, EmployeeInfo, Parent, Positions

DB_PATH = {
"sqlite": "sqlite:///test.db",
Expand Down Expand Up @@ -52,7 +52,7 @@ async def database_with_cache(request):

db = await Database.create(
request.param,
tables=[EmployeeInfo, Employee, Positions, Department],
tables=[EmployeeInfo, Employee, Positions, Department, Parent, Child],
cache_enabled=False,
testing=True,
)
Expand Down
10 changes: 10 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ def uuid_str():
return str(uuid4())


class Parent(DataBaseModel):
name: str = PrimaryKey()
sex: str


class Child(DataBaseModel):
name: str = PrimaryKey()
parent: str = ForeignKey(Parent, foreign_model_key="name", ondelete="CASCADE")


class Department(DataBaseModel):
department_id: str = PrimaryKey()
name: str
Expand Down
13 changes: 12 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from pydbantic import Database
from tests.models import Employee
from tests.models import Child, Employee, Parent


@pytest.mark.asyncio
Expand Down Expand Up @@ -53,3 +53,14 @@ async def test_models(database_with_cache):
assert result[0].position[0].name == employee.position[0].name
assert result[0].salary == employee.salary
assert result[0].is_employed == employee.is_employed

parent = await Parent.create(name="bob", sex="MALE")
child = await Child.create(name="joe", parent=parent.name)

assert await Parent.all()
assert await Child.all()

await parent.delete()

assert not await Parent.all()
assert not await Child.all()

0 comments on commit 739be0d

Please sign in to comment.