diff --git a/jwt_helper.py b/jwt_helper.py index 8a8cf01..dabfb93 100644 --- a/jwt_helper.py +++ b/jwt_helper.py @@ -55,12 +55,15 @@ def extract_token_from_header() -> str: return auth_header.split("Bearer ")[1] -def verify_token(token: str) -> dict: +def verify_token(token: str, required_type: str) -> dict: """ Verify and decode a JWT token. """ try: - return jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + decoded = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + if decoded.get("token_type") != required_type: + raise jwt.InvalidTokenError("Invalid token type") + return decoded except jwt.ExpiredSignatureError: raise TokenError("Token has expired", 401) except jwt.InvalidTokenError: @@ -76,7 +79,7 @@ def token_required(f): def decorated(*args, **kwargs): try: token = extract_token_from_header() - decoded = verify_token(token) + decoded = verify_token(token, required_type="access") request.player_id = decoded["player_id"] return f(*args, **kwargs) except TokenError as e: diff --git a/routes/authentication.py b/routes/authentication.py index 247255a..5befde2 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -8,11 +8,11 @@ from db import get_db_connection from jwt_helper import ( + TokenError, + extract_token_from_header, generate_access_token, generate_refresh_token, verify_token, - extract_token_from_header, - TokenError, ) load_dotenv() @@ -117,7 +117,7 @@ def login(): def refresh_token(): try: token = extract_token_from_header() - decoded = verify_token(token) + decoded = verify_token(token, required_type="refresh") player_id = decoded["player_id"] new_access_token = generate_access_token(player_id)