diff --git a/app/api.py b/app/api.py index c6d5d48c..4d292231 100644 --- a/app/api.py +++ b/app/api.py @@ -1,3 +1,4 @@ +import functools import time import uuid from pathlib import Path @@ -79,6 +80,11 @@ def get_db(): # Authentication helpers # ------------------------------------------------------------------------------ oauth2_scheme = OAuth2PasswordBearerOrAuthCookie(tokenUrl="auth") +# Version of oauth2_scheme that does not raise an error if the token is +# invalid or missing +oauth2_scheme_no_error = OAuth2PasswordBearerOrAuthCookie( + tokenUrl="auth", auto_error=False +) def create_token(user_id: str): @@ -87,11 +93,20 @@ def create_token(user_id: str): def get_current_user( token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db) -): +) -> schemas.UserBase: + """Get the current user if authenticated. + + This function is used as a dependency in endpoints that require + authentication. It raises an HTTPException if the user is not + authenticated. + + :param token: the authentication token + :param db: the database session + :raises HTTPException: if the user is not authenticated + :return: the current user + """ if token and "__U" in token: - current_user: schemas.UserBase = crud.update_user_last_used_field( - db, token=token - ) + current_user = crud.update_user_last_used_field(db, token=token) if current_user: return current_user raise HTTPException( @@ -101,6 +116,24 @@ def get_current_user( ) +def get_current_user_optional( + token: Annotated[str, Depends(oauth2_scheme_no_error)], + db: Session = Depends(get_db), +) -> schemas.UserBase | None: + """Get the current user if authenticated, None otherwise. + + This function is used as a dependency in endpoints that require + authentication, but where the user is optional. + + :param token: the authentication token + :param db: the database session + :return: the current user if authenticated, None otherwise + """ + if token and "__U" in token: + return crud.update_user_last_used_field(db, token=token) + return None + + # Routes # ------------------------------------------------------------------------------ @@ -166,14 +199,26 @@ def authentication( ) -def price_transformer(prices: list[Price]) -> list[Price]: +def price_transformer( + prices: list[Price], current_user: schemas.UserBase | None = None +) -> list[Price]: """Transformer function used to remove the file_path of private proofs. + If current_user is None, the file_path is removed for all proofs that are + not public. Otherwise, the file_path is removed for all proofs that are not + public and do not belong to the current user. + :param prices: the list of prices to transform + :param current_user: the current user, if authenticated :return: the transformed list of prices """ + user_id = current_user.user_id if current_user else None for price in prices: - if price.proof and price.proof.is_public is False: + if ( + price.proof + and price.proof.is_public is False + and price.proof.owner != user_id + ): price.proof.file_path = None return prices @@ -182,9 +227,12 @@ def price_transformer(prices: list[Price]) -> list[Price]: def get_price( filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter), db: Session = Depends(get_db), + current_user: schemas.UserBase | None = Depends(get_current_user_optional), ): return paginate( - db, crud.get_prices_query(filters=filters), transformer=price_transformer + db, + crud.get_prices_query(filters=filters), + transformer=functools.partial(price_transformer, current_user=current_user), ) diff --git a/app/crud.py b/app/crud.py index 493d2c4e..cdf1a203 100644 --- a/app/crud.py +++ b/app/crud.py @@ -48,7 +48,7 @@ def create_user(db: Session, user: UserBase): return db_user -def update_user_last_used_field(db: Session, token: str): +def update_user_last_used_field(db: Session, token: str) -> UserBase | None: db_user = get_user_by_token(db, token=token) if db_user: db.query(User).filter(User.user_id == db_user.user_id).update( @@ -57,7 +57,7 @@ def update_user_last_used_field(db: Session, token: str): db.commit() db.refresh(db_user) return db_user - return False + return None def delete_user(db: Session, user_id: UserBase):