Skip to content

Commit

Permalink
Handle serializing inherited models. #459
Browse files Browse the repository at this point in the history
  • Loading branch information
adamghill committed Nov 10, 2022
1 parent db90c03 commit 4c851b7
Show file tree
Hide file tree
Showing 11 changed files with 602 additions and 113 deletions.
96 changes: 83 additions & 13 deletions django_unicorn/serializer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from datetime import datetime, timedelta
from decimal import Decimal
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple

from django.core.serializers import serialize
from django.core.serializers.json import DjangoJSONEncoder
from django.db.models import (
DateField,
DateTimeField,
Expand All @@ -18,6 +20,7 @@
parse_duration,
parse_time,
)
from django.utils.duration import duration_string

import orjson

Expand All @@ -32,6 +35,8 @@

logger = logging.getLogger(__name__)

django_json_encoder = DjangoJSONEncoder()


class JSONDecodeError(Exception):
pass
Expand Down Expand Up @@ -107,6 +112,69 @@ def _get_many_to_many_field_related_names_from_meta(meta):
return _get_many_to_many_field_related_names_from_meta(model._meta)


def _get_m2m_field_serialized(model: Model, field_name) -> List:
pks = []

try:
related_descriptor = getattr(model, field_name)

# Get `pk` from `all` because it will re-use the cached data if the m-2-m field is prefetched
# Using `values_list("pk", flat=True)` or `only()` won't use the cached prefetched values
pks = [m.pk for m in related_descriptor.all()]
except ValueError:
# ValueError is thrown when the model doesn't have an id already set
pass

return pks


def _handle_inherited_models(model: Model, model_json: Dict):
"""
Handle if the model has a parent (i.e. the model is a subclass of another model).
Subclassed model's fields don't get serialized
(https://docs.djangoproject.com/en/stable/topics/serialization/#inherited-models)
so those fields need to be retrieved manually.
"""

if model._meta.get_parent_list():
for field in model._meta.get_fields():
if (
field.name not in model_json
and hasattr(field, "primary_key")
and not field.primary_key
):
if field.is_relation:
# We already serialized the m2m fields above, so we can skip them, but need to handle FKs
if not field.many_to_many:
foreign_key_field = getattr(model, field.name)
foreign_key_field_pk = getattr(
foreign_key_field,
"pk",
getattr(foreign_key_field, "id", None),
)
model_json[field.name] = foreign_key_field_pk
else:
value = getattr(model, field.name)

# Explicitly handle `timedelta`, but use the DjangoJSONEncoder for everything else
if isinstance(value, timedelta):
value = duration_string(value)
else:
# Make sure the value is properly serialized
value = django_json_encoder.encode(value)

# The DjangoJSONEncoder has extra double-quotes for strings so remove them
if (
isinstance(value, str)
and value.startswith('"')
and value.endswith('"')
):
value = value[1:-1]

model_json[field.name] = value


def _get_model_dict(model: Model) -> dict:
"""
Serializes Django models. Uses the built-in Django JSON serializer, but moves the data around to
Expand All @@ -115,27 +183,29 @@ def _get_model_dict(model: Model) -> dict:

_parse_field_values_from_string(model)

# Django's `serialize` method always returns an array, so remove the brackets from the resulting string
# Django's `serialize` method always returns a string of an array,
# so remove the brackets from the resulting string
serialized_model = serialize("json", [model])[1:-1]

# Convert the string into a dictionary and grab the `pk`
model_json = orjson.loads(serialized_model)
model_pk = model_json.get("pk")

# Shuffle around the serialized pieces to condense the size of the payload
model_json = model_json.get("fields")
model_json["pk"] = model_pk

for related_name in _get_many_to_many_field_related_names(model):
pks = []
# Set `pk` for models that subclass another model which only have `id` set
if not model_pk:
model_json["pk"] = model.pk or model.id

try:
related_descriptor = getattr(model, related_name)
# Add in m2m fields
m2m_field_names = _get_many_to_many_field_related_names(model)

# Get `pk` from `all` because it will re-use the cached data if the m-2-m field is prefetched
# Using `values_list("pk", flat=True)` or `only()` won't use the cached prefetched values
pks = [m.pk for m in related_descriptor.all()]
except ValueError:
# ValueError is thrown when the model doesn't have an id already set
pass
for m2m_field_name in m2m_field_names:
model_json[m2m_field_name] = _get_m2m_field_serialized(model, m2m_field_name)

model_json[related_name] = pks
_handle_inherited_models(model, model_json)

return model_json

Expand Down
12 changes: 11 additions & 1 deletion django_unicorn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import pickle
from inspect import signature
from pprint import pp
from typing import Dict, List, Union
from typing import get_type_hints as typing_get_type_hints

Expand Down Expand Up @@ -53,14 +54,23 @@ def dicts_equal(dictionary_one: Dict, dictionary_two: Dict) -> bool:
Return True if all keys and values are the same between two dictionaries.
"""

return all(
is_valid = all(
k in dictionary_two and dictionary_one[k] == dictionary_two[k]
for k in dictionary_one
) and all(
k in dictionary_one and dictionary_one[k] == dictionary_two[k]
for k in dictionary_two
)

if not is_valid:
print("dictionary_one:")
pp(dictionary_one)
print()
print("dictionary_two:")
pp(dictionary_two)

return is_valid


def get_cacheable_component(
component: "django_unicorn.views.UnicornView",
Expand Down
23 changes: 23 additions & 0 deletions example/books/migrations/0003_auto_20221110_0400.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 3.2.15 on 2022-11-10 04:00

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('books', '0002_author'),
]

operations = [
migrations.AlterField(
model_name='author',
name='id',
field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID'),
),
migrations.AlterField(
model_name='book',
name='id',
field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID'),
),
]
37 changes: 37 additions & 0 deletions example/coffee/migrations/0005_auto_20221110_0400.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 3.2.15 on 2022-11-10 04:00

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('coffee', '0004_origin_taste'),
]

operations = [
migrations.CreateModel(
name='NewFlavor',
fields=[
('flavor_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='coffee.flavor')),
('new_name', models.CharField(max_length=255)),
],
bases=('coffee.flavor',),
),
migrations.AlterField(
model_name='flavor',
name='id',
field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID'),
),
migrations.AlterField(
model_name='origin',
name='id',
field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID'),
),
migrations.AlterField(
model_name='taste',
name='id',
field=models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID'),
),
]
4 changes: 4 additions & 0 deletions example/coffee/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ class Taste(models.Model):
class Origin(models.Model):
name = models.CharField(max_length=255)
flavor = models.ManyToManyField(Flavor, related_name="origins")


class NewFlavor(Flavor):
new_name = models.CharField(max_length=255)
Loading

0 comments on commit 4c851b7

Please sign in to comment.