Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect n+1 queries caused by deferred fields in Django #41

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
language: python
sudo: false
python:
- "2.7"
- "3.6"
- "3.7"
- "3.8"
- "3.9"
env:
global:
- PYTHONPATH=tests/testapp
matrix:
- DJANGO_VERSION=">=1.8,<1.9"
- DJANGO_VERSION=">=1.9,<1.10"
- DJANGO_VERSION=">=1.10,<1.11"
- DJANGO_VERSION=">=1.11,<2.0"
- DJANGO_VERSION=">=2.0,<2.1"
- DJANGO_VERSION=">=2.1,<2.2"
matrix:
exclude:
- python: "2.7"
env: DJANGO_VERSION=">=2.0,<2.1"
- python: "2.7"
env: DJANGO_VERSION=">=2.1,<2.2"
- DJANGO_VERSION=">=2.2,<3.0"
- DJANGO_VERSION=">=3.1,<3.2"
- DJANGO_VERSION=">=3.2,<4.0"
install:
- travis_retry pip install codecov
- travis_retry pip install -U -r dev-requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion nplusone/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '1.1.0a1'
7 changes: 7 additions & 0 deletions nplusone/core/notifiers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import logging
import traceback

from nplusone.core import exceptions

Expand Down Expand Up @@ -40,6 +41,12 @@ def notify(self, message):
self.logger.log(self.level, message.message)


class TraceNotifier(LogNotifier):
def notify(self, message):
self.logger.log(self.level, "".join(traceback.format_stack()))
self.logger.log(self.level, message.message)


class ErrorNotifier(Notifier):

CONFIG_KEY = 'NPLUSONE_RAISE'
Expand Down
5 changes: 1 addition & 4 deletions nplusone/ext/django/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

from django.conf import settings

try:
from django.utils.deprecation import MiddlewareMixin
except ImportError:
MiddlewareMixin = object
from django.utils.deprecation import MiddlewareMixin

from nplusone.core import listeners
from nplusone.core import notifiers
Expand Down
92 changes: 62 additions & 30 deletions nplusone/ext/django/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,20 @@
import importlib
import threading

import django
from django.db.models import query
from django.db.models import Model

from nplusone.core import signals

if django.VERSION >= (1, 9): # pragma: no cover
from django.db.models.fields.related_descriptors import (
ReverseOneToOneDescriptor,
ForwardManyToOneDescriptor,
create_reverse_many_to_one_manager,
create_forward_many_to_many_manager,
)
else: # pragma: no cover
from django.db.models.fields.related import (
SingleRelatedObjectDescriptor as ReverseOneToOneDescriptor,
ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor,
create_foreign_related_manager as create_reverse_many_to_one_manager,
create_many_related_manager as create_forward_many_to_many_manager,
)
from django.db.models.fields.related_descriptors import (
ReverseOneToOneDescriptor,
ForwardManyToOneDescriptor,
create_reverse_many_to_one_manager,
create_forward_many_to_many_manager,
)
from django.db.models.query_utils import DeferredAttribute

NPLUSONE_WRAPPED = 'nplusone_wrapped'


def get_worker():
Expand Down Expand Up @@ -88,16 +82,8 @@ def get_related_name(model):

def parse_field(field):
return (
(
field.related_model # Django >= 1.8
if hasattr(field, 'related_model')
else field.related_field.model # Django <= 1.8
),
(
field.remote_field.name # Django >= 1.8
if hasattr(field, 'remote_field')
else field.rel.related_name # Django <= 1.8
) or get_related_name(field.related_model),
field.related_model,
field.remote_field.name or get_related_name(field.related_model),
)


Expand Down Expand Up @@ -142,11 +128,7 @@ def parse_many_related_queryset(args, kwargs, context):
rel = context['rel']
manager = context['args'][0]
model = manager.instance.__class__
related_model = (
manager.target_field.related_model # Django >= 1.8
if hasattr(manager.target_field, 'related_model')
else manager.target_field.related_field.model # Django <= 1.8
)
related_model = manager.target_field.related_model
field = manager.prefetch_cache_name if rel.related_name else None
return (
model,
Expand Down Expand Up @@ -366,3 +348,53 @@ def getitem_queryset(self, index):
)
return original_getitem_queryset(self, index)
query.QuerySet.__getitem__ = getitem_queryset


def parse_refresh_from_db(instance, fields, args, kwargs, context):
# Instance & fields passed via partial
model = type(instance)
return model, to_key(instance), fields[0]


original_deferred_attribute_get = DeferredAttribute.__get__
def deferred_attribute_get(self, instance, cls=None):
"""
DeferredAttribute.__get__() is called when a deferred
field is accessed. It may or may not trigger a db query;
if it does, it's going to be a refresh_from_db() call
So we'll emit a `touch` from there
"""
if instance is None:
return self
# Refresh-from-db, intenally, calls QuerySet.get() on our
# instance. Normally, this would make our instance immune
# to further notifications. We don't want that to happen,
# so we disable the ignore_load signal within refresh_from_db
ensure_wrapped_refresh_from_db(instance)
return original_deferred_attribute_get(self, instance, cls)
DeferredAttribute.__get__ = deferred_attribute_get


def ensure_wrapped_refresh_from_db(instance):
orig_refresh_from_db = instance.refresh_from_db
if getattr(orig_refresh_from_db, NPLUSONE_WRAPPED, False):
return
@functools.wraps(orig_refresh_from_db)
def refresh_from_db(fields=None, *args, **kwargs):
with signals.ignore(signals.ignore_load):
ret = orig_refresh_from_db(fields=fields, **kwargs)
# and now, if the refresh_from_db was called for specific fields,
# then it's a lazy load
if fields:
parser = functools.partial(parse_refresh_from_db, instance, fields)
signals.lazy_load.send(
get_worker(),
args=args,
kwargs=kwargs,
ret=ret,
context={},
parser=parser,
)
return ret
setattr(refresh_from_db, NPLUSONE_WRAPPED, True)
instance.refresh_from_db = refresh_from_db
10 changes: 5 additions & 5 deletions nplusone/ext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

def to_key(instance):
model = type(instance)
return ':'.join(
[model.__name__] +
[
return ':'.join(itertools.chain(
[model.__name__],
(
format(instance.__dict__.get(key.key)) # Avoid recursion on __get__
for key in get_primary_keys(model)
]
)
)
))


def get_primary_keys(model):
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ universal=1
# E265: block comment should start with #
# E301: expected 1 blank line, found 1
# E302: expected 2 blank lines, found 0
# W504: line break after binary operator
[flake8]
ignore = E127,E128,E265,E301,E302,E305,E306
ignore = E127,E128,E265,E301,E302,E305,E306,W504
max-line-length = 90
5 changes: 5 additions & 0 deletions tests/testapp/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ class Address(models.Model):

class Hobby(models.Model):
pass


class Medicine(models.Model):
name = models.CharField(max_length=20)
prescription = models.BooleanField(default=False)
36 changes: 34 additions & 2 deletions tests/testapp/testapp/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.conf import settings
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.test import override_settings

from nplusone.ext.django.patch import setup_state
from nplusone.ext.django.middleware import NPlusOneMiddleware
Expand All @@ -32,6 +33,7 @@ def objects():
address = models.Address.objects.create(user=user)
hobby = models.Hobby.objects.create()
user.hobbies.add(hobby)
medicine = models.Medicine.objects.create(name="Allergix")
return locals()


Expand Down Expand Up @@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls):
assert len(calls) == 0


@pytest.mark.django_db
class TestDeferred:

def test_deferred(self, objects, calls):
medicine = list(models.Medicine.objects.defer('name'))[0]
medicine.name
assert len(calls) == 1
call = calls[0]
assert call.objects == (models.Medicine, 'Medicine:1', 'name')
assert 'medicine.name' in ''.join(call.frame[4])

def test_non_deferred(self, objects, calls):
medicine = list(models.Medicine.objects.all())[0]
medicine.name
assert len(calls) == 0


@pytest.fixture
def logger(monkeypatch):
mock_logger = mock.Mock()
Expand Down Expand Up @@ -272,16 +291,29 @@ def test_select_nested_unused(self, objects, client, logger):
assert any('Pet.user' in call[1] for call in calls)
assert any('User.occupation' in call[1] for call in calls)

@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}])
def test_many_to_many_whitelist(self, objects, client, logger):
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}]
client.get('/many_to_many/')
assert not logger.log.called

@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}])
def test_many_to_many_whitelist_wildcard(self, objects, client, logger):
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}]
client.get('/many_to_many/')
assert not logger.log.called

def test_deferred(self, objects, client, logger):
client.get('/deferred/')
assert len(logger.log.call_args_list) == 1
args = logger.log.call_args[0]
assert 'Medicine.name' in args[1]

def test_double_deferred(self, objects, client, logger):
client.get('/double_deferred/')
assert len(logger.log.call_args_list) == 2
messages = sorted({args[0][1] for args in logger.log.call_args_list})
assert 'Medicine.name' in messages[0]
assert 'Medicine.prescription' in messages[1]


@pytest.mark.django_db
def test_values(objects, lazy_listener):
Expand Down
2 changes: 2 additions & 0 deletions tests/testapp/testapp/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused),
url(r'^select_nested/$', views.select_nested),
url(r'^select_nested_unused/$', views.select_nested_unused),
url(r'^deferred/$', views.deferred),
url(r'^double_deferred/$', views.double_deferred),
]
11 changes: 11 additions & 0 deletions tests/testapp/testapp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,14 @@ def select_nested(request):
def select_nested_unused(request):
pets = list(models.Pet.objects.all().select_related('user__occupation'))
return HttpResponse(pets[0])


def deferred(request):
meds = list(models.Medicine.objects.defer('name'))
return HttpResponse("; ".join(med.name for med in meds))

def double_deferred(request):
meds = list(models.Medicine.objects.only('id'))
return HttpResponse("; ".join(
med.name + (' *' if med.prescription else '') for med in meds
))
11 changes: 4 additions & 7 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist={py27,py33,py34,py35,py36}-{django18,django19,django110,django111,django20,django21}
envlist={py37,py38,py39}-{django22,django31,django32}

[testenv]
deps=
Expand All @@ -19,12 +19,9 @@ deps=
peewee

; Django versions
django18: django>=1.8,<1.9
django19: django>=1.9,<1.10
django110: django>=1.10,<1.11
django111: django>=1.11,<2.0
django20: django>=2.0,<2.1
django21: django>=2.1,<2.2
django22: django>=2.2,<3.0
django31: django>=3.1,<3.2
django32: django>=3.2,<4.0
commands=
flake8 .
py.test