Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement JWT authentication with access and refresh tokens (#8)
Browse files Browse the repository at this point in the history
* Implement JWT authentication with access and refresh tokens

* Refactor battle, building, and city tests to assert 401 status code for unauthorized access

* Remove unused mock patches from battle, building, and city tests to simplify code

* Update super-linter workflow to trigger on pull requests to the main branch
Vianpyro authored Nov 18, 2024
1 parent 07e3305 commit 2b65b1c
Showing 12 changed files with 124 additions and 56 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/super-linter.yml
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@ name: Lint

on:
push: null
pull_request: null
pull_request:
branches:
- main

permissions: {}

6 changes: 6 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -45,5 +45,11 @@ def add_status(response):
return response


@app.after_request
def add_common_headers(response):
response.headers["X-Content-Type-Options"] = "nosniff"
return response


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
57 changes: 57 additions & 0 deletions jwt_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
from datetime import datetime, timedelta, timezone
from functools import wraps

import jwt
from flask import jsonify, request

SECRET_KEY = os.getenv("SECRET_JWT_KEY", "SuperSecretKey")
ACCESS_TOKEN_EXPIRY = timedelta(hours=1)
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


def generate_access_token(player_id: int) -> str:
"""Generate a JWT token for a user."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
"iat": datetime.now(timezone.utc), # Issued at
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def generate_refresh_token(player_id: int) -> str:
"""Generate a long-lived refresh token."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
"iat": datetime.now(timezone.utc),
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def verify_token(token: str) -> dict | None:
"""Verify a JWT token and return the payload."""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return None # Token expired
except jwt.InvalidTokenError:
return None # Invalid token


def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get("Authorization")
if not token:
return jsonify(message="Token is missing"), 401

decoded = verify_token(token)
if not decoded:
return jsonify(message="Token is invalid or expired"), 401

request.player_id = decoded["player_id"] # Attach user ID to the request
return f(*args, **kwargs)

return decorated
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
argon2-cffi>=23.1.0
flask>=3.0.3
flask-cors>=5.0.0
pyjwt>=2.10.0
pymysql>=1.1.1
pytest>=8.3.3
python-dotenv>=1.0.1
42 changes: 40 additions & 2 deletions routes/authentication.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
from re import match

import jwt
from argon2 import PasswordHasher, exceptions
from dotenv import load_dotenv
from flask import Blueprint, jsonify, request
from pymysql import MySQLError

from db import get_db_connection
from jwt_helper import generate_access_token, generate_refresh_token

load_dotenv()

@@ -77,20 +79,56 @@ def login():
db = get_db_connection()
with db.cursor() as cursor:
cursor.execute(
"SELECT hashed_password, salt FROM player WHERE email = %s", (email,)
"SELECT player_id, hashed_password, salt FROM player WHERE email = %s",
(email,),
)
player = cursor.fetchone()

if not player:
return jsonify(message="Invalid credentials"), 401

player_id = player["player_id"]
stored_password = player["hashed_password"]
salt = player["salt"]
pepper = os.getenv("PEPPER").encode("utf-8")
seasoned_password = password.encode("utf-8") + salt + pepper

try:
ph.verify(stored_password, seasoned_password)
return jsonify(message="Login successful"), 200
access_token = generate_access_token(player_id)
refresh_token = generate_refresh_token(player_id)
return jsonify(
message="Login successful",
access_token=access_token,
refresh_token=refresh_token,
)
except exceptions.VerifyMismatchError:
return jsonify(message="Invalid credentials"), 401


@authentication_blueprint.route("/refresh", methods=["POST"])
def refresh_token():
auth_header = request.headers.get("Authorization")

if not auth_header or not auth_header.startswith("Bearer "):
return (
jsonify(message="Refresh token is required in the Authorization header"),
400,
)

refresh_token = auth_header.split("Bearer ")[1]

try:
decoded = jwt.decode(
refresh_token,
os.getenv("SECRET_JWT_KEY", "SuperSecretKey"),
algorithms=["HS256"],
)
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
return jsonify(access_token=new_access_token), 200
except jwt.ExpiredSignatureError:
return jsonify(message="Refresh token has expired, please log in again"), 401
except jwt.InvalidTokenError:
return jsonify(message="Invalid refresh token"), 401
2 changes: 2 additions & 0 deletions routes/battle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

battles_blueprint = Blueprint("battles", __name__)

@@ -28,6 +29,7 @@ def get_battle_by_id(battle_id):


@battles_blueprint.route("/<int:battle_id>/units", methods=["GET"])
@token_required
def get_battle_units(battle_id):
db = get_db_connection()
with db.cursor() as cursor:
2 changes: 2 additions & 0 deletions routes/building.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

buildings_blueprint = Blueprint("buildings", __name__)

@@ -17,6 +18,7 @@ def get_all_buildings():


@buildings_blueprint.route("/<int:building_id>", methods=["GET"])
@token_required
def get_building_by_id(building_id):
db = get_db_connection()
with db.cursor() as cursor:
3 changes: 3 additions & 0 deletions routes/city.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

cities_blueprint = Blueprint("cities", __name__)

@@ -28,6 +29,7 @@ def get_city_by_id(city_id):


@cities_blueprint.route("/<int:city_id>/buildings", methods=["GET"])
@token_required
def get_city_buildings(city_id):
db = get_db_connection()
with db.cursor() as cursor:
@@ -39,6 +41,7 @@ def get_city_buildings(city_id):


@cities_blueprint.route("/<int:city_id>/units", methods=["GET"])
@token_required
def get_city_units(city_id):
db = get_db_connection()
with db.cursor() as cursor:
2 changes: 2 additions & 0 deletions routes/player.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify, request

from db import get_db_connection
from jwt_helper import token_required

players_blueprint = Blueprint("players", __name__)

@@ -72,6 +73,7 @@ def get_player_cities(player_id):


@players_blueprint.route("/<int:player_id>/battles", methods=["GET"])
@token_required
def get_player_battles(player_id):
db = get_db_connection()
with db.cursor() as cursor:
19 changes: 2 additions & 17 deletions tests/test_battle.py
Original file line number Diff line number Diff line change
@@ -66,21 +66,6 @@ def test_get_battle_by_id(mock_get_db, client):


# Test get_battle_units endpoint
@patch("routes.battle.get_db_connection")
def test_get_battle_units(mock_get_db, client):
mock_db_response(
mock_get_db,
[
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
],
)

def test_get_battle_units(client):
response = client(battles_blueprint).get("/battles/1/units")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == [
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
]
assert response.status_code == 401
10 changes: 2 additions & 8 deletions tests/test_building.py
Original file line number Diff line number Diff line change
@@ -19,15 +19,9 @@ def test_get_all_buildings(mock_get_db, client):


# Test get_building_by_id endpoint
@patch("routes.building.get_db_connection")
def test_get_building_by_id(mock_get_db, client):
mock_db_response(mock_get_db, {"id": 1, "name": "Barracks"}, fetchone=True)

def test_get_building_by_id(client):
response = client(buildings_blueprint).get("/buildings/1")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == {"id": 1, "name": "Barracks"}
assert response.status_code == 401


# Test get_building_prerequisites endpoint
32 changes: 4 additions & 28 deletions tests/test_city.py
Original file line number Diff line number Diff line change
@@ -31,36 +31,12 @@ def test_get_city_by_id(mock_get_db, client):


# Test get_city_buildings endpoint
@patch("routes.city.get_db_connection")
def test_get_city_buildings(mock_get_db, client):
mock_db_response(
mock_get_db,
[{"building_id": 1, "name": "Barracks"}, {"building_id": 2, "name": "Academy"}],
)

def test_get_city_buildings(client):
response = client(cities_blueprint).get("/cities/1/buildings")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == [
{"building_id": 1, "name": "Barracks"},
{"building_id": 2, "name": "Academy"},
]
assert response.status_code == 401


# Test get_city_units endpoint
@patch("routes.city.get_db_connection")
def test_get_city_units(mock_get_db, client):
mock_db_response(
mock_get_db,
[{"unit_id": 101, "name": "Infantry"}, {"unit_id": 102, "name": "Cavalry"}],
)

def test_get_city_units(client):
response = client(cities_blueprint).get("/cities/1/units")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == [
{"unit_id": 101, "name": "Infantry"},
{"unit_id": 102, "name": "Cavalry"},
]
assert response.status_code == 401

0 comments on commit 2b65b1c

Please sign in to comment.