Skip to content

Commit

Permalink
Detect cases of n+1 caused by Django deferred fields
Browse files Browse the repository at this point in the history
  • Loading branch information
shaib committed Jun 22, 2021
1 parent 7045b8a commit 3a14cc6
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nplusone/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '1.1.0a1'
53 changes: 53 additions & 0 deletions nplusone/ext/django/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
create_reverse_many_to_one_manager,
create_forward_many_to_many_manager,
)
from django.db.models.query_utils import DeferredAttribute

NPLUSONE_WRAPPED = 'nplusone_wrapped'


def get_worker():
Expand Down Expand Up @@ -345,3 +348,53 @@ def getitem_queryset(self, index):
)
return original_getitem_queryset(self, index)
query.QuerySet.__getitem__ = getitem_queryset


def parse_refresh_from_db(instance, fields, args, kwargs, context):
# Instance & fields passed via partial
model = type(instance)
return model, to_key(instance), fields[0]


original_deferred_attribute_get = DeferredAttribute.__get__
def deferred_attribute_get(self, instance, cls=None):
"""
DeferredAttribute.__get__() is called when a deferred
field is accessed. It may or may not trigger a db query;
if it does, it's going to be a refresh_from_db() call
So we'll emit a `touch` from there
"""
if instance is None:
return self
# Refresh-from-db, intenally, calls QuerySet.get() on our
# instance. Normally, this would make our instance immune
# to further notifications. We don't want that to happen,
# so we disable the ignore_load signal within refresh_from_db
ensure_wrapped_refresh_from_db(instance)
return original_deferred_attribute_get(self, instance, cls)
DeferredAttribute.__get__ = deferred_attribute_get


def ensure_wrapped_refresh_from_db(instance):
orig_refresh_from_db = instance.refresh_from_db
if getattr(orig_refresh_from_db, NPLUSONE_WRAPPED, False):
return
@functools.wraps(orig_refresh_from_db)
def refresh_from_db(fields=None, *args, **kwargs):
with signals.ignore(signals.ignore_load):
ret = orig_refresh_from_db(fields=fields, **kwargs)
# and now, if the refresh_from_db was called for specific fields,
# then it's a lazy load
if fields:
parser = functools.partial(parse_refresh_from_db, instance, fields)
signals.lazy_load.send(
get_worker(),
args=args,
kwargs=kwargs,
ret=ret,
context={},
parser=parser,
)
return ret
setattr(refresh_from_db, NPLUSONE_WRAPPED, True)
instance.refresh_from_db = refresh_from_db
5 changes: 5 additions & 0 deletions tests/testapp/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ class Address(models.Model):

class Hobby(models.Model):
pass


class Medicine(models.Model):
name = models.CharField(max_length=20)
prescription = models.BooleanField(default=False)
36 changes: 34 additions & 2 deletions tests/testapp/testapp/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.conf import settings
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.test import override_settings

from nplusone.ext.django.patch import setup_state
from nplusone.ext.django.middleware import NPlusOneMiddleware
Expand All @@ -32,6 +33,7 @@ def objects():
address = models.Address.objects.create(user=user)
hobby = models.Hobby.objects.create()
user.hobbies.add(hobby)
medicine = models.Medicine.objects.create(name="Allergix")
return locals()


Expand Down Expand Up @@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls):
assert len(calls) == 0


@pytest.mark.django_db
class TestDeferred:

def test_deferred(self, objects, calls):
medicine = list(models.Medicine.objects.defer('name'))[0]
medicine.name
assert len(calls) == 1
call = calls[0]
assert call.objects == (models.Medicine, 'Medicine:1', 'name')
assert 'medicine.name' in ''.join(call.frame[4])

def test_non_deferred(self, objects, calls):
medicine = list(models.Medicine.objects.all())[0]
medicine.name
assert len(calls) == 0


@pytest.fixture
def logger(monkeypatch):
mock_logger = mock.Mock()
Expand Down Expand Up @@ -272,16 +291,29 @@ def test_select_nested_unused(self, objects, client, logger):
assert any('Pet.user' in call[1] for call in calls)
assert any('User.occupation' in call[1] for call in calls)

@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}])
def test_many_to_many_whitelist(self, objects, client, logger):
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}]
client.get('/many_to_many/')
assert not logger.log.called

@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}])
def test_many_to_many_whitelist_wildcard(self, objects, client, logger):
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}]
client.get('/many_to_many/')
assert not logger.log.called

def test_deferred(self, objects, client, logger):
client.get('/deferred/')
assert len(logger.log.call_args_list) == 1
args = logger.log.call_args[0]
assert 'Medicine.name' in args[1]

def test_double_deferred(self, objects, client, logger):
client.get('/double_deferred/')
assert len(logger.log.call_args_list) == 2
messages = sorted({args[0][1] for args in logger.log.call_args_list})
assert 'Medicine.name' in messages[0]
assert 'Medicine.prescription' in messages[1]


@pytest.mark.django_db
def test_values(objects, lazy_listener):
Expand Down
2 changes: 2 additions & 0 deletions tests/testapp/testapp/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused),
url(r'^select_nested/$', views.select_nested),
url(r'^select_nested_unused/$', views.select_nested_unused),
url(r'^deferred/$', views.deferred),
url(r'^double_deferred/$', views.double_deferred),
]
11 changes: 11 additions & 0 deletions tests/testapp/testapp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,14 @@ def select_nested(request):
def select_nested_unused(request):
pets = list(models.Pet.objects.all().select_related('user__occupation'))
return HttpResponse(pets[0])


def deferred(request):
meds = list(models.Medicine.objects.defer('name'))
return HttpResponse("; ".join(med.name for med in meds))

def double_deferred(request):
meds = list(models.Medicine.objects.only('id'))
return HttpResponse("; ".join(
med.name + (' *' if med.prescription else '') for med in meds
))

0 comments on commit 3a14cc6

Please sign in to comment.