diff --git a/model_utils/models.py b/model_utils/models.py index 268db8c5..9b974b26 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -92,6 +92,10 @@ def add_status_query_managers(sender, **kwargs): default_manager = sender._meta.default_manager + status_manager_class = QueryManager.from_queryset( + queryset_class=default_manager.get_queryset().__class__, + class_name='{}StatusManager'.format(sender.__name__) + ) for value, display in getattr(sender, 'STATUS', ()): if _field_exists(sender, value): raise ImproperlyConfigured( @@ -99,7 +103,7 @@ def add_status_query_managers(sender, **kwargs): "conflicts with a status of the same name." % (sender.__name__, value) ) - sender.add_to_class(value, QueryManager(status=value)) + sender.add_to_class(value, status_manager_class(status=value)) sender._meta.default_manager_name = default_manager.name diff --git a/tests/models.py b/tests/models.py index c789267e..e0a379e4 100644 --- a/tests/models.py +++ b/tests/models.py @@ -157,6 +157,23 @@ class StatusCustomManager(AbstractStatusCustomManager): title = models.CharField(max_length=50) +class StatusCustomQuerySet(models.QuerySet): + def custom_exists(self): + return self.exists() + + def custom_count(self): + return self.count() + + +class StatusCustomQuerySet(StatusModel): + STATUS = Choices( + ("first_choice", _("First choice")), + ("second_choice", _("Second choice")), + ) + + objects = StatusCustomQuerySet.as_manager() + + class Post(models.Model): published = models.BooleanField(default=False) confirmed = models.BooleanField(default=False) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index 2e311438..4fbe328c 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -3,7 +3,12 @@ from django.test.testcases import TestCase from freezegun import freeze_time -from tests.models import Status, StatusCustomManager, StatusPlainTuple +from tests.models import ( + Status, + StatusCustomManager, + StatusCustomQuerySet, + StatusPlainTuple, +) class StatusModelTests(TestCase): @@ -97,3 +102,22 @@ def test_default_manager_is_not_status_model_generated_ones(self): # ...and this one equal to 0, because of 2 successive filters of 'first_choice' # (default manager) and 'second_choice' (explicit filter below). self.assertEqual(StatusCustomManager._default_manager.filter(status='second_choice').count(), 2) + + +class StatusModelStatusManagerTests(TestCase): + + def test_manager_has_custom_qs_methods(self): + StatusCustomQuerySet.objects.create(status='first_choice') + + StatusCustomQuerySet.objects.create(status='second_choice') + StatusCustomQuerySet.objects.create(status='second_choice') + + self.assertTrue(StatusCustomQuerySet.first_choice.custom_exists()) + self.assertEqual(StatusCustomQuerySet.first_choice.custom_count(), 1) + + self.assertTrue(StatusCustomQuerySet.second_choice.custom_exists()) + self.assertEqual(StatusCustomQuerySet.second_choice.custom_count(), 2) + + def test_manager_class_name(self): + self.assertEqual(StatusCustomQuerySet.first_choice.__class__.__name__, 'StatusCustomQuerySetStatusManager') + self.assertEqual(StatusCustomQuerySet.second_choice.__class__.__name__, 'StatusCustomQuerySetStatusManager')