diff --git a/care/facility/api/viewsets/facility.py b/care/facility/api/viewsets/facility.py index 8e4f4bd3f7..c24aa1578a 100644 --- a/care/facility/api/viewsets/facility.py +++ b/care/facility/api/viewsets/facility.py @@ -39,6 +39,12 @@ class FacilityFilter(filters.FilterSet): state = filters.NumberFilter(field_name="state__id") state_name = filters.CharFilter(field_name="state__name", lookup_expr="icontains") kasp_empanelled = filters.BooleanFilter(field_name="kasp_empanelled") + exclude_user = filters.CharFilter(method="filter_exclude_user") + + def filter_exclude_user(self, queryset, name, value): + if value: + queryset = queryset.exclude(facilityuser__user__username=value) + return queryset class FacilityQSPermissions(DRYPermissionFiltersBase): diff --git a/care/facility/tests/test_facilityuser_api.py b/care/facility/tests/test_facilityuser_api.py index 0f57868a08..8b7f0d8c2d 100644 --- a/care/facility/tests/test_facilityuser_api.py +++ b/care/facility/tests/test_facilityuser_api.py @@ -1,7 +1,6 @@ from rest_framework import status from rest_framework.test import APITestCase -from care.users.models import Skill from care.utils.tests.test_utils import TestUtils @@ -14,9 +13,16 @@ def setUpTestData(cls) -> None: cls.super_user = cls.create_super_user("su", cls.district) cls.facility = cls.create_facility(cls.super_user, cls.district, cls.local_body) cls.user = cls.create_user("staff", cls.district, home_facility=cls.facility) - cls.skill1 = Skill.objects.create(name="Skill 1") - cls.skill2 = Skill.objects.create(name="Skill 2") - cls.user.skills.add(cls.skill1, cls.skill2) + + cls.facility1 = cls.create_facility( + cls.super_user, cls.district, cls.local_body + ) + cls.facility2 = cls.create_facility( + cls.super_user, cls.district, cls.local_body + ) + + def setUp(self) -> None: + self.client.force_authenticate(self.super_user) def test_get_queryset_with_prefetching(self): response = self.client.get( @@ -25,3 +31,17 @@ def test_get_queryset_with_prefetching(self): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNumQueries(2) + + def test_link_new_facility(self): + response = self.client.get("/api/v1/facility/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 3) + + def test_link_existing_facility(self): + response = self.client.get( + f"/api/v1/facility/?exclude_user={self.user.username}" + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 2)