Skip to content

Commit

Permalink
feat: add rerank endpoint (#108)
Browse files Browse the repository at this point in the history
* feat(rerank): add LLM based reranker

* feat(rerank): add rerank types

* feat(rerank): prep for choicer

* feat: rebase and clean

---------

Co-authored-by: camille <[email protected]>
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent e762c05 commit a23ac80
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 17 deletions.
45 changes: 36 additions & 9 deletions app/clients/_modelclients.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from functools import partial
import json
import time
from typing import Literal, Any
from typing import Any, List, Literal

from fastapi import HTTPException
from openai import OpenAI
import requests
from fastapi import HTTPException
import json
from app.schemas.settings import Settings

from app.schemas.embeddings import Embeddings
from app.schemas.models import Model, Models
from app.utils.logging import logger
from app.schemas.rerank import Rerank
from app.schemas.settings import Settings
from app.utils.exceptions import ModelNotAvailableException, ModelNotFoundException
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT
from app.utils.logging import logger
from app.utils.variables import AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT, EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE


def get_models_list(self, *args, **kwargs) -> Models:
Expand All @@ -37,7 +39,7 @@ def get_models_list(self, *args, **kwargs) -> Models:
self.created = response.get("created", round(time.time()))
self.max_context_length = response.get("max_model_len", None)

elif self.type == EMBEDDINGS_MODEL_TYPE:
elif self.type == EMBEDDINGS_MODEL_TYPE or self.type == RERANK_MODEL_TYPE:
endpoint = str(self.base_url).replace("/v1/", "/info")
response = requests.get(url=endpoint, headers=headers, timeout=DEFAULT_TIMEOUT).json()

Expand Down Expand Up @@ -74,22 +76,25 @@ def get_models_list(self, *args, **kwargs) -> Models:
return Models(data=[data])


# @TODO : useless ?
def create_embeddings(self, *args, **kwargs):
try:
url = f"{self.base_url}embeddings"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(url=url, headers=headers, json=kwargs)
response.raise_for_status()
data = response.json()

return Embeddings(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])


class ModelClient(OpenAI):
DEFAULT_TIMEOUT = 120

def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE], *args, **kwargs) -> None:
def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE], *args, **kwargs) -> None:
"""
ModelClient class extends AsyncOpenAI class to support custom methods.
"""
Expand All @@ -111,6 +116,28 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI
self.vector_size = len(response.data[0].embedding)
self.embeddings.create = partial(create_embeddings, self)

if self.type == RERANK_MODEL_TYPE:

class RerankClient(OpenAI):
def __init__(self, model: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.model = model

def create(self, prompt: str, input: list[str], model: str) -> List[Rerank]:
assert self.model == model, "Model not found."
json = {"query": prompt, "texts": input}
url = f"{str(self.base_url).replace("/v1/", "/rerank")}"
headers = {"Authorization": f"Bearer {self.api_key}"}

response = requests.post(url=url, headers=headers, json=json, timeout=self.timeout)
response.raise_for_status()
data = response.json()
data = [Rerank(**item) for item in data]

return data

self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=self.DEFAULT_TIMEOUT)


class ModelClients(dict):
"""
Expand All @@ -127,7 +154,7 @@ def __init__(self, settings: Settings) -> None:

def __setitem__(self, key: str, value) -> None:
if any(key == k for k in self.keys()):
raise KeyError(msg=f"Model id {key} is duplicated, not allowed.")
raise KeyError(f"Model id {key} is duplicated, not allowed.")
super().__setitem__(key, value)

def __getitem__(self, key: str) -> Any:
Expand Down
31 changes: 31 additions & 0 deletions app/endpoints/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fastapi import APIRouter, Request, Security

from app.helpers import LanguageModelReranker
from app.schemas.rerank import RerankRequest, Reranks
from app.schemas.security import User
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE

from app.utils.exceptions import WrongModelTypeException

router = APIRouter()


@router.post("/rerank")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)):
"""
Rerank a list of inputs with a language model or reranker model.
"""

if clients.models[body.model].type == LANGUAGE_MODEL_TYPE:
reranker = LanguageModelReranker(model=clients.models[body.model])
data = reranker.create(prompt=body.prompt, input=body.input)
elif clients.models[body.model].type == RERANK_MODEL_TYPE:
data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model)
else:
raise WrongModelTypeException()

return Reranks(data=data)
5 changes: 3 additions & 2 deletions app/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._clientsmanager import ClientsManager
from ._fileuploader import FileUploader
from ._metricsmiddleware import MetricsMiddleware
from ._internetmanager import InternetManager
from ._languagemodelreranker import LanguageModelReranker
from ._metricsmiddleware import MetricsMiddleware
from ._searchmanager import SearchManager

__all__ = ["ClientsManager", "FileUploader", "InternetManager", "MetricsMiddleware", "SearchManager"]
__all__ = ["ClientsManager", "FileUploader", "LanguageModelReranker", "InternetManager", "MetricsMiddleware", "SearchManager"]
28 changes: 28 additions & 0 deletions app/helpers/_languagemodelreranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List
import re
from app.clients._modelclients import ModelClient
from app.schemas.rerank import Rerank


class LanguageModelReranker:
PROMPT_LLM_BASED = """Voilà un texte : {text}\n
En se basant uniquement sur ce texte, réponds 1 si ce texte peut donner des éléments de réponse à la question suivante ou 0 si aucun élément de réponse n'est présent dans le texte. Voila la question: {prompt}
Le texte n'a pas besoin de répondre parfaitement à la question, juste d'apporter des éléments de réponses et/ou de parler du même thème. Réponds uniquement 0 ou 1."""

def __init__(self, model: ModelClient) -> None:
self.model = model

def create(self, prompt: str, input: list) -> List[Rerank]:
data = list()
for index, text in enumerate(input):
content = self.PROMPT_LLM_BASED.format(prompt=prompt, text=text)

response = self.model.chat.completions.create(
messages=[{"role": "user", "content": content}], model=self.model.id, temperature=0.1, max_tokens=3, stream=False, n=1
)
result = response.choices[0].message.content
match = re.search(r"[0-1]", result)
result = int(match.group(0)) if match else 0
data.append(Rerank(score=result, index=index))

return data
7 changes: 4 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from prometheus_fastapi_instrumentator import Instrumentator
from slowapi.middleware import SlowAPIASGIMiddleware

from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, search
from app.helpers._metricsmiddleware import MetricsMiddleware
from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, rerank, search
from app.helpers import MetricsMiddleware
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.lifespan import lifespan
from app.utils.security import check_admin_api_key, check_api_key
from app.utils.settings import settings

app = FastAPI(
title=settings.app_name,
Expand Down Expand Up @@ -47,6 +47,7 @@ def health(user: User = Security(dependency=check_api_key)) -> Response:
app.include_router(router=completions.router, tags=["Core"], prefix="/v1")
app.include_router(router=embeddings.router, tags=["Core"], prefix="/v1")
app.include_router(router=audio.router, tags=["Core"], prefix="/v1")
app.include_router(router=rerank.router, tags=["Core"], prefix="/v1")

# RAG
app.include_router(router=search.router, tags=["Retrieval Augmented Generation"], prefix="/v1")
Expand Down
4 changes: 2 additions & 2 deletions app/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from openai.types import Model
from pydantic import BaseModel

from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE
from app.utils.variables import AUDIO_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE


class Model(Model):
max_context_length: Optional[int] = None
type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE]
type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE]
status: Literal["available", "unavailable"]


Expand Down
19 changes: 19 additions & 0 deletions app/schemas/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Literal

from pydantic import BaseModel


class RerankRequest(BaseModel):
prompt: str
input: List[str]
model: str


class Rerank(BaseModel):
score: float
index: int


class Reranks(BaseModel):
object: Literal["list"] = "list"
data: List[Rerank]
3 changes: 2 additions & 1 deletion app/schemas/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
EMBEDDINGS_MODEL_TYPE,
LANGUAGE_MODEL_TYPE,
AUDIO_MODEL_TYPE,
RERANK_MODEL_TYPE,
INTERNET_CLIENT_DUCKDUCKGO_TYPE,
INTERNET_CLIENT_BRAVE_TYPE,
SEARCH_CLIENT_ELASTIC_TYPE,
Expand All @@ -32,7 +33,7 @@ class Auth(ConfigBaseModel):

class Model(ConfigBaseModel):
url: str
type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE]
type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE]
key: Optional[str] = "EMPTY"


Expand Down
63 changes: 63 additions & 0 deletions app/tests/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging

import pytest

from app.schemas.rerank import Reranks
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE


@pytest.fixture(scope="module")
def setup(args, session_user):
response = session_user.get(f"{args["base_url"]}/models")
assert response.status_code == 200, f"error: retrieve models ({response.status_code})"
response_json = response.json()

LANGUAGE_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0]
logging.info(f"test model ID: {LANGUAGE_MODEL_ID}")

RERANK_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == RERANK_MODEL_TYPE][0]
logging.info(f"test model ID: {RERANK_MODEL_ID}")

EMBEDDINGS_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == EMBEDDINGS_MODEL_TYPE][0]
logging.info(f"test model ID: {EMBEDDINGS_MODEL_ID}")

yield LANGUAGE_MODEL_ID, RERANK_MODEL_ID, EMBEDDINGS_MODEL_ID


@pytest.mark.usefixtures("args", "session_user", "setup")
class TestRerank:
def test_rerank_with_language_model(self, args, session_user, setup):
"""Test the POST /rerank with a language model."""
LANGUAGE_MODEL_ID, _, _ = setup
params = {"model": LANGUAGE_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]}
response = session_user.post(f"{args["base_url"]}/rerank", json=params)
assert response.status_code == 200, f"error: rerank with language model ({response.status_code})"

response_json = response.json()
reranks = Reranks(**response_json)
assert isinstance(reranks, Reranks)

def test_rerank_with_rerank_model(self, args, session_user, setup):
"""Test the POST /rerank with a rerank model."""
_, RERANK_MODEL_ID, _ = setup
params = {"model": RERANK_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]}
response = session_user.post(f"{args["base_url"]}/rerank", json=params)
assert response.status_code == 200, f"error: rerank with rerank model ({response.status_code})"

response_json = response.json()
reranks = Reranks(**response_json)
assert isinstance(reranks, Reranks)

def test_rerank_with_wrong_model_type(self, args, session_user, setup):
"""Test the POST /rerank with a wrong model type."""
_, _, EMBEDDINGS_MODEL_ID = setup
params = {"model": EMBEDDINGS_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]}
response = session_user.post(f"{args["base_url"]}/rerank", json=params)
assert response.status_code == 422, f"error: rerank with wrong model type ({response.status_code})"

def test_rerank_with_unknown_model(self, args, session_user, setup):
"""Test the POST /rerank with an unknown model."""
_, _, _ = setup
params = {"model": "unknown", "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]}
response = session_user.post(f"{args["base_url"]}/rerank", json=params)
assert response.status_code == 404, f"error: rerank with unknown model ({response.status_code})"
1 change: 1 addition & 0 deletions app/utils/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AUDIO_MODEL_TYPE = "automatic-speech-recognition"
EMBEDDINGS_MODEL_TYPE = "text-embeddings-inference"
LANGUAGE_MODEL_TYPE = "text-generation"
RERANK_MODEL_TYPE = "text-classification"

CHUNKERS = ["LangchainRecursiveCharacterTextSplitter", "NoChunker"]
DEFAULT_CHUNKER = "LangchainRecursiveCharacterTextSplitter"
Expand Down

0 comments on commit a23ac80

Please sign in to comment.