Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medicines: replace in-memory search with pg full text search with GIN index #1439

Closed
wants to merge 10 commits into from
78 changes: 30 additions & 48 deletions care/facility/api/viewsets/prescription.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from re import IGNORECASE

from django.contrib.postgres.search import SearchQuery, SearchRank
from django.db.models.query import F
from django.shortcuts import get_object_or_404
from django_filters import rest_framework as filters
from drf_spectacular.openapi import OpenApiParameter
from drf_spectacular.utils import extend_schema
from rest_framework import mixins, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet, ViewSet
from rest_framework.viewsets import GenericViewSet

from care.facility.api.serializers.prescription import (
MedibaseMedicineSerializer,
MedicineAdministrationSerializer,
PrescriptionSerializer,
)
from care.facility.models import (
MedibaseMedicine,
MedicineAdministration,
Prescription,
PrescriptionType,
Expand Down Expand Up @@ -136,52 +139,31 @@ def administer(self, request, *args, **kwargs):
# return Response({"success": True}, status=status.HTTP_200_OK)


class MedibaseViewSet(ViewSet):
class MedicineViewSet(
mixins.ListModelMixin,
mixins.RetrieveModelMixin,
GenericViewSet,
):
serializer_class = MedibaseMedicineSerializer
permission_classes = (IsAuthenticated,)
queryset = MedibaseMedicine.objects.all()
lookup_field = "external_id"

def serailize_data(self, objects):
result = []
for object in objects:
if type(object) == tuple:
object = object[0]
result.append(
{
"id": object.external_id,
"name": object.name,
"type": object.type,
"generic": object.generic,
"company": object.company,
"contents": object.contents,
"cims_class": object.cims_class,
"atc_classification": object.atc_classification,
}
)
return result

def sort(self, query, results):
exact_matches = []
partial_matches = []

for result in results:
if type(result) == tuple:
result = result[0]
words = result.searchable.lower().split()
if query in words:
exact_matches.append(result)
else:
partial_matches.append(result)

return exact_matches + partial_matches

def list(self, request):
from care.facility.static_data.medibase import MedibaseMedicineTable
@extend_schema(
parameters=[
OpenApiParameter(name="search", required=False, type=str),
]
)
def list(self, request, *args, **kwargs):
rank = SearchRank(
F("search_vector"), SearchQuery(request.query_params.get("search", ""))
)
queryset = self.queryset.annotate(rank=rank).order_by("-rank")

queryset = MedibaseMedicineTable
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)

if request.GET.get("query", False):
query = request.GET.get("query").strip().lower()
queryset = queryset.where(
searchable=queryset.re_match(r".*" + query + r".*", IGNORECASE)
)
queryset = self.sort(query, queryset)
return Response(self.serailize_data(queryset[:15]))
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
69 changes: 69 additions & 0 deletions care/facility/migrations/0371_auto_20230706_2228.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Generated by Django 4.2.2 on 2023-07-06 16:58

import django.contrib.postgres.indexes
import django.contrib.postgres.search
from django.contrib.postgres.operations import BtreeGinExtension
from django.contrib.postgres.search import SearchVector
from django.db import migrations


def compute_search_vector(apps, schema_editor):
MedibaseMedicine = apps.get_model("facility", "MedibaseMedicine")
MedibaseMedicine.objects.update(
search_vector=(
SearchVector("generic", weight="A")
+ SearchVector("name", weight="A")
+ SearchVector("company", weight="C")
+ SearchVector("cims_class", weight="D")
+ SearchVector("contents", weight="D")
)
)


class Migration(migrations.Migration):
dependencies = [
("facility", "0370_merge_20230705_1500"),
]
operations = [
BtreeGinExtension(),
migrations.AddField(
model_name="medibasemedicine",
name="search_vector",
field=django.contrib.postgres.search.SearchVectorField(null=True),
),
migrations.AddIndex(
model_name="medibasemedicine",
index=django.contrib.postgres.indexes.GinIndex(
fields=["search_vector"], name="medibase_search_vector_idx"
),
),
migrations.RunPython(
compute_search_vector, reverse_code=migrations.RunPython.noop
),
migrations.RunSQL(
sql="""
CREATE OR REPLACE FUNCTION medibase_search_vector_trigger() RETURNS trigger AS $$
BEGIN
NEW.search_vector :=
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.name, '')), 'A') ||
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.generic, '')), 'A') ||
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.company, '')), 'C') ||
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.cims_class, '')), 'D') ||
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.contents, '')), 'D');
RETURN NEW;
END
$$ LANGUAGE plpgsql;

CREATE TRIGGER medibase_search_vector_trigger
BEFORE INSERT OR UPDATE OF name, generic, company, cims_class, contents, search_vector
ON facility_medibasemedicine
FOR EACH ROW EXECUTE FUNCTION medibase_search_vector_trigger();

UPDATE facility_medibasemedicine SET search_vector = NULL;
""",
reverse_sql="""
DROP TRIGGER IF EXISTS medibase_search_vector_trigger
ON facility_medibasemedicine;
""",
),
]
9 changes: 9 additions & 0 deletions care/facility/models/prescription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import enum

from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVectorField
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models import JSONField
Expand Down Expand Up @@ -63,9 +65,16 @@ class MedibaseMedicine(BaseModel):
cims_class = models.CharField(max_length=255, blank=True, null=True)
atc_classification = models.TextField(blank=True, null=True)

search_vector = SearchVectorField(null=True)

def __str__(self):
return " - ".join([self.name, self.generic, self.company])

class Meta:
indexes = (
GinIndex(fields=["search_vector"], name="medibase_search_vector_idx"),
)


class Prescription(BaseModel):
consultation = models.ForeignKey(
Expand Down
26 changes: 0 additions & 26 deletions care/facility/static_data/medibase.py

This file was deleted.

12 changes: 6 additions & 6 deletions care/facility/tests/test_medibase_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@

class TestMedibaseApi(TestBase):
def get_url(self, query=None):
return f"/api/v1/medibase/?query={query}"
return f"/api/v1/medicine/?search={query}"

def test_search_by_name_exact_word(self):
response = self.client.get(self.get_url(query="dolo"))
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data[0]["name"], "DOLO")
self.assertEquals(response.json()["results"][0]["name"], "DOLO")

def test_search_by_generic_exact_word(self):
response = self.client.get(self.get_url(query="pAraCetAmoL"))
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data[0]["generic"], "paracetamol")
self.assertEquals(response.json()["results"][0]["generic"], "paracetamol")

def test_search_by_name_and_generic_exact_word(self):
response = self.client.get(self.get_url(query="panadol paracetamol"))
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data[0]["name"], "PANADOL")
self.assertEquals(response.data[0]["generic"], "paracetamol")
self.assertEquals(response.data[0]["company"], "GSK")
self.assertEquals(response.json()["results"][0]["name"], "PANADOL")
self.assertEquals(response.json()["results"][0]["generic"], "paracetamol")
self.assertEquals(response.json()["results"][0]["company"], "GSK")
4 changes: 2 additions & 2 deletions config/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
from care.facility.api.viewsets.patient_sample import PatientSampleViewSet
from care.facility.api.viewsets.prescription import (
ConsultationPrescriptionViewSet,
MedibaseViewSet,
MedicineAdministrationViewSet,
MedicineViewSet,
)
from care.facility.api.viewsets.prescription_supplier import (
PrescriptionSupplierConsultationViewSet,
Expand Down Expand Up @@ -200,7 +200,7 @@
consultation_nested_router.register(
r"prescription_administration", MedicineAdministrationViewSet
)
router.register("medibase", MedibaseViewSet, basename="medibase")
router.register("medicine", MedicineViewSet)

# HCX
router.register("hcx/policy", PolicyViewSet)
Expand Down
Loading