diff --git a/app/endpoints/files.py b/app/endpoints/files.py index 713f118a..7b6de908 100644 --- a/app/endpoints/files.py +++ b/app/endpoints/files.py @@ -5,6 +5,7 @@ from app.schemas.security import User from app.utils.lifespan import clients from app.utils.security import check_api_key +from app.utils.exceptions import FileSizeLimitExceededException router = APIRouter() @@ -20,10 +21,13 @@ async def upload_file(file: UploadFile = File(...), request: FilesRequest = Body For JSON, file structure like a list of documents: [{"text": "hello world", "title": "my document", "metadata": {"autor": "me"}}, ...]} or [{"text": "hello world", "title": "my document"}, ...]} Each document must have a "text" and "title" keys and "metadata" key (optional) with dict type value. - html: Hypertext Markup Language file. - - Max file size is 10MB. """ + file_size = len(file.file.read()) + if file_size > FileSizeLimitExceededException.MAX_CONTENT_SIZE: + raise FileSizeLimitExceededException() + file.file.seek(0) # reset file pointer to the beginning of the file + if request.chunker: chunker_args = request.chunker.args.model_dump() if request.chunker.args else ChunkerArgs().model_dump() chunker_name = request.chunker.name diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py index 0b1cd6cd..9d998e07 100644 --- a/app/helpers/__init__.py +++ b/app/helpers/__init__.py @@ -1,9 +1,9 @@ from ._authenticationclient import AuthenticationClient from ._clientsmanager import ClientsManager -from ._contentsizelimitmiddleware import ContentSizeLimitMiddleware from ._fileuploader import FileUploader from ._internetclient import InternetClient from ._modelclients import ModelClients from .searchclients import SearchClient +from ._metricsmiddleware import MetricsMiddleware -__all__ = ["AuthenticationClient", "ClientsManager", "ContentSizeLimitMiddleware", "FileUploader", "InternetClient", "ModelClients", "SearchClient"] +__all__ = ["AuthenticationClient", "ClientsManager", "ContentSizeLimitMiddleware", "FileUploader", "InternetClient", "ModelClients", "SearchClient", "MetricsMiddleware"] diff --git a/app/helpers/_authenticationclient.py b/app/helpers/_authenticationclient.py index fe44f712..b3ef3baf 100644 --- a/app/helpers/_authenticationclient.py +++ b/app/helpers/_authenticationclient.py @@ -1,23 +1,21 @@ +import base64 import datetime as dt +import hashlib import json from typing import Optional import uuid +from typing import Any, Callable from grist_api import GristDocAPI from redis import Redis -from app.utils.variables import ROLE_LEVEL_0, ROLE_LEVEL_1, ROLE_LEVEL_2 +from app.schemas.security import Role, User class AuthenticationClient(GristDocAPI): CACHE_EXPIRATION = 3600 # 1h - ROLE_DICT = { - "user": ROLE_LEVEL_0, - "client": ROLE_LEVEL_1, - "admin": ROLE_LEVEL_2, - } - def __init__(self, cache: Redis, table_id: str, *args, **kwargs): + def __init__(self, cache: Redis, table_id: str, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.session_id = str(uuid.uuid4()) self.table_id = table_id @@ -35,14 +33,14 @@ def check_api_key(self, key: str) -> Optional[str]: """ keys = self._get_api_keys() if key in keys: - return keys[key] + return User(id=self._api_key_to_user_id(input=key), role=Role[keys[key]["role"]], name=keys[key]["name"]) - def cache(func): + def cache(func) -> Callable[..., Any]: """ Decorator to cache the result of a function in Redis. """ - def wrapper(self): + def wrapper(self) -> Any: key = f"auth-{self.session_id}" result = self.redis.get(key) if result: @@ -56,18 +54,40 @@ def wrapper(self): return wrapper @cache - def _get_api_keys(self): + def _get_api_keys(self) -> dict: """ Get all keys from a table in the Grist document. Returns: dict: dictionary of keys and their corresponding access level """ - records = self.fetch_table(self.table_id) + records = self.fetch_table(table_name=self.table_id) keys = dict() for record in records: if record.EXPIRATION > dt.datetime.now().timestamp(): - keys[record.KEY] = self.ROLE_DICT.get(record.ROLE, ROLE_LEVEL_0) + keys[record.KEY] = { + "id": self._api_key_to_user_id(input=record.KEY), + "role": Role.get(name=record.ROLE.upper(), default=Role.USER)._name_, + "name": record.USER, + } return keys + + @staticmethod + def _api_key_to_user_id(input: str) -> str: + """ + Generate a 16 length unique code from an input string using salted SHA-256 hashing. + + Args: + input_string (str): The input string to generate the code from. + + Returns: + tuple[str, bytes]: A tuple containing the generated code and the salt used. + """ + hash = hashlib.sha256((input).encode()).digest() + hash = base64.urlsafe_b64encode(hash).decode() + # remove special characters and limit length + hash = "".join(c for c in hash if c.isalnum())[:16].lower() + + return hash diff --git a/app/helpers/_contentsizelimitmiddleware.py b/app/helpers/_contentsizelimitmiddleware.py deleted file mode 100644 index 9ccf4acf..00000000 --- a/app/helpers/_contentsizelimitmiddleware.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional - -from app.utils.exceptions import FileSizeLimitExceededException - - -class ContentSizeLimitMiddleware: - """ - Content size limiting middleware for ASGI applications - - Args: - app (ASGI application): ASGI application - max_content_size (optional): the maximum content size allowed in bytes, default is MAX_CONTENT_SIZE - """ - - MAX_CONTENT_SIZE = 20 * 1024 * 1024 # 20MB - - def __init__(self, app, max_content_size: Optional[int] = None): - self.app = app - self.max_content_size = max_content_size or self.MAX_CONTENT_SIZE - - def receive_wrapper(self, receive): - received = 0 - - async def inner(): - nonlocal received - message = await receive() - if message["type"] != "http.request" or self.max_content_size is None: - return message - body_len = len(message.get("body", b"")) - received += body_len - if received > self.max_content_size: - raise FileSizeLimitExceededException() - - return message - - return inner - - async def __call__(self, scope, receive, send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - wrapper = self.receive_wrapper(receive) - await self.app(scope, wrapper, send) diff --git a/app/helpers/_metricsmiddleware.py b/app/helpers/_metricsmiddleware.py new file mode 100644 index 00000000..2ae3568a --- /dev/null +++ b/app/helpers/_metricsmiddleware.py @@ -0,0 +1,39 @@ +import json + +from prometheus_client import Counter + +from app.helpers._authenticationclient import AuthenticationClient +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + +class MetricsMiddleware(BaseHTTPMiddleware): + # TODO: add audio endpoint (support for multipart/form-data) + MODELS_ENDPOINTS = ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"] + http_requests_by_user = Counter( + name="http_requests_by_user_and_endpoint", + documentation="Number of HTTP requests by user and endpoint", + labelnames=["user", "endpoint", "model"], + ) + + async def dispatch(self, request: Request, call_next) -> Response: + endpoint = request.url.path + content_type = request.headers.get("Content-Type", "") + + if endpoint.startswith("/v1"): + authorization = request.headers.get("Authorization") + model = None + if not content_type.startswith("multipart/form-data"): + body = await request.body() + body = body.decode(encoding="utf-8") + model = json.loads(body).get("model") if body else None + + user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1]) + + if authorization and authorization.startswith("Bearer "): + user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1]) + self.http_requests_by_user.labels(user=user_id, endpoint=endpoint[3:], model=model).inc() + + response = await call_next(request) + + return response diff --git a/app/helpers/searchclients/_elasticsearchclient.py b/app/helpers/searchclients/_elasticsearchclient.py index c2d33c32..3f85a4c7 100644 --- a/app/helpers/searchclients/_elasticsearchclient.py +++ b/app/helpers/searchclients/_elasticsearchclient.py @@ -9,13 +9,14 @@ from app.schemas.collections import Collection from app.schemas.documents import Document from app.schemas.chunks import Chunk +from app.schemas.security import Role from app.schemas.security import User from app.schemas.search import Filter, Search from app.utils.exceptions import ( DifferentCollectionsModelsException, - WrongCollectionTypeException, WrongModelTypeException, CollectionNotFoundException, + InsufficientRightsException, SearchMethodNotAvailableException, ) from app.utils.variables import ( @@ -23,7 +24,6 @@ HYBRID_SEARCH_TYPE, LEXICAL_SEARCH_TYPE, SEMANTIC_SEARCH_TYPE, - ROLE_LEVEL_2, PUBLIC_COLLECTION_TYPE, PRIVATE_COLLECTION_TYPE, ) @@ -71,8 +71,8 @@ def __init__(self, models: List[str] = None, hybrid_limit_factor: float = 1.5, * def upsert(self, chunks: List[Chunk], collection_id: str, user: Optional[User] = None) -> None: collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() for i in range(0, len(chunks), self.BATCH_SIZE): batched_chunks = chunks[i : i + self.BATCH_SIZE] @@ -168,8 +168,8 @@ def create_collection( if self.models[collection_model].type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() - if user.role != ROLE_LEVEL_2 and collection_type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection_type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() settings = { "similarity": {"default": {"type": "BM25"}}, @@ -224,8 +224,8 @@ def delete_collection(self, collection_id: str, user: User) -> None: """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() self.indices.delete(index=collection_id, ignore_unavailable=True) @@ -279,8 +279,8 @@ def delete_document(self, collection_id: str, document_id: str, user: Optional[U """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() # delete chunks body = {"query": {"match": {"metadata.document_id": document_id}}} diff --git a/app/helpers/searchclients/_qdrantsearchclient.py b/app/helpers/searchclients/_qdrantsearchclient.py index bd6e434d..d83e07f0 100644 --- a/app/helpers/searchclients/_qdrantsearchclient.py +++ b/app/helpers/searchclients/_qdrantsearchclient.py @@ -20,20 +20,20 @@ from app.schemas.collections import Collection from app.schemas.documents import Document from app.schemas.search import Search +from app.schemas.security import Role from app.schemas.security import User from app.utils.exceptions import ( CollectionNotFoundException, DifferentCollectionsModelsException, SearchMethodNotAvailableException, - WrongCollectionTypeException, WrongModelTypeException, + InsufficientRightsException, ) from app.utils.variables import ( EMBEDDINGS_MODEL_TYPE, LEXICAL_SEARCH_TYPE, HYBRID_SEARCH_TYPE, PUBLIC_COLLECTION_TYPE, - ROLE_LEVEL_2, SEMANTIC_SEARCH_TYPE, ) @@ -59,8 +59,8 @@ def upsert(self, chunks: List[Chunk], collection_id: str, user: User) -> None: """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() for i in range(0, len(chunks), self.BATCH_SIZE): batch = chunks[i : i + self.BATCH_SIZE] @@ -214,8 +214,8 @@ def create_collection( if self.models[collection_model].type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() - if user.role != ROLE_LEVEL_2 and collection_type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection_type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() # create metadata metadata = { @@ -241,8 +241,8 @@ def delete_collection(self, collection_id: str, user: User) -> None: """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() super().delete_collection(collection_name=collection.id) super().delete(collection_name=self.METADATA_COLLECTION_ID, points_selector=PointIdsList(points=[collection.id])) @@ -294,8 +294,8 @@ def delete_document(self, collection_id: str, document_id: str, user: User): """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: - raise WrongCollectionTypeException() + if user.role != Role.ADMIN and collection.type == PUBLIC_COLLECTION_TYPE: + raise InsufficientRightsException() # delete chunks filter = Filter(must=[FieldCondition(key="metadata.document_id", match=MatchAny(any=[document_id]))]) diff --git a/app/main.py b/app/main.py index 3fb2158b..a9eade2a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,13 @@ -from fastapi import FastAPI, Response, Security - - +from fastapi import Depends, FastAPI, Response, Security +from prometheus_fastapi_instrumentator import Instrumentator from slowapi.middleware import SlowAPIASGIMiddleware from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, search -from app.helpers import ContentSizeLimitMiddleware +from app.helpers._metricsmiddleware import MetricsMiddleware from app.schemas.security import User from app.utils.settings import settings from app.utils.lifespan import lifespan -from app.utils.security import check_api_key - +from app.utils.security import check_admin_api_key, check_api_key app = FastAPI( title=settings.app_name, @@ -22,9 +20,13 @@ redoc_url="/documentation", ) +# Prometheus metrics +# @TODO: env_var_name="ENABLE_METRICS" +app.instrumentator = Instrumentator().instrument(app=app) + # Middlewares -app.add_middleware(middleware_class=ContentSizeLimitMiddleware) app.add_middleware(middleware_class=SlowAPIASGIMiddleware) +app.add_middleware(middleware_class=MetricsMiddleware) # Monitoring @@ -37,6 +39,8 @@ def health(user: User = Security(dependency=check_api_key)) -> Response: return Response(status_code=200) +app.instrumentator.expose(app=app, should_gzip=True, tags=["Monitoring"], dependencies=[Depends(dependency=check_admin_api_key)]) + # Core app.include_router(router=models.router, tags=["Core"], prefix="/v1") app.include_router(router=chat.router, tags=["Core"], prefix="/v1") diff --git a/app/schemas/security.py b/app/schemas/security.py index cbf3068e..1c7c57bf 100644 --- a/app/schemas/security.py +++ b/app/schemas/security.py @@ -1,6 +1,23 @@ +from enum import Enum +from typing import Any, Optional + from pydantic import BaseModel +class Role(Enum): + USER = 0 + CLIENT = 1 + ADMIN = 2 + + @classmethod + def get(cls, name: str, default=None) -> Enum | Any: + try: + return cls.__getitem__(name=name) + except KeyError: + return default + + class User(BaseModel): id: str - role: int + name: Optional[str] = None + role: Role diff --git a/app/tests/conftest.py b/app/tests/conftest.py index bf7fc0dc..1ffbb679 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -4,7 +4,7 @@ import pytest import requests -from app.utils.security import encode_string +from app.helpers._authenticationclient import AuthenticationClient from app.utils.variables import PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE @@ -45,8 +45,8 @@ def session_admin(args): @pytest.fixture(scope="session") def cleanup_collections(args, session_user, session_admin): - USER = encode_string(input=args["api_key_user"]) - ADMIN = encode_string(input=args["api_key_admin"]) + USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"]) + ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"]) yield USER, ADMIN diff --git a/app/tests/test_collections.py b/app/tests/test_collections.py index 229d99ce..1117bac2 100644 --- a/app/tests/test_collections.py +++ b/app/tests/test_collections.py @@ -3,10 +3,9 @@ import pytest from app.schemas.collections import Collection, Collections -from app.utils.security import encode_string +from app.helpers._authenticationclient import AuthenticationClient from app.utils.variables import ( EMBEDDINGS_MODEL_TYPE, - INTERNET_COLLECTION_DISPLAY_ID, LANGUAGE_MODEL_TYPE, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE, @@ -15,8 +14,8 @@ @pytest.fixture(scope="module") def setup(args, session_user): - USER = encode_string(input=args["api_key_user"]) - ADMIN = encode_string(input=args["api_key_admin"]) + USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"]) + ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"]) logging.info(f"test user ID: {USER}") logging.info(f"test admin ID: {ADMIN}") @@ -48,7 +47,7 @@ def test_create_public_collection_with_user(self, args, session_user, setup): params = {"name": PUBLIC_COLLECTION_NAME, "model": EMBEDDINGS_MODEL_ID, "type": PUBLIC_COLLECTION_TYPE} response = session_user.post(f"{args["base_url"]}/collections", json=params) - assert response.status_code == 422 + assert response.status_code == 403 def test_create_public_collection_with_admin(self, args, session_admin, setup): PUBLIC_COLLECTION_NAME, _, _, _, EMBEDDINGS_MODEL_ID, _ = setup @@ -117,7 +116,7 @@ def test_delete_public_collection_with_user(self, args, session_user, setup): response = session_user.get(f"{args["base_url"]}/collections") collection_id = [collection["id"] for collection in response.json()["data"] if collection["name"] == PUBLIC_COLLECTION_NAME][0] response = session_user.delete(f"{args["base_url"]}/collections/{collection_id}") - assert response.status_code == 422 + assert response.status_code == 403 def test_delete_public_collection_with_admin(self, args, session_admin, setup): PUBLIC_COLLECTION_NAME, _, _, _, _, _ = setup @@ -127,13 +126,6 @@ def test_delete_public_collection_with_admin(self, args, session_admin, setup): response = session_admin.delete(f"{args["base_url"]}/collections/{collection_id}") assert response.status_code == 204 - def test_create_internet_collection_with_user(self, args, session_user, setup): - _, _, _, _, EMBEDDINGS_MODEL_ID, _ = setup - - params = {"name": INTERNET_COLLECTION_DISPLAY_ID, "model": EMBEDDINGS_MODEL_ID, "type": PUBLIC_COLLECTION_TYPE} - response = session_user.post(f"{args["base_url"]}/collections", json=params) - assert response.status_code == 422 - def test_create_collection_with_empty_name(self, args, session_user, setup): _, _, _, _, EMBEDDINGS_MODEL_ID, _ = setup diff --git a/app/tests/test_files.py b/app/tests/test_files.py index 2d1d9a1c..d8b94da0 100644 --- a/app/tests/test_files.py +++ b/app/tests/test_files.py @@ -111,4 +111,4 @@ def test_upload_in_public_collection_with_user(self, args, session_user, setup): files = {"file": (os.path.basename(file_path), open(file_path, "rb"), "application/pdf")} data = {"request": '{"collection": "%s"}' % PUBLIC_COLLECTION_ID} response = session_user.post(f"{args["base_url"]}/files", data=data, files=files) - assert response.status_code == 422, f"error: upload file ({response.status_code} - {response.text})" + assert response.status_code == 403, f"error: upload file ({response.status_code} - {response.text})" diff --git a/app/utils/exceptions.py b/app/utils/exceptions.py index 86b098ab..618634bb 100644 --- a/app/utils/exceptions.py +++ b/app/utils/exceptions.py @@ -3,17 +3,17 @@ # 400 class ParsingFileFailedException(HTTPException): - def __init__(self, detail: str = "Parsing file failed."): + def __init__(self, detail: str = "Parsing file failed.") -> None: super().__init__(status_code=400, detail=detail) class NoChunksToUpsertException(HTTPException): - def __init__(self, detail: str = "No chunks to upsert."): + def __init__(self, detail: str = "No chunks to upsert.") -> None: super().__init__(status_code=400, detail=detail) class ModelNotAvailableException(HTTPException): - def __init__(self, detail: str = "Model not available."): + def __init__(self, detail: str = "Model not available.") -> None: super().__init__(status_code=400, detail=detail) @@ -24,63 +24,65 @@ def __init__(self, detail: str = "Method not available."): # 403 class InvalidAuthenticationSchemeException(HTTPException): - def __init__(self, detail: str = "Invalid authentication scheme."): + def __init__(self, detail: str = "Invalid authentication scheme.") -> None: super().__init__(status_code=403, detail=detail) class InvalidAPIKeyException(HTTPException): - def __init__(self, detail: str = "Invalid API key."): + def __init__(self, detail: str = "Invalid API key.") -> None: + super().__init__(status_code=403, detail=detail) + + +class InsufficientRightsException(HTTPException): + def __init__(self, detail: str = "Insufficient rights.") -> None: super().__init__(status_code=403, detail=detail) # 404 class CollectionNotFoundException(HTTPException): - def __init__(self, detail: str = "Collection not found."): + def __init__(self, detail: str = "Collection not found.") -> None: super().__init__(status_code=404, detail=detail) class ModelNotFoundException(HTTPException): - def __init__(self, detail: str = "Model not found."): + def __init__(self, detail: str = "Model not found.") -> None: super().__init__(status_code=404, detail=detail) # 413 class ContextLengthExceededException(HTTPException): - def __init__(self, detail: str = "Context length exceeded."): + def __init__(self, detail: str = "Context length exceeded.") -> None: super().__init__(status_code=413, detail=detail) class FileSizeLimitExceededException(HTTPException): - def __init__(self, detail: str = "File size limit exceeded."): + MAX_CONTENT_SIZE = 20 * 1024 * 1024 # 20MB + + def __init__(self, detail: str = f"File size limit exceeded (max: {MAX_CONTENT_SIZE} bytes).") -> None: super().__init__(status_code=413, detail=detail) # 422 class InvalidJSONFormatException(HTTPException): - def __init__(self, detail: str = "Invalid JSON file format."): + def __init__(self, detail: str = "Invalid JSON file format.") -> None: super().__init__(status_code=422, detail=detail) class WrongModelTypeException(HTTPException): - def __init__(self, detail: str = "Wrong model type."): + def __init__(self, detail: str = "Wrong model type.") -> None: super().__init__(status_code=422, detail=detail) class MaxTokensExceededException(HTTPException): - def __init__(self, detail: str = "Max tokens exceeded."): - super().__init__(status_code=422, detail=detail) - - -class WrongCollectionTypeException(HTTPException): - def __init__(self, detail: str = "Wrong collection type."): + def __init__(self, detail: str = "Max tokens exceeded.") -> None: super().__init__(status_code=422, detail=detail) class DifferentCollectionsModelsException(HTTPException): - def __init__(self, detail: str = "Different collections models."): + def __init__(self, detail: str = "Different collections models.") -> None: super().__init__(status_code=422, detail=detail) class UnsupportedFileTypeException(HTTPException): - def __init__(self, detail: str = "Unsupported file type."): + def __init__(self, detail: str = "Unsupported file type.") -> None: super().__init__(status_code=422, detail=detail) diff --git a/app/utils/security.py b/app/utils/security.py index f3b91c8d..93a189f3 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -1,5 +1,3 @@ -import base64 -import hashlib from typing import Annotated, Optional from fastapi import Depends, Request @@ -7,32 +5,30 @@ from app.schemas.security import User from app.utils.settings import settings -from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException +from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException, InsufficientRightsException from app.utils.lifespan import clients -from app.utils.variables import ROLE_LEVEL_0, ROLE_LEVEL_2 +from app.schemas.security import Role -def encode_string(input: str) -> str: - """ - Generate a 16 length unique code from an input string using salted SHA-256 hashing. - - Args: - input_string (str): The input string to generate the code from. +if settings.auth: - Returns: - tuple[str, bytes]: A tuple containing the generated code and the salt used. - """ - hash = hashlib.sha256((input).encode()).digest() - hash = base64.urlsafe_b64encode(hash).decode() - # remove special characters and limit length - hash = "".join(c for c in hash if c.isalnum())[:16].lower() + def check_admin_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User: + """ + Check if the API key is valid and if the user has admin rights. - return hash + Args: + api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check. + Returns: + User: User object, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file. + """ + user = check_api_key(api_key=api_key) + if user.role != Role.ADMIN: + raise InsufficientRightsException() -if settings.auth: + return user - def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> str: + def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User: """ Check if the API key is valid. @@ -40,24 +36,25 @@ def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPB api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check. Returns: - str: User ID, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file. + User: User object, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file. """ if api_key.scheme != "Bearer": raise InvalidAuthenticationSchemeException() - role = clients.auth.check_api_key(api_key.credentials) - if role is None: + user = clients.auth.check_api_key(api_key.credentials) + if user is None: raise InvalidAPIKeyException() - user_id = encode_string(input=api_key.credentials) - - return User(id=user_id, role=role) + return user else: - def check_api_key(api_key: Optional[str] = None) -> str: - return User(id="no-auth", role=ROLE_LEVEL_2) + def check_admin_api_key(api_key: Optional[str] = None) -> User: + return User(id="no-auth", role=Role.ADMIN) + + def check_api_key(api_key: Optional[str] = None) -> User: + return User(id="no-auth", role=Role.ADMIN) def check_rate_limit(request: Request) -> Optional[str]: @@ -76,7 +73,7 @@ def check_rate_limit(request: Request) -> Optional[str]: api_key = HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) user = check_api_key(api_key=api_key) - if user.role > ROLE_LEVEL_0: + if user.role.value > Role.USER.value: return None else: return user.id diff --git a/app/utils/variables.py b/app/utils/variables.py index 46258cf4..eb5b039c 100644 --- a/app/utils/variables.py +++ b/app/utils/variables.py @@ -26,6 +26,7 @@ SEARCH_ELASTIC_TYPE = "elastic" SEARCH_QDRANT_TYPE = "qdrant" + SUPPORTED_LANGUAGES = { "afrikaans": "af", "albanian": "sq", diff --git a/pyproject.toml b/pyproject.toml index 33d14801..a9f8e92c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ app = [ "fastapi==0.111.0", "pydantic==2.10.2", "pydantic-settings==2.6.1", + "prometheus-fastapi-instrumentator==7.0.0", "pyyaml==6.0.1", "grist-api==0.1.0", "six==1.16.0",