From 9601b8b5957c4e1830e60f570fbb4c246a3910e6 Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Sun, 17 Nov 2024 10:48:29 -0500 Subject: [PATCH 1/6] Implement user authentication with registration and login endpoints --- .dockerignore | 1 - app.py | 2 +- requirements.txt | 1 + routes/__init__.py | 2 + routes/authentication.py | 110 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 routes/authentication.py diff --git a/.dockerignore b/.dockerignore index 8c4153c..9401b48 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,4 +9,3 @@ __pycache__ # Ignore local development files *.pyc .DS_Store -.env diff --git a/app.py b/app.py index a03a901..12e1327 100644 --- a/app.py +++ b/app.py @@ -36,7 +36,7 @@ def add_status(response): if response.is_json: original_data = response.get_json() new_response = { - "success": response.status_code == 200, + "success": response.status_code in range(200, 300), "data": original_data if original_data != [] else None, } response.set_data(jsonify(new_response).data) diff --git a/requirements.txt b/requirements.txt index 2556013..1a1bdb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +argon2-cffi>=23.1.0 flask>=3.0.3 pymysql>=1.1.1 pytest>=8.3.3 diff --git a/routes/__init__.py b/routes/__init__.py index 713938d..25e7f97 100644 --- a/routes/__init__.py +++ b/routes/__init__.py @@ -1,3 +1,4 @@ +from .authentication import authentication_blueprint from .battle import battles_blueprint from .building import buildings_blueprint from .city import cities_blueprint @@ -8,6 +9,7 @@ def register_routes(app): + app.register_blueprint(authentication_blueprint, url_prefix="/auth") app.register_blueprint(battles_blueprint, url_prefix="/battles") app.register_blueprint(buildings_blueprint, url_prefix="/buildings") app.register_blueprint(cities_blueprint, url_prefix="/cities") diff --git a/routes/authentication.py b/routes/authentication.py new file mode 100644 index 0000000..d866eb1 --- /dev/null +++ b/routes/authentication.py @@ -0,0 +1,110 @@ +import os +from re import match + +from argon2 import PasswordHasher, exceptions +from dotenv import load_dotenv +from flask import Blueprint, jsonify, request +from pymysql import MySQLError + +from db import get_db_connection + +load_dotenv() + +authentication_blueprint = Blueprint("authentication", __name__) +ph = PasswordHasher() + + +def hash_password_with_salt_and_pepper(password: str) -> str: + salt = os.urandom(16) + pepper = os.getenv("PEPPER").encode("utf-8") + seasoned_password = password.encode("utf-8") + salt + pepper + return ph.hash(seasoned_password), salt + + +def validate_password(password): + """ + Validates a password based on the following criteria: + - At least 12 characters long. + - Contains at least one uppercase letter (A-Z). + - Contains at least one lowercase letter (a-z). + - Contains at least one digit (0-9). + - Contains at least one special character (any non-alphanumeric character). + """ + return bool( + match(r"^(?=.*[A-Z])(?=.*[a-z])(?=.*\d)(?=.*[^A-Za-z0-9]).{12,}$", password) + ) + + +@authentication_blueprint.route("/register", methods=["POST"]) +def register(): + data = request.get_json() + name = data.get("name") + email = data.get("email") + password = data.get("password") + + if not name or not email or not password: + return jsonify(message="Username, email, and password are required"), 400 + + if not validate_password(password): + return jsonify(message="Password does not meet security requirements"), 400 + + hashed_password, salt = hash_password_with_salt_and_pepper(password) + + db = get_db_connection() + with db.cursor() as cursor: + try: + cursor.callproc("register_player", (name, email, hashed_password, salt)) + db.commit() + except MySQLError as e: + if e.args[0] == 1644: + return jsonify(message="Email already in use"), 400 + else: + return jsonify(message="An error occurred during registration"), 500 + + db.close() + return jsonify(message="User created successfully"), 201 + + +@authentication_blueprint.route("/login", methods=["POST"]) +def login(): + data = request.get_json() + email = data.get("email") + password = data.get("password") + + if not email or not password: + return jsonify(message="Email and password are required"), 400 + + db = get_db_connection() + with db.cursor() as cursor: + cursor.execute( + "SELECT hashed_password, salt FROM player WHERE email = %s", (email,) + ) + player = cursor.fetchone() + + if not player: + return jsonify(message="Invalid credentials"), 401 + + stored_password = player["hashed_password"] + salt = player["salt"] + pepper = os.getenv("PEPPER").encode("utf-8") + seasoned_password = password.encode("utf-8") + salt + pepper + + try: + ph.verify(stored_password, seasoned_password) + return jsonify(message="Login successful"), 200 + except exceptions.VerifyMismatchError: + return jsonify(message="Invalid credentials"), 401 + + +@authentication_blueprint.route("/test", methods=["POST"]) +def get_player_by_email(): + data = request.get_json() + email = data.get("email") + + db = get_db_connection() + with db.cursor() as cursor: + cursor.callproc("login_player", (email,)) + player = cursor.fetchone() + db.close() + + return jsonify(player) From c8fee35bb7563d2b9d8df5541ff528e5e20ee530 Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Sun, 17 Nov 2024 11:57:33 -0500 Subject: [PATCH 2/6] Remove unused test endpoint for retrieving player by email --- routes/authentication.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/routes/authentication.py b/routes/authentication.py index d866eb1..1ac0a60 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -94,17 +94,3 @@ def login(): return jsonify(message="Login successful"), 200 except exceptions.VerifyMismatchError: return jsonify(message="Invalid credentials"), 401 - - -@authentication_blueprint.route("/test", methods=["POST"]) -def get_player_by_email(): - data = request.get_json() - email = data.get("email") - - db = get_db_connection() - with db.cursor() as cursor: - cursor.callproc("login_player", (email,)) - player = cursor.fetchone() - db.close() - - return jsonify(player) From b44ebb066d5f747bf76ed47fc5f2881d33b9324f Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Sun, 17 Nov 2024 12:39:20 -0500 Subject: [PATCH 3/6] Change hash_password_with_salt_and_pepper function to return a tuple of hashed password and salt --- routes/authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routes/authentication.py b/routes/authentication.py index 1ac0a60..c82f865 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -14,7 +14,7 @@ ph = PasswordHasher() -def hash_password_with_salt_and_pepper(password: str) -> str: +def hash_password_with_salt_and_pepper(password: str) -> tuple[str, bytes]: salt = os.urandom(16) pepper = os.getenv("PEPPER").encode("utf-8") seasoned_password = password.encode("utf-8") + salt + pepper From 07e3305541d10253c52a40145769a1a57c386597 Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Sun, 17 Nov 2024 12:50:53 -0500 Subject: [PATCH 4/6] Use a default pepper value for Linting purpose --- routes/authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routes/authentication.py b/routes/authentication.py index c82f865..a899acd 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -16,7 +16,7 @@ def hash_password_with_salt_and_pepper(password: str) -> tuple[str, bytes]: salt = os.urandom(16) - pepper = os.getenv("PEPPER").encode("utf-8") + pepper = os.getenv("PEPPER", "SuperSecretPepper").encode("utf-8") seasoned_password = password.encode("utf-8") + salt + pepper return ph.hash(seasoned_password), salt From 2b65b1c8bc643f51bebc16673a2280135b4a1127 Mon Sep 17 00:00:00 2001 From: Vianney Veremme <10519369+Vianpyro@users.noreply.github.com> Date: Mon, 18 Nov 2024 08:15:13 -0500 Subject: [PATCH 5/6] Implement JWT authentication with access and refresh tokens (#8) * Implement JWT authentication with access and refresh tokens * Refactor battle, building, and city tests to assert 401 status code for unauthorized access * Remove unused mock patches from battle, building, and city tests to simplify code * Update super-linter workflow to trigger on pull requests to the main branch --- .github/workflows/super-linter.yml | 4 ++- app.py | 6 ++++ jwt_helper.py | 57 ++++++++++++++++++++++++++++++ requirements.txt | 1 + routes/authentication.py | 42 ++++++++++++++++++++-- routes/battle.py | 2 ++ routes/building.py | 2 ++ routes/city.py | 3 ++ routes/player.py | 2 ++ tests/test_battle.py | 19 ++-------- tests/test_building.py | 10 ++---- tests/test_city.py | 32 +++-------------- 12 files changed, 124 insertions(+), 56 deletions(-) create mode 100644 jwt_helper.py diff --git a/.github/workflows/super-linter.yml b/.github/workflows/super-linter.yml index fae11ad..5d095e9 100644 --- a/.github/workflows/super-linter.yml +++ b/.github/workflows/super-linter.yml @@ -3,7 +3,9 @@ name: Lint on: push: null - pull_request: null + pull_request: + branches: + - main permissions: {} diff --git a/app.py b/app.py index fa5fd75..f7ed7a0 100644 --- a/app.py +++ b/app.py @@ -45,5 +45,11 @@ def add_status(response): return response +@app.after_request +def add_common_headers(response): + response.headers["X-Content-Type-Options"] = "nosniff" + return response + + if __name__ == "__main__": app.run(host="0.0.0.0", port=5000) diff --git a/jwt_helper.py b/jwt_helper.py new file mode 100644 index 0000000..e33996d --- /dev/null +++ b/jwt_helper.py @@ -0,0 +1,57 @@ +import os +from datetime import datetime, timedelta, timezone +from functools import wraps + +import jwt +from flask import jsonify, request + +SECRET_KEY = os.getenv("SECRET_JWT_KEY", "SuperSecretKey") +ACCESS_TOKEN_EXPIRY = timedelta(hours=1) +REFRESH_TOKEN_EXPIRY = timedelta(days=30) + + +def generate_access_token(player_id: int) -> str: + """Generate a JWT token for a user.""" + payload = { + "player_id": player_id, + "exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration + "iat": datetime.now(timezone.utc), # Issued at + } + return jwt.encode(payload, SECRET_KEY, algorithm="HS256") + + +def generate_refresh_token(player_id: int) -> str: + """Generate a long-lived refresh token.""" + payload = { + "player_id": player_id, + "exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY, + "iat": datetime.now(timezone.utc), + } + return jwt.encode(payload, SECRET_KEY, algorithm="HS256") + + +def verify_token(token: str) -> dict | None: + """Verify a JWT token and return the payload.""" + try: + return jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + except jwt.ExpiredSignatureError: + return None # Token expired + except jwt.InvalidTokenError: + return None # Invalid token + + +def token_required(f): + @wraps(f) + def decorated(*args, **kwargs): + token = request.headers.get("Authorization") + if not token: + return jsonify(message="Token is missing"), 401 + + decoded = verify_token(token) + if not decoded: + return jsonify(message="Token is invalid or expired"), 401 + + request.player_id = decoded["player_id"] # Attach user ID to the request + return f(*args, **kwargs) + + return decorated diff --git a/requirements.txt b/requirements.txt index 45fafdf..2b01573 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ argon2-cffi>=23.1.0 flask>=3.0.3 flask-cors>=5.0.0 +pyjwt>=2.10.0 pymysql>=1.1.1 pytest>=8.3.3 python-dotenv>=1.0.1 diff --git a/routes/authentication.py b/routes/authentication.py index a899acd..96f7bf7 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -1,12 +1,14 @@ import os from re import match +import jwt from argon2 import PasswordHasher, exceptions from dotenv import load_dotenv from flask import Blueprint, jsonify, request from pymysql import MySQLError from db import get_db_connection +from jwt_helper import generate_access_token, generate_refresh_token load_dotenv() @@ -77,13 +79,15 @@ def login(): db = get_db_connection() with db.cursor() as cursor: cursor.execute( - "SELECT hashed_password, salt FROM player WHERE email = %s", (email,) + "SELECT player_id, hashed_password, salt FROM player WHERE email = %s", + (email,), ) player = cursor.fetchone() if not player: return jsonify(message="Invalid credentials"), 401 + player_id = player["player_id"] stored_password = player["hashed_password"] salt = player["salt"] pepper = os.getenv("PEPPER").encode("utf-8") @@ -91,6 +95,40 @@ def login(): try: ph.verify(stored_password, seasoned_password) - return jsonify(message="Login successful"), 200 + access_token = generate_access_token(player_id) + refresh_token = generate_refresh_token(player_id) + return jsonify( + message="Login successful", + access_token=access_token, + refresh_token=refresh_token, + ) except exceptions.VerifyMismatchError: return jsonify(message="Invalid credentials"), 401 + + +@authentication_blueprint.route("/refresh", methods=["POST"]) +def refresh_token(): + auth_header = request.headers.get("Authorization") + + if not auth_header or not auth_header.startswith("Bearer "): + return ( + jsonify(message="Refresh token is required in the Authorization header"), + 400, + ) + + refresh_token = auth_header.split("Bearer ")[1] + + try: + decoded = jwt.decode( + refresh_token, + os.getenv("SECRET_JWT_KEY", "SuperSecretKey"), + algorithms=["HS256"], + ) + player_id = decoded["player_id"] + + new_access_token = generate_access_token(player_id) + return jsonify(access_token=new_access_token), 200 + except jwt.ExpiredSignatureError: + return jsonify(message="Refresh token has expired, please log in again"), 401 + except jwt.InvalidTokenError: + return jsonify(message="Invalid refresh token"), 401 diff --git a/routes/battle.py b/routes/battle.py index 3107301..32265cd 100644 --- a/routes/battle.py +++ b/routes/battle.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify from db import get_db_connection +from jwt_helper import token_required battles_blueprint = Blueprint("battles", __name__) @@ -28,6 +29,7 @@ def get_battle_by_id(battle_id): @battles_blueprint.route("//units", methods=["GET"]) +@token_required def get_battle_units(battle_id): db = get_db_connection() with db.cursor() as cursor: diff --git a/routes/building.py b/routes/building.py index 02eadbe..ea7d11a 100644 --- a/routes/building.py +++ b/routes/building.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify from db import get_db_connection +from jwt_helper import token_required buildings_blueprint = Blueprint("buildings", __name__) @@ -17,6 +18,7 @@ def get_all_buildings(): @buildings_blueprint.route("/", methods=["GET"]) +@token_required def get_building_by_id(building_id): db = get_db_connection() with db.cursor() as cursor: diff --git a/routes/city.py b/routes/city.py index 66483df..a80f6dc 100644 --- a/routes/city.py +++ b/routes/city.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify from db import get_db_connection +from jwt_helper import token_required cities_blueprint = Blueprint("cities", __name__) @@ -28,6 +29,7 @@ def get_city_by_id(city_id): @cities_blueprint.route("//buildings", methods=["GET"]) +@token_required def get_city_buildings(city_id): db = get_db_connection() with db.cursor() as cursor: @@ -39,6 +41,7 @@ def get_city_buildings(city_id): @cities_blueprint.route("//units", methods=["GET"]) +@token_required def get_city_units(city_id): db = get_db_connection() with db.cursor() as cursor: diff --git a/routes/player.py b/routes/player.py index 061af8d..cd99ac4 100644 --- a/routes/player.py +++ b/routes/player.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify, request from db import get_db_connection +from jwt_helper import token_required players_blueprint = Blueprint("players", __name__) @@ -72,6 +73,7 @@ def get_player_cities(player_id): @players_blueprint.route("//battles", methods=["GET"]) +@token_required def get_player_battles(player_id): db = get_db_connection() with db.cursor() as cursor: diff --git a/tests/test_battle.py b/tests/test_battle.py index 6ce8197..74d4f21 100644 --- a/tests/test_battle.py +++ b/tests/test_battle.py @@ -66,21 +66,6 @@ def test_get_battle_by_id(mock_get_db, client): # Test get_battle_units endpoint -@patch("routes.battle.get_db_connection") -def test_get_battle_units(mock_get_db, client): - mock_db_response( - mock_get_db, - [ - {"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0}, - {"unit_id": 102, "name": "Archer", "count": 30, "side": 1}, - ], - ) - +def test_get_battle_units(client): response = client(battles_blueprint).get("/battles/1/units") - json_data = response.get_json(force=True) - - assert response.status_code == 200 - assert json_data == [ - {"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0}, - {"unit_id": 102, "name": "Archer", "count": 30, "side": 1}, - ] + assert response.status_code == 401 diff --git a/tests/test_building.py b/tests/test_building.py index 74ef5a3..a2180b6 100644 --- a/tests/test_building.py +++ b/tests/test_building.py @@ -19,15 +19,9 @@ def test_get_all_buildings(mock_get_db, client): # Test get_building_by_id endpoint -@patch("routes.building.get_db_connection") -def test_get_building_by_id(mock_get_db, client): - mock_db_response(mock_get_db, {"id": 1, "name": "Barracks"}, fetchone=True) - +def test_get_building_by_id(client): response = client(buildings_blueprint).get("/buildings/1") - json_data = response.get_json(force=True) - - assert response.status_code == 200 - assert json_data == {"id": 1, "name": "Barracks"} + assert response.status_code == 401 # Test get_building_prerequisites endpoint diff --git a/tests/test_city.py b/tests/test_city.py index e3d9d33..7665e9f 100644 --- a/tests/test_city.py +++ b/tests/test_city.py @@ -31,36 +31,12 @@ def test_get_city_by_id(mock_get_db, client): # Test get_city_buildings endpoint -@patch("routes.city.get_db_connection") -def test_get_city_buildings(mock_get_db, client): - mock_db_response( - mock_get_db, - [{"building_id": 1, "name": "Barracks"}, {"building_id": 2, "name": "Academy"}], - ) - +def test_get_city_buildings(client): response = client(cities_blueprint).get("/cities/1/buildings") - json_data = response.get_json(force=True) - - assert response.status_code == 200 - assert json_data == [ - {"building_id": 1, "name": "Barracks"}, - {"building_id": 2, "name": "Academy"}, - ] + assert response.status_code == 401 # Test get_city_units endpoint -@patch("routes.city.get_db_connection") -def test_get_city_units(mock_get_db, client): - mock_db_response( - mock_get_db, - [{"unit_id": 101, "name": "Infantry"}, {"unit_id": 102, "name": "Cavalry"}], - ) - +def test_get_city_units(client): response = client(cities_blueprint).get("/cities/1/units") - json_data = response.get_json(force=True) - - assert response.status_code == 200 - assert json_data == [ - {"unit_id": 101, "name": "Infantry"}, - {"unit_id": 102, "name": "Cavalry"}, - ] + assert response.status_code == 401 From 5ffea9418920cfd41a24551e662916763b2941e7 Mon Sep 17 00:00:00 2001 From: Vianpyro Date: Mon, 18 Nov 2024 20:24:54 -0500 Subject: [PATCH 6/6] Enhance registration error handling for duplicate player names and emails --- routes/authentication.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/routes/authentication.py b/routes/authentication.py index 96f7bf7..9abc650 100644 --- a/routes/authentication.py +++ b/routes/authentication.py @@ -58,8 +58,11 @@ def register(): cursor.callproc("register_player", (name, email, hashed_password, salt)) db.commit() except MySQLError as e: - if e.args[0] == 1644: - return jsonify(message="Email already in use"), 400 + # Check for specific error messages in the SQL error + if "Player name already exists" in str(e): + return jsonify(message="Player name already exists"), 400 + elif "Email already exists" in str(e): + return jsonify(message="Email already exists"), 400 else: return jsonify(message="An error occurred during registration"), 500