Skip to content

Commit

Permalink
Merge pull request #9 from codemation/foreign-model-arrays
Browse files Browse the repository at this point in the history
Add Feature - Foreign model arrays
  • Loading branch information
codemation authored Nov 17, 2021
2 parents f9e5ffb + 2b53c6d commit eab0c57
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 34 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,34 @@ Adding cache with Redis is easy with `pydbantic`, and is complete with built in
cache_enabled=True,
redis_url="redis://localhost"
)
```

## Models with arrays of Foreign Objects

`DataBaseModel` models can support arrays of both `BaseModels` and other `DataBaseModel`. Just like single `DataBaseModel` references, data is stored in separate tables, and populated automatically when the child `DataBaseModel` is instantiated.

```python
from uuid import uuid4
from datetime import datetime
from typing import List, Optional
from pydbantic import DataBaseModel, PrimaryKey


def time_now():
return datetime.now().isoformat()
def get_uuid4():
return str(uuid4())

class Coordinate(DataBaseModel):
time: str = PrimaryKey(default=time_now)
latitude: float
longitude: float

class Journey(DataBaseModel):
trip_id: str = PrimaryKey(default=get_uuid4)
waypoints: List[Optional[Coordinate]]




```
40 changes: 34 additions & 6 deletions docs/model-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from pydbantic import DataBaseModel, PrimaryKey

class Coordinates(BaseModel):

class Department(DataBaseModel):
id: str = PrimaryKey()
name: str
name: str = PrimaryKey()
company: str
is_sensitive: bool = False
location: Optional[str]

class Positions(DataBaseModel):
id: str = PrimaryKey()
name: str
name: str = PrimaryKey()
department: Department

class EmployeeInfo(DataBaseModel):
Expand Down Expand Up @@ -145,4 +145,32 @@ Much like updates, `DataBaseModel` objects can only be deleted by directly calli
```

!!! WARNING
Deleted objects which are depended on by other `DataBaseModel` are <u>NOT</u> deleted, as no strict table relationships exist between `DataBaseModel`. This may be changed later.
Deleted objects which are depended on by other `DataBaseModel` are <u>NOT</u> deleted, as no strict table relationships exist between `DataBaseModel`. This may be changed later.


### Models with arrays of Foreign Objects

`DataBaseModel` models can support arrays of both `BaseModels` and other `DataBaseModel`. Just like single `DataBaseModel` references, data is stored in separate tables, and populated automatically when the child `DataBaseModel` is instantiated.

```python
from uuid import uuid4
from datetime import datetime
from typing import List, Optional
from pydbantic import DataBaseModel, PrimaryKey


def time_now():
return datetime.now().isoformat()
def get_uuid4():
return str(uuid4())

class Coordinate(DataBaseModel):
time: str = PrimaryKey(default=time_now)
latitude: float
longitude: float

class Journey(DataBaseModel):
trip_id: str = PrimaryKey(default=get_uuid4)
waypoints: List[Optional[Coordinate]]

```
97 changes: 71 additions & 26 deletions pydbantic/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from pydantic import BaseModel, Field
import typing
from typing import Optional, Union, List
from pydantic.typing import is_callable_type
import sqlalchemy
from sqlalchemy import select
from pickle import dumps, loads
Expand Down Expand Up @@ -53,6 +53,23 @@ def Default(default=...):
class DataBaseModel(BaseModel):
__metadata__: BaseMeta = BaseMeta()

@classmethod
def check_if_subtype(cls, field):

database_model = None
if isinstance(field['type'], typing._GenericAlias):
breakpoint()
for sub in field['type'].__args__:
if issubclass(sub, DataBaseModel):
if database_model:
raise Exception(f"Cannot Specify two DataBaseModels in Union[] for {field['name']}")
database_model = sub
elif issubclass(field['type'], DataBaseModel):
return field['type']
return database_model



@classmethod
async def refresh_models(cls):
"""
Expand Down Expand Up @@ -117,11 +134,16 @@ def convert_fields_to_columns(
include = [f for f in cls.__fields__]

primary_key = None
array_fields = set()

for property, config in cls.schema()['properties'].items():

if 'primary_key' in config:
if primary_key:
raise Exception(f"Duplicate Primary Key Specified for {cls.__name__}")
primary_key = property
if 'type' in config and config['type'] == 'array':
array_fields.add(property)

if not model_fields:
model_fields_list = [
Expand All @@ -147,23 +169,24 @@ def convert_fields_to_columns(

columns = []
for i, field in enumerate(model_fields):
if issubclass(field['type'], DataBaseModel):
data_base_model = cls.check_if_subtype(field)
if data_base_model:
# ensure DataBaseModel also exists in Database, even if not already
# explicity added

cls.__metadata__.database.add_table(field['type'])
cls.__metadata__.database.add_table(data_base_model)

# create a string or foreign table column to be used to reference
# other table
foreign_table_name = field['type'].__name__
foreign_primary_key_name = field['type'].__metadata__.tables[foreign_table_name]['primary_key']
foreign_key_type = field['type'].__metadata__.tables[foreign_table_name]['column_map'][foreign_primary_key_name][1]
foreign_table_name = data_base_model.__name__
foreign_primary_key_name = data_base_model.__metadata__.tables[foreign_table_name]['primary_key']
foreign_key_type = data_base_model.__metadata__.tables[foreign_table_name]['column_map'][foreign_primary_key_name][1]

serialize = field['name'] in array_fields

cls.__metadata__.tables[name]['column_map'][field['name']] = (
cls.__metadata__.database.get_translated_column_type(foreign_key_type)[0],
field['type'],
False
data_base_model,
serialize
)

# store field name in map to quickly determine attribute is tied to
Expand Down Expand Up @@ -227,22 +250,33 @@ async def serialize(self, data: dict, insert: bool = False, alias=None):
values = {**data}

for k, v in data.items():

name = self.__class__.__name__
serialize = self.__metadata__.tables[name]['column_map'][k][2]

if k in self.__metadata__.tables[name]['foreign_keys']:

# use the foreign DataBaseModel's primary key / value
foreign_type = self.__metadata__.tables[name]['column_map'][k][1]
foreign_primary_key = foreign_type.__metadata__.tables[foreign_type.__name__]['primary_key']
foreign_model = foreign_type(**v)
foreign_primary_key_value = getattr(foreign_model, foreign_primary_key)
values[f'fk_{foreign_type.__name__}_{foreign_primary_key}'.lower()] = foreign_primary_key_value

foreign_values = [v] if not isinstance(v, list) else v
fk_values = []

for v in foreign_values:
foreign_model = foreign_type(**v)
foreign_primary_key_value = getattr(foreign_model, foreign_primary_key)

fk_values.append(foreign_primary_key_value)

if insert:
exists = await foreign_type.exists(**{foreign_primary_key: foreign_primary_key_value})
if not exists:
await foreign_model.insert()
del values[k]
if insert:
exists = await foreign_type.exists(**{foreign_primary_key: foreign_primary_key_value})

if not exists:

await foreign_model.insert()
values[f'fk_{foreign_type.__name__}_{foreign_primary_key}'.lower()] = fk_values[0] if not serialize else dumps(fk_values)

continue

serialize = self.__metadata__.tables[name]['column_map'][k][2]
Expand Down Expand Up @@ -363,17 +397,28 @@ async def select(cls,
for result in cls.normalize(results):
values = {}
for sel, value in zip(selection, result):
serialized = cls.__metadata__.tables[cls.__name__]['column_map'][sel][2]

if sel in cls.__metadata__.tables[cls.__name__]['foreign_keys']:

foreign_type = cls.__metadata__.tables[cls.__name__]['column_map'][sel][1]
foreign_primary_key = foreign_type.__metadata__.tables[foreign_type.__name__]['primary_key']
values[sel] = await foreign_type.select(
'*',
where={foreign_primary_key: result[value]},
)

foreign_primary_key_values = loads(result[value]) if serialized else [result[value]]
values[sel] = []
for foreign_primary_key_value in foreign_primary_key_values:
fk_query_results = await foreign_type.select(
'*',
where={foreign_primary_key: foreign_primary_key_value},
)
values[sel].extend(fk_query_results)
if serialized:
values[sel] = values[sel]
continue

values[sel] = values[sel][0] if values[sel] else None
continue

serialized = cls.__metadata__.tables[cls.__name__]['column_map'][sel][2]
if serialized:
try:
values[sel] = loads(result[value])
Expand Down Expand Up @@ -461,18 +506,18 @@ async def update(self,
if not where_:
where_ = {primary_key: getattr(self, primary_key)}

table = self.__metadata__.tables[self.__class__.__name__]['table']
table = self.__metadata__.tables[table_name]['table']
for column in to_update.copy():
if column in self.__metadata__.tables[table_name]['foreign_keys']:
del to_update[column] # = foreign_pk_value
continue
if column not in table.c:
raise Exception(f"{column} is not a valid column in {table}")

query, _ = self.where(table.update(), where_)
query = query.values(**to_update)

to_update = await self.serialize(to_update, insert=True)

to_update = await self.serialize(to_update)
query = query.values(**to_update)

await self.__metadata__.database.execute(query, to_update)

Expand Down
20 changes: 18 additions & 2 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional
from uuid import uuid4
from datetime import datetime
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from pydbantic import DataBaseModel, PrimaryKey, Default

Expand Down Expand Up @@ -30,4 +32,18 @@ class Employee(DataBaseModel):
position: Positions
salary: float
is_employed: bool
date_employed: Optional[str]
date_employed: Optional[str]

def time_now():
return datetime.now().isoformat()
def get_uuid4():
return str(uuid4())

class Coordinate(DataBaseModel):
time: str = PrimaryKey(default=time_now)
latitude: float
longitude: float

class Journey(DataBaseModel):
trip_id: str = PrimaryKey(default=get_uuid4)
waypoints: List[Optional[Coordinate]]
57 changes: 57 additions & 0 deletions tests/test_model_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
import pytest
from pydbantic import Database
from tests.models import Journey, Coordinate

@pytest.mark.asyncio
async def test_database(db_url):
await Database.create(
db_url,
tables=[Journey],
cache_enabled=False,
testing=True
)

journey = await Journey.create(
waypoints=[Coordinate(latitude=1.0, longitude=1.0), Coordinate(latitude=1.0, longitude=1.0)]
)

all_coordinates = await Coordinate.all()

assert len(all_coordinates) == 2


all_journeys = await Journey.all()
assert len(all_journeys) ==1
assert len(all_journeys[0].waypoints) == 2

for coordinate in all_coordinates:
await coordinate.delete()

all_journeys = await Journey.all()
assert len(all_journeys[0].waypoints) == 0

all_journeys[0].waypoints=all_coordinates
await all_journeys[0].save()

all_journeys = await Journey.all()
assert len(all_journeys[0].waypoints) == 2

all_journeys[0].waypoints.pop(0)
await all_journeys[0].save()

all_journeys = await Journey.all()
assert len(all_journeys[0].waypoints) == 1

journey = await Journey.create(
waypoints=all_journeys[0].waypoints
)

all_journeys = await Journey.all()
assert len(all_journeys) == 2
assert all_journeys[0].waypoints == all_journeys[1].waypoints

await all_journeys[1].waypoints[0].delete()

all_journeys = await Journey.all()
assert all_journeys[0].waypoints == all_journeys[1].waypoints

0 comments on commit eab0c57

Please sign in to comment.