diff --git a/license_manager/apps/api/pagination.py b/license_manager/apps/api/pagination.py index 2b5dd79d..b77f4eb6 100644 --- a/license_manager/apps/api/pagination.py +++ b/license_manager/apps/api/pagination.py @@ -3,8 +3,14 @@ """ from django.core.paginator import Paginator as DjangoPaginator from django.utils.functional import cached_property +from edx_rest_framework_extensions.paginators import DefaultPagination from rest_framework.pagination import PageNumberPagination +from license_manager.apps.api.serializers import ( + MinimalCustomerAgreementSerializer, +) +from license_manager.apps.subscriptions.models import CustomerAgreement + class PageNumberPaginationWithCount(PageNumberPagination): """ @@ -31,6 +37,7 @@ class EstimatedCountDjangoPaginator(DjangoPaginator): A lazy paginator that determines it's count from the upstream `estimated_count` """ + def __init__(self, *args, estimated_count=None, **kwargs): self.estimated_count = estimated_count super().__init__(*args, **kwargs) @@ -49,6 +56,7 @@ class EstimatedCountLicensePagination(LicensePagination): which means the downstream django paginator does *not* perform an additional query to get the count of the queryset. """ + def __init__(self, *args, estimated_count=None, **kwargs): """ Optionally stores an `estimated_count` to pass along @@ -70,3 +78,37 @@ def django_paginator_class(self, queryset, page_size): queryset, page_size, estimated_count=self.estimated_count, ) return DjangoPaginator(queryset, page_size) + + +class LearnerLicensesPaginationCustomerAgreement(DefaultPagination): + """ + Adds the customer agreement object to the learner-licenses endpoint. + The learner licenses endpoint currently contains the subscription_licenses, with its + corresponding subscription_plan. In order to reduce the number of calls to the client, + we incorporate the customer_agreement accessible within a single call. + """ + + def get_paginated_response(self, data): + """ + Modifies the DefaultPagination response to include ``customer_agreement`` dict. + + Arguments: + self: LearnerLicensesPaginationCustomerAgreement instance. + data (dict): Results for current page. + + Returns: + (Response): DRF response object containing ``customer_agreement`` dict. + """ + paginated_response = super().get_paginated_response(data) + enterprise_customer_uuid = self.request.query_params.get('enterprise_customer_uuid') + try: + customer_agreement = CustomerAgreement.objects.get(enterprise_customer_uuid=enterprise_customer_uuid) + paginated_response.data.update({ + 'customer_agreement': MinimalCustomerAgreementSerializer(customer_agreement).data + }) + except CustomerAgreement.DoesNotExist: + paginated_response.data.update({ + 'customer_agreement': None + }) + + return paginated_response diff --git a/license_manager/apps/api/serializers.py b/license_manager/apps/api/serializers.py index 712f2bf4..2cf79e82 100644 --- a/license_manager/apps/api/serializers.py +++ b/license_manager/apps/api/serializers.py @@ -118,6 +118,8 @@ class MinimalCustomerAgreementSerializer(serializers.ModelSerializer): include a nested representation of related subscription plans. """ + subscription_for_auto_applied_licenses = serializers.SerializerMethodField() + class Meta: model = CustomerAgreement fields = [ @@ -127,21 +129,24 @@ class Meta: 'default_enterprise_catalog_uuid', 'disable_expiration_notifications', 'net_days_until_expiration', + 'subscription_for_auto_applied_licenses', ] + def get_subscription_for_auto_applied_licenses(self, obj): + subscription_plan = obj.auto_applicable_subscription + return subscription_plan.uuid if subscription_plan else None + class CustomerAgreementSerializer(MinimalCustomerAgreementSerializer): """ Expanded serializer for the `CustomerAgreement` model. """ subscriptions = SerializerMethodField() - subscription_for_auto_applied_licenses = serializers.SerializerMethodField() class Meta: model = CustomerAgreement fields = MinimalCustomerAgreementSerializer.Meta.fields + [ 'subscriptions', - 'subscription_for_auto_applied_licenses' ] @property @@ -160,10 +165,6 @@ def get_subscriptions(self, obj): serializer = SubscriptionPlanSerializer(plans, many=True) return serializer.data - def get_subscription_for_auto_applied_licenses(self, obj): - subscription_plan = obj.auto_applicable_subscription - return subscription_plan.uuid if subscription_plan else None - class LicenseSerializer(serializers.ModelSerializer): """ diff --git a/license_manager/apps/api/v1/tests/test_views.py b/license_manager/apps/api/v1/tests/test_views.py index 1a1c53a6..ecf1397f 100644 --- a/license_manager/apps/api/v1/tests/test_views.py +++ b/license_manager/apps/api/v1/tests/test_views.py @@ -2685,6 +2685,22 @@ def test_endpoint_request_missing_customer_uuid(self, system_role, subs_role): assert response.status_code == status.HTTP_400_BAD_REQUEST assert 'missing enterprise_customer_uuid query param' in str(response.content) + def test_endpoint_results_contains_customer_agreement(self): + """ + Tests if the learner-licenses endpoint contains the customer agreement object + on the paginator. + + Checks if the serialized customer agreement from the response matches the mocked + customer agreement. + """ + self._assign_learner_roles() + + response = self._get_url_with_customer_uuid(self.enterprise_customer_uuid) + + assert response.status_code == status.HTTP_200_OK + customer_agreement_response = response.json().get('customer_agreement') + assert customer_agreement_response['uuid'] == str(self.customer_agreement.uuid) + def test_endpoint_results_correctly_ordered(self): """ Test the ordering of responses from the endpoint matches the following: diff --git a/license_manager/apps/api/v1/views.py b/license_manager/apps/api/v1/views.py index 18bb48d6..01a79886 100644 --- a/license_manager/apps/api/v1/views.py +++ b/license_manager/apps/api/v1/views.py @@ -67,7 +67,11 @@ localized_utcnow, ) -from ..pagination import EstimatedCountLicensePagination, LicensePagination +from ..pagination import ( + EstimatedCountLicensePagination, + LearnerLicensesPaginationCustomerAgreement, + LicensePagination, +) logger = logging.getLogger(__name__) @@ -440,6 +444,7 @@ class LearnerLicensesViewSet( list_lookup_field = 'subscription_plan__customer_agreement__enterprise_customer_uuid' allowed_roles = [constants.SUBSCRIPTIONS_ADMIN_ROLE, constants.SUBSCRIPTIONS_LEARNER_ROLE] role_assignment_class = SubscriptionsRoleAssignment + pagination_class = LearnerLicensesPaginationCustomerAgreement @property def enterprise_customer_uuid(self):