Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: always return proof.file_path for proof uploaded by the user #132

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import time
import uuid
from pathlib import Path
Expand Down Expand Up @@ -79,6 +80,11 @@ def get_db():
# Authentication helpers
# ------------------------------------------------------------------------------
oauth2_scheme = OAuth2PasswordBearerOrAuthCookie(tokenUrl="auth")
# Version of oauth2_scheme that does not raise an error if the token is
# invalid or missing
oauth2_scheme_no_error = OAuth2PasswordBearerOrAuthCookie(
tokenUrl="auth", auto_error=False
)


def create_token(user_id: str):
Expand All @@ -87,11 +93,20 @@ def create_token(user_id: str):

def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)
):
) -> schemas.UserBase:
"""Get the current user if authenticated.

This function is used as a dependency in endpoints that require
authentication. It raises an HTTPException if the user is not
authenticated.

:param token: the authentication token
:param db: the database session
:raises HTTPException: if the user is not authenticated
:return: the current user
"""
if token and "__U" in token:
current_user: schemas.UserBase = crud.update_user_last_used_field(
db, token=token
)
current_user = crud.update_user_last_used_field(db, token=token)
if current_user:
return current_user
raise HTTPException(
Expand All @@ -101,6 +116,24 @@ def get_current_user(
)


def get_current_user_optional(
token: Annotated[str, Depends(oauth2_scheme_no_error)],
db: Session = Depends(get_db),
) -> schemas.UserBase | None:
"""Get the current user if authenticated, None otherwise.

This function is used as a dependency in endpoints that require
authentication, but where the user is optional.

:param token: the authentication token
:param db: the database session
:return: the current user if authenticated, None otherwise
"""
if token and "__U" in token:
return crud.update_user_last_used_field(db, token=token)
return None


# Routes
# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -166,14 +199,26 @@ def authentication(
)


def price_transformer(prices: list[Price]) -> list[Price]:
def price_transformer(
prices: list[Price], current_user: schemas.UserBase | None = None
) -> list[Price]:
"""Transformer function used to remove the file_path of private proofs.

If current_user is None, the file_path is removed for all proofs that are
not public. Otherwise, the file_path is removed for all proofs that are not
public and do not belong to the current user.

:param prices: the list of prices to transform
:param current_user: the current user, if authenticated
:return: the transformed list of prices
"""
user_id = current_user.user_id if current_user else None
for price in prices:
if price.proof and price.proof.is_public is False:
if (
price.proof
and price.proof.is_public is False
and price.proof.owner != user_id
):
price.proof.file_path = None
return prices

Expand All @@ -182,9 +227,12 @@ def price_transformer(prices: list[Price]) -> list[Price]:
def get_price(
filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter),
db: Session = Depends(get_db),
current_user: schemas.UserBase | None = Depends(get_current_user_optional),
):
return paginate(
db, crud.get_prices_query(filters=filters), transformer=price_transformer
db,
crud.get_prices_query(filters=filters),
transformer=functools.partial(price_transformer, current_user=current_user),
)


Expand Down
4 changes: 2 additions & 2 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_user(db: Session, user: UserBase):
return db_user


def update_user_last_used_field(db: Session, token: str):
def update_user_last_used_field(db: Session, token: str) -> UserBase | None:
db_user = get_user_by_token(db, token=token)
if db_user:
db.query(User).filter(User.user_id == db_user.user_id).update(
Expand All @@ -57,7 +57,7 @@ def update_user_last_used_field(db: Session, token: str):
db.commit()
db.refresh(db_user)
return db_user
return False
return None


def delete_user(db: Session, user_id: UserBase):
Expand Down
Loading