Skip to content

Commit

Permalink
Fix token types (#14)
Browse files Browse the repository at this point in the history
* Enhance token verification to check for required token types in JWT handling

* Add token type to access and refresh token payloads
  • Loading branch information
Vianpyro authored Nov 22, 2024
1 parent 965bf98 commit 9b6bf10
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 8 additions & 3 deletions jwt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def generate_access_token(player_id: int) -> str:
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
"iat": datetime.now(timezone.utc), # Issued at
"token_type": "access",
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

Expand All @@ -41,6 +42,7 @@ def generate_refresh_token(player_id: int) -> str:
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
"iat": datetime.now(timezone.utc),
"token_type": "refresh",
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

Expand All @@ -55,12 +57,15 @@ def extract_token_from_header() -> str:
return auth_header.split("Bearer ")[1]


def verify_token(token: str) -> dict:
def verify_token(token: str, required_type: str) -> dict:
"""
Verify and decode a JWT token.
"""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
decoded = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
if decoded.get("token_type") != required_type:
raise jwt.InvalidTokenError("Invalid token type")
return decoded
except jwt.ExpiredSignatureError:
raise TokenError("Token has expired", 401)
except jwt.InvalidTokenError:
Expand All @@ -76,7 +81,7 @@ def token_required(f):
def decorated(*args, **kwargs):
try:
token = extract_token_from_header()
decoded = verify_token(token)
decoded = verify_token(token, required_type="access")
request.player_id = decoded["player_id"]
return f(*args, **kwargs)
except TokenError as e:
Expand Down
6 changes: 3 additions & 3 deletions routes/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from db import get_db_connection
from jwt_helper import (
TokenError,
extract_token_from_header,
generate_access_token,
generate_refresh_token,
verify_token,
extract_token_from_header,
TokenError,
)

load_dotenv()
Expand Down Expand Up @@ -117,7 +117,7 @@ def login():
def refresh_token():
try:
token = extract_token_from_header()
decoded = verify_token(token)
decoded = verify_token(token, required_type="refresh")
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
Expand Down

0 comments on commit 9b6bf10

Please sign in to comment.