-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(rerank): add LLM based reranker * feat(rerank): add rerank types * feat(rerank): prep for choicer * feat: rebase and clean --------- Co-authored-by: camille <[email protected]> Co-authored-by: leoguillaume <[email protected]>
- Loading branch information
1 parent
e762c05
commit a23ac80
Showing
10 changed files
with
189 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fastapi import APIRouter, Request, Security | ||
|
||
from app.helpers import LanguageModelReranker | ||
from app.schemas.rerank import RerankRequest, Reranks | ||
from app.schemas.security import User | ||
from app.utils.lifespan import clients, limiter | ||
from app.utils.security import check_api_key, check_rate_limit | ||
from app.utils.settings import settings | ||
from app.utils.variables import LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE | ||
|
||
from app.utils.exceptions import WrongModelTypeException | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/rerank") | ||
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) | ||
async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)): | ||
""" | ||
Rerank a list of inputs with a language model or reranker model. | ||
""" | ||
|
||
if clients.models[body.model].type == LANGUAGE_MODEL_TYPE: | ||
reranker = LanguageModelReranker(model=clients.models[body.model]) | ||
data = reranker.create(prompt=body.prompt, input=body.input) | ||
elif clients.models[body.model].type == RERANK_MODEL_TYPE: | ||
data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model) | ||
else: | ||
raise WrongModelTypeException() | ||
|
||
return Reranks(data=data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from ._clientsmanager import ClientsManager | ||
from ._fileuploader import FileUploader | ||
from ._metricsmiddleware import MetricsMiddleware | ||
from ._internetmanager import InternetManager | ||
from ._languagemodelreranker import LanguageModelReranker | ||
from ._metricsmiddleware import MetricsMiddleware | ||
from ._searchmanager import SearchManager | ||
|
||
__all__ = ["ClientsManager", "FileUploader", "InternetManager", "MetricsMiddleware", "SearchManager"] | ||
__all__ = ["ClientsManager", "FileUploader", "LanguageModelReranker", "InternetManager", "MetricsMiddleware", "SearchManager"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import List | ||
import re | ||
from app.clients._modelclients import ModelClient | ||
from app.schemas.rerank import Rerank | ||
|
||
|
||
class LanguageModelReranker: | ||
PROMPT_LLM_BASED = """Voilà un texte : {text}\n | ||
En se basant uniquement sur ce texte, réponds 1 si ce texte peut donner des éléments de réponse à la question suivante ou 0 si aucun élément de réponse n'est présent dans le texte. Voila la question: {prompt} | ||
Le texte n'a pas besoin de répondre parfaitement à la question, juste d'apporter des éléments de réponses et/ou de parler du même thème. Réponds uniquement 0 ou 1.""" | ||
|
||
def __init__(self, model: ModelClient) -> None: | ||
self.model = model | ||
|
||
def create(self, prompt: str, input: list) -> List[Rerank]: | ||
data = list() | ||
for index, text in enumerate(input): | ||
content = self.PROMPT_LLM_BASED.format(prompt=prompt, text=text) | ||
|
||
response = self.model.chat.completions.create( | ||
messages=[{"role": "user", "content": content}], model=self.model.id, temperature=0.1, max_tokens=3, stream=False, n=1 | ||
) | ||
result = response.choices[0].message.content | ||
match = re.search(r"[0-1]", result) | ||
result = int(match.group(0)) if match else 0 | ||
data.append(Rerank(score=result, index=index)) | ||
|
||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List, Literal | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class RerankRequest(BaseModel): | ||
prompt: str | ||
input: List[str] | ||
model: str | ||
|
||
|
||
class Rerank(BaseModel): | ||
score: float | ||
index: int | ||
|
||
|
||
class Reranks(BaseModel): | ||
object: Literal["list"] = "list" | ||
data: List[Rerank] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
|
||
import pytest | ||
|
||
from app.schemas.rerank import Reranks | ||
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, RERANK_MODEL_TYPE | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def setup(args, session_user): | ||
response = session_user.get(f"{args["base_url"]}/models") | ||
assert response.status_code == 200, f"error: retrieve models ({response.status_code})" | ||
response_json = response.json() | ||
|
||
LANGUAGE_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0] | ||
logging.info(f"test model ID: {LANGUAGE_MODEL_ID}") | ||
|
||
RERANK_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == RERANK_MODEL_TYPE][0] | ||
logging.info(f"test model ID: {RERANK_MODEL_ID}") | ||
|
||
EMBEDDINGS_MODEL_ID = [model["id"] for model in response_json["data"] if model["type"] == EMBEDDINGS_MODEL_TYPE][0] | ||
logging.info(f"test model ID: {EMBEDDINGS_MODEL_ID}") | ||
|
||
yield LANGUAGE_MODEL_ID, RERANK_MODEL_ID, EMBEDDINGS_MODEL_ID | ||
|
||
|
||
@pytest.mark.usefixtures("args", "session_user", "setup") | ||
class TestRerank: | ||
def test_rerank_with_language_model(self, args, session_user, setup): | ||
"""Test the POST /rerank with a language model.""" | ||
LANGUAGE_MODEL_ID, _, _ = setup | ||
params = {"model": LANGUAGE_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} | ||
response = session_user.post(f"{args["base_url"]}/rerank", json=params) | ||
assert response.status_code == 200, f"error: rerank with language model ({response.status_code})" | ||
|
||
response_json = response.json() | ||
reranks = Reranks(**response_json) | ||
assert isinstance(reranks, Reranks) | ||
|
||
def test_rerank_with_rerank_model(self, args, session_user, setup): | ||
"""Test the POST /rerank with a rerank model.""" | ||
_, RERANK_MODEL_ID, _ = setup | ||
params = {"model": RERANK_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} | ||
response = session_user.post(f"{args["base_url"]}/rerank", json=params) | ||
assert response.status_code == 200, f"error: rerank with rerank model ({response.status_code})" | ||
|
||
response_json = response.json() | ||
reranks = Reranks(**response_json) | ||
assert isinstance(reranks, Reranks) | ||
|
||
def test_rerank_with_wrong_model_type(self, args, session_user, setup): | ||
"""Test the POST /rerank with a wrong model type.""" | ||
_, _, EMBEDDINGS_MODEL_ID = setup | ||
params = {"model": EMBEDDINGS_MODEL_ID, "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} | ||
response = session_user.post(f"{args["base_url"]}/rerank", json=params) | ||
assert response.status_code == 422, f"error: rerank with wrong model type ({response.status_code})" | ||
|
||
def test_rerank_with_unknown_model(self, args, session_user, setup): | ||
"""Test the POST /rerank with an unknown model.""" | ||
_, _, _ = setup | ||
params = {"model": "unknown", "prompt": "Sort these sentences by relevance.", "input": ["Sentence 1", "Sentence 2", "Sentence 3"]} | ||
response = session_user.post(f"{args["base_url"]}/rerank", json=params) | ||
assert response.status_code == 404, f"error: rerank with unknown model ({response.status_code})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters