Skip to content

Commit

Permalink
feat: add model aliases (#122)
Browse files Browse the repository at this point in the history
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Dec 20, 2024
1 parent e209d26 commit fac466e
Show file tree
Hide file tree
Showing 24 changed files with 380 additions and 133 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ Merci, avant chaque pull request, de vérifier le bon déploiement de votre API
1. Lancez l'API en local avec la commande suivante:

```bash
DEFAULT_RATE_LIMIT="30/minute" uvicorn app.main:app --port 8080 --log-level debug --reload
uvicorn app.main:app --port 8080 --log-level debug --reload
```

2. Exécutez les tests unitaires à la racine du projet

```bash
DEFAULT_RATE_LIMIT="30/minute" PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO
PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO
```

# Notebooks
Expand Down
51 changes: 46 additions & 5 deletions app/clients/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests

from app.schemas.embeddings import Embeddings
from app.schemas.chat import ChatCompletion
from app.schemas.models import Model, Models
from app.schemas.rerank import Rerank
from app.schemas.settings import Settings
Expand Down Expand Up @@ -69,15 +70,36 @@ def get_models_list(self, *args, **kwargs) -> Models:
owned_by=self.owned_by,
created=self.created,
max_context_length=self.max_context_length,
aliases=self.aliases,
type=self.type,
status=self.status,
)

return Models(data=[data])


def create_chat_completions(self, *args, **kwargs):
"""
Custom method to overwrite OpenAI's create method to raise HTTPException from model API.
"""
try:
url = f"{self.base_url}chat/completions"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(url=url, headers=headers, json=kwargs)
response.raise_for_status()
data = response.json()

return ChatCompletion(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])


# @TODO : useless ?
def create_embeddings(self, *args, **kwargs):
"""
Custom method to overwrite OpenAI's create method to raise HTTPException from model API.
"""
try:
url = f"{self.base_url}embeddings"
headers = {"Authorization": f"Bearer {self.api_key}"}
Expand All @@ -104,13 +126,17 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI
# set attributes for unavailable models
self.id = ""
self.owned_by = ""
self.aliases = []
self.created = round(number=time.time())
self.max_context_length = None

# set real attributes if model is available
self.models.list = partial(get_models_list, self)
response = self.models.list()

if self.type == LANGUAGE_MODEL_TYPE:
self.chat.completions.create = partial(create_chat_completions, self)

if self.type == EMBEDDINGS_MODEL_TYPE:
response = self.embeddings.create(model=self.id, input="hello world")
self.vector_size = len(response.data[0].embedding)
Expand Down Expand Up @@ -145,25 +171,40 @@ class ModelClients(dict):
"""

def __init__(self, settings: Settings) -> None:
for model_config in settings.models:
model = ModelClient(base_url=model_config.url, api_key=model_config.key, type=model_config.type)
self.aliases = {alias: model_id for model_id, aliases in settings.models.aliases.items() for alias in aliases}

for model_settings in settings.clients.models:
model = ModelClient(
base_url=model_settings.url,
api_key=model_settings.key,
type=model_settings.type,
)
if model.status == "unavailable":
logger.error(msg=f"unavailable model API on {model_config.url}, skipping.")
logger.error(msg=f"unavailable model API on {model_settings.url}, skipping.")
continue
try:
logger.info(msg=f"Adding model API {model_config.url} to the client...")
logger.info(msg=f"Adding model API {model_settings.url} to the client...")
self.__setitem__(key=model.id, value=model)
logger.info(msg="done.")
except Exception as e:
logger.error(msg=e)

model.aliases = settings.models.aliases.get(model.id, [])

for alias in self.aliases.keys():
assert alias not in self.keys(), "Alias is already used by another model."

assert settings.internet.default_language_model in self.keys(), "Default internet language model not found."
assert settings.internet.default_embeddings_model in self.keys(), "Default internet embeddings model not found."

def __setitem__(self, key: str, value) -> None:
if any(key == k for k in self.keys()):
if key in self.keys():
raise ValueError(f"duplicated model ID {key}, skipping.")
else:
super().__setitem__(key, value)

def __getitem__(self, key: str) -> Any:
key = self.aliases.get(key, key)
try:
item = super().__getitem__(key)
assert item.status == "available", "Model not available."
Expand Down
5 changes: 2 additions & 3 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import httpx

from app.schemas.audio import AudioTranscription
from app.schemas.settings import AUDIO_MODEL_TYPE
from app.utils.exceptions import ModelNotFoundException
from app.utils.lifespan import clients, limiter
from app.utils.security import User, check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.variables import DEFAULT_TIMEOUT, AUDIO_MODEL_TYPE

router = APIRouter()

Expand Down Expand Up @@ -135,7 +134,7 @@


@router.post("/audio/transcriptions")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
Expand Down
13 changes: 9 additions & 4 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@
from app.schemas.search import Search
from app.schemas.security import User
from app.schemas.settings import Settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE

router = APIRouter()


@router.post(path="/chat/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def chat_completions(
request: Request, body: ChatCompletionRequest, user: User = Security(dependency=check_api_key)
) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create for the API specification.
"""

client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()
body.model = client.id # replace alias by model id
url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

Expand All @@ -42,8 +47,8 @@ def retrieval_augmentation_generation(
internet_manager=InternetManager(
model_clients=clients.models,
internet_client=clients.internet,
default_language_model_id=settings.internet.args.default_language_model,
default_embeddings_model_id=settings.internet.args.default_embeddings_model,
default_language_model_id=settings.internet.default_language_model,
default_embeddings_model_id=settings.internet.default_embeddings_model,
),
)
searches = search_manager.query(
Expand Down
2 changes: 1 addition & 1 deletion app/endpoints/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.get("/chunks/{collection}/{document}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_chunks(
request: Request,
collection: UUID,
Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@router.post("/collections")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response:
"""
Create a new collection.
Expand All @@ -35,7 +35,7 @@ async def create_collection(request: Request, body: CollectionRequest, user: Use


@router.get("/collections")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]:
"""
Get list of collections.
Expand All @@ -54,7 +54,7 @@ async def get_collections(request: Request, user: User = Security(check_api_key)


@router.delete("/collections/{collection}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a collection.
Expand Down
15 changes: 10 additions & 5 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from fastapi import APIRouter, Request, Security, HTTPException
import httpx
import json

from fastapi import APIRouter, HTTPException, Request, Security
import httpx

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import DEFAULT_TIMEOUT
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE

router = APIRouter()


@router.post(path="/completions")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(dependency=check_api_key)) -> Completions:
"""
Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create for the API specification.
"""
client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()
body.model = client.id # replace alias by model id
url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.get("/documents/{collection}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def get_documents(
request: Request,
collection: UUID,
Expand All @@ -31,7 +31,7 @@ async def get_documents(


@router.delete("/documents/{collection}/{document}")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a document and relative collections.
Expand Down
9 changes: 2 additions & 7 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@router.post(path="/embeddings")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(dependency=check_api_key)) -> Embeddings:
"""
Embedding API similar to OpenAI's API.
Expand All @@ -24,19 +24,14 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec
client = clients.models[body.model]
if client.type != EMBEDDINGS_MODEL_TYPE:
raise WrongModelTypeException()

body.model = client.id # replace alias by model id
url = f"{client.base_url}embeddings"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
# try:
response.raise_for_status()
# except httpx.HTTPStatusError as e:
# if "`inputs` must have less than" in e.response.text:
# raise ContextLengthExceededException()
# raise e
data = response.json()
return Embeddings(**data)
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

@router.get("/models/{model:path}")
@router.get("/models")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]:
"""
Model API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/models/list for the API specification.
"""
if model is not None:
client = clients.models[model]
response = [row for row in client.models.list().data if row.id == model][0]
model = clients.models[model]
response = [row for row in model.models.list().data if row.id == model.id][0]
else:
response = {"object": "list", "data": []}
for model_id, client in clients.models.items():
Expand Down
11 changes: 6 additions & 5 deletions app/endpoints/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@


@router.post("/rerank")
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)):
"""
Rerank a list of inputs with a language model or reranker model.
"""
model = clients.models[body.model]

if clients.models[body.model].type == LANGUAGE_MODEL_TYPE:
reranker = LanguageModelReranker(model=clients.models[body.model])
if model.type == LANGUAGE_MODEL_TYPE:
reranker = LanguageModelReranker(model=model)
data = reranker.create(prompt=body.prompt, input=body.input)
elif clients.models[body.model].type == RERANK_MODEL_TYPE:
data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model)
elif model.type == RERANK_MODEL_TYPE:
data = model.rerank.create(prompt=body.prompt, input=body.input, model=model.id)
else:
raise WrongModelTypeException()

Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@router.post(path="/search")
@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(dependency=check_api_key)) -> Searches:
"""
Endpoint to search on the internet or with our search client.
Expand All @@ -26,8 +26,8 @@ async def search(request: Request, body: SearchRequest, user: User = Security(de
internet_manager=InternetManager(
model_clients=clients.models,
internet_client=clients.internet,
default_language_model_id=settings.internet.args.default_language_model,
default_embeddings_model_id=settings.internet.args.default_embeddings_model,
default_language_model_id=settings.internet.default_language_model,
default_embeddings_model_id=settings.internet.default_embeddings_model,
),
)

Expand Down
Loading

0 comments on commit fac466e

Please sign in to comment.