From a23ac807ff3cbd2b2ba250bd386534a19a11be8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Guillaume?= <62661249+leoguillaume@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:34:16 +0100 Subject: [PATCH] feat: add rerank endpoint (#108) * feat(rerank): add LLM based reranker * feat(rerank): add rerank types * feat(rerank): prep for choicer * feat: rebase and clean --------- Co-authored-by: camille Co-authored-by: leoguillaume --- app/clients/_modelclients.py | 45 +++++++++++++++---- app/endpoints/rerank.py | 31 +++++++++++++ app/helpers/__init__.py | 5 ++- app/helpers/_languagemodelreranker.py | 28 ++++++++++++ app/main.py | 7 +-- app/schemas/models.py | 4 +- app/schemas/rerank.py | 19 ++++++++ app/schemas/settings.py | 3 +- app/tests/test_rerank.py | 63 +++++++++++++++++++++++++++ app/utils/variables.py | 1 + 10 files changed, 189 insertions(+), 17 deletions(-) create mode 100644 app/endpoints/rerank.py create mode 100644 app/helpers/_languagemodelreranker.py create mode 100644 app/schemas/rerank.py create mode 100644 app/tests/test_rerank.py diff --git a/app/clients/_modelclients.py b/app/clients/_modelclients.py index bd88c703..50f7f4f5 100644 --- a/app/clients/_modelclients.py +++ b/app/clients/_modelclients.py @@ -1,17 +1,19 @@ from functools import partial +import json import time -from typing import Literal, Any +from typing import Any, List, Literal +from fastapi import HTTPException from openai import OpenAI import requests -from fastapi import HTTPException -import json -from app.schemas.settings import Settings + from app.schemas.embeddings import Embeddings from app.schemas.models import Model, Models -from app.utils.logging import logger +from app.schemas.rerank import Rerank +from app.schemas.settings import Settings from app.utils.exceptions import ModelNotAvailableException, ModelNotFoundException -from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT +from app.utils.logging import logger +from app.utils.variables import AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT, EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE def get_models_list(self, *args, **kwargs) -> Models: @@ -37,7 +39,7 @@ def get_models_list(self, *args, **kwargs) -> Models: self.created = response.get("created", round(time.time())) self.max_context_length = response.get("max_model_len", None) - elif self.type == EMBEDDINGS_MODEL_TYPE: + elif self.type == EMBEDDINGS_MODEL_TYPE or self.type == RERANK_MODEL_TYPE: endpoint = str(self.base_url).replace("/v1/", "/info") response = requests.get(url=endpoint, headers=headers, timeout=DEFAULT_TIMEOUT).json() @@ -74,6 +76,7 @@ def get_models_list(self, *args, **kwargs) -> Models: return Models(data=[data]) +# @TODO : useless ? def create_embeddings(self, *args, **kwargs): try: url = f"{self.base_url}embeddings" @@ -81,7 +84,9 @@ def create_embeddings(self, *args, **kwargs): response = requests.post(url=url, headers=headers, json=kwargs) response.raise_for_status() data = response.json() + return Embeddings(**data) + except Exception as e: raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"]) @@ -89,7 +94,7 @@ def create_embeddings(self, *args, **kwargs): class ModelClient(OpenAI): DEFAULT_TIMEOUT = 120 - def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE], *args, **kwargs) -> None: + def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE], *args, **kwargs) -> None: """ ModelClient class extends AsyncOpenAI class to support custom methods. """ @@ -111,6 +116,28 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI self.vector_size = len(response.data[0].embedding) self.embeddings.create = partial(create_embeddings, self) + if self.type == RERANK_MODEL_TYPE: + + class RerankClient(OpenAI): + def __init__(self, model: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.model = model + + def create(self, prompt: str, input: list[str], model: str) -> List[Rerank]: + assert self.model == model, "Model not found." + json = {"query": prompt, "texts": input} + url = f"{str(self.base_url).replace("/v1/", "/rerank")}" + headers = {"Authorization": f"Bearer {self.api_key}"} + + response = requests.post(url=url, headers=headers, json=json, timeout=self.timeout) + response.raise_for_status() + data = response.json() + data = [Rerank(**item) for item in data] + + return data + + self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=self.DEFAULT_TIMEOUT) + class ModelClients(dict): """ @@ -127,7 +154,7 @@ def __init__(self, settings: Settings) -> None: def __setitem__(self, key: str, value) -> None: if any(key == k for k in self.keys()): - raise KeyError(msg=f"Model id {key} is duplicated, not allowed.") + raise KeyError(f"Model id {key} is duplicated, not allowed.") super().__setitem__(key, value) def __getitem__(self, key: str) -> Any: diff --git a/app/endpoints/rerank.py b/app/endpoints/rerank.py new file mode 100644 index 00000000..010114b6 --- /dev/null +++ b/app/endpoints/rerank.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter, Request, Security + +from app.helpers import LanguageModelReranker +from app.schemas.rerank import RerankRequest, Reranks +from app.schemas.security import User +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit +from app.utils.settings import settings +from app.utils.variables import LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE + +from app.utils.exceptions import WrongModelTypeException + +router = APIRouter() + + +@router.post("/rerank") +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)): + """ + Rerank a list of inputs with a language model or reranker model. + """ + + if clients.models[body.model].type == LANGUAGE_MODEL_TYPE: + reranker = LanguageModelReranker(model=clients.models[body.model]) + data = reranker.create(prompt=body.prompt, input=body.input) + elif clients.models[body.model].type == RERANK_MODEL_TYPE: + data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model) + else: + raise WrongModelTypeException() + + return Reranks(data=data) diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py index 1add6d77..1adfe8c6 100644 --- a/app/helpers/__init__.py +++ b/app/helpers/__init__.py @@ -1,7 +1,8 @@ from ._clientsmanager import ClientsManager from ._fileuploader import FileUploader -from ._metricsmiddleware import MetricsMiddleware from ._internetmanager import InternetManager +from ._languagemodelreranker import LanguageModelReranker +from ._metricsmiddleware import MetricsMiddleware from ._searchmanager import SearchManager -__all__ = ["ClientsManager", "FileUploader", "InternetManager", "MetricsMiddleware", "SearchManager"] +__all__ = ["ClientsManager", "FileUploader", "LanguageModelReranker", "InternetManager", "MetricsMiddleware", "SearchManager"] diff --git a/app/helpers/_languagemodelreranker.py b/app/helpers/_languagemodelreranker.py new file mode 100644 index 00000000..6f37d966 --- /dev/null +++ b/app/helpers/_languagemodelreranker.py @@ -0,0 +1,28 @@ +from typing import List +import re +from app.clients._modelclients import ModelClient +from app.schemas.rerank import Rerank + + +class LanguageModelReranker: + PROMPT_LLM_BASED = """Voilà un texte : {text}\n +En se basant uniquement sur ce texte, réponds 1 si ce texte peut donner des éléments de réponse à la question suivante ou 0 si aucun élément de réponse n'est présent dans le texte. Voila la question: {prompt} +Le texte n'a pas besoin de répondre parfaitement à la question, juste d'apporter des éléments de réponses et/ou de parler du même thème. Réponds uniquement 0 ou 1.""" + + def __init__(self, model: ModelClient) -> None: + self.model = model + + def create(self, prompt: str, input: list) -> List[Rerank]: + data = list() + for index, text in enumerate(input): + content = self.PROMPT_LLM_BASED.format(prompt=prompt, text=text) + + response = self.model.chat.completions.create( + messages=[{"role": "user", "content": content}], model=self.model.id, temperature=0.1, max_tokens=3, stream=False, n=1 + ) + result = response.choices[0].message.content + match = re.search(r"[0-1]", result) + result = int(match.group(0)) if match else 0 + data.append(Rerank(score=result, index=index)) + + return data diff --git a/app/main.py b/app/main.py index a9eade2a..dcd2ed55 100644 --- a/app/main.py +++ b/app/main.py @@ -2,12 +2,12 @@ from prometheus_fastapi_instrumentator import Instrumentator from slowapi.middleware import SlowAPIASGIMiddleware -from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, search -from app.helpers._metricsmiddleware import MetricsMiddleware +from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, rerank, search +from app.helpers import MetricsMiddleware from app.schemas.security import User -from app.utils.settings import settings from app.utils.lifespan import lifespan from app.utils.security import check_admin_api_key, check_api_key +from app.utils.settings import settings app = FastAPI( title=settings.app_name, @@ -47,6 +47,7 @@ def health(user: User = Security(dependency=check_api_key)) -> Response: app.include_router(router=completions.router, tags=["Core"], prefix="/v1") app.include_router(router=embeddings.router, tags=["Core"], prefix="/v1") app.include_router(router=audio.router, tags=["Core"], prefix="/v1") +app.include_router(router=rerank.router, tags=["Core"], prefix="/v1") # RAG app.include_router(router=search.router, tags=["Retrieval Augmented Generation"], prefix="/v1") diff --git a/app/schemas/models.py b/app/schemas/models.py index 690319e4..5458fca5 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -3,12 +3,12 @@ from openai.types import Model from pydantic import BaseModel -from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE +from app.utils.variables import AUDIO_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE class Model(Model): max_context_length: Optional[int] = None - type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE] + type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE] status: Literal["available", "unavailable"] diff --git a/app/schemas/rerank.py b/app/schemas/rerank.py new file mode 100644 index 00000000..6b97c292 --- /dev/null +++ b/app/schemas/rerank.py @@ -0,0 +1,19 @@ +from typing import List, Literal + +from pydantic import BaseModel + + +class RerankRequest(BaseModel): + prompt: str + input: List[str] + model: str + + +class Rerank(BaseModel): + score: float + index: int + + +class Reranks(BaseModel): + object: Literal["list"] = "list" + data: List[Rerank] diff --git a/app/schemas/settings.py b/app/schemas/settings.py index 62ae5e6e..dc5c3fdd 100644 --- a/app/schemas/settings.py +++ b/app/schemas/settings.py @@ -9,6 +9,7 @@ EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, + RERANK_MODEL_TYPE, INTERNET_CLIENT_DUCKDUCKGO_TYPE, INTERNET_CLIENT_BRAVE_TYPE, SEARCH_CLIENT_ELASTIC_TYPE, @@ -32,7 +33,7 @@ class Auth(ConfigBaseModel): class Model(ConfigBaseModel): url: str - type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE] + type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE] key: Optional[str] = "EMPTY" diff --git a/app/tests/test_rerank.py b/app/tests/test_rerank.py new file mode 100644 index 00000000..c2e85093 --- /dev/null +++ b/app/tests/test_rerank.py @@ -0,0 +1,63 @@ +import logging + +import pytest + +from app.schemas.rerank import Reranks +from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE + + +@pytest.fixture(scope="module") +def setup(args, session_user): + response = session_user.get(f"{args["base_url"]}/models") + assert response.status_code == 200, f"error: retrieve models ({response.status_code})" + response_json = response.json() + + LANGUAGE_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0] + logging.info(f"test model ID: {LANGUAGE_MODEL_ID}") + + RERANK_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == RERANK_MODEL_TYPE][0] + logging.info(f"test model ID: {RERANK_MODEL_ID}") + + EMBEDDINGS_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == EMBEDDINGS_MODEL_TYPE][0] + logging.info(f"test model ID: {EMBEDDINGS_MODEL_ID}") + + yield LANGUAGE_MODEL_ID, RERANK_MODEL_ID, EMBEDDINGS_MODEL_ID + + +@pytest.mark.usefixtures("args", "session_user", "setup") +class TestRerank: + def test_rerank_with_language_model(self, args, session_user, setup): + """Test the POST /rerank with a language model.""" + LANGUAGE_MODEL_ID, _, _ = setup + params = {"model": LANGUAGE_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} + response = session_user.post(f"{args["base_url"]}/rerank", json=params) + assert response.status_code == 200, f"error: rerank with language model ({response.status_code})" + + response_json = response.json() + reranks = Reranks(**response_json) + assert isinstance(reranks, Reranks) + + def test_rerank_with_rerank_model(self, args, session_user, setup): + """Test the POST /rerank with a rerank model.""" + _, RERANK_MODEL_ID, _ = setup + params = {"model": RERANK_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} + response = session_user.post(f"{args["base_url"]}/rerank", json=params) + assert response.status_code == 200, f"error: rerank with rerank model ({response.status_code})" + + response_json = response.json() + reranks = Reranks(**response_json) + assert isinstance(reranks, Reranks) + + def test_rerank_with_wrong_model_type(self, args, session_user, setup): + """Test the POST /rerank with a wrong model type.""" + _, _, EMBEDDINGS_MODEL_ID = setup + params = {"model": EMBEDDINGS_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} + response = session_user.post(f"{args["base_url"]}/rerank", json=params) + assert response.status_code == 422, f"error: rerank with wrong model type ({response.status_code})" + + def test_rerank_with_unknown_model(self, args, session_user, setup): + """Test the POST /rerank with an unknown model.""" + _, _, _ = setup + params = {"model": "unknown", "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} + response = session_user.post(f"{args["base_url"]}/rerank", json=params) + assert response.status_code == 404, f"error: rerank with unknown model ({response.status_code})" diff --git a/app/utils/variables.py b/app/utils/variables.py index c4078b44..3a417be8 100644 --- a/app/utils/variables.py +++ b/app/utils/variables.py @@ -12,6 +12,7 @@ AUDIO_MODEL_TYPE = "automatic-speech-recognition" EMBEDDINGS_MODEL_TYPE = "text-embeddings-inference" LANGUAGE_MODEL_TYPE = "text-generation" +RERANK_MODEL_TYPE = "text-classification" CHUNKERS = ["LangchainRecursiveCharacterTextSplitter", "NoChunker"] DEFAULT_CHUNKER = "LangchainRecursiveCharacterTextSplitter"