Skip to content

Commit

Permalink
Implement user authentication with registration and login endpoints (#9)
Browse files Browse the repository at this point in the history
* Implement user authentication with registration and login endpoints

* Remove unused test endpoint for retrieving player by email

* Change hash_password_with_salt_and_pepper function to return a tuple of hashed password and salt

* Use a default pepper value for Linting purpose

* Implement JWT authentication with access and refresh tokens (#8)

* Implement JWT authentication with access and refresh tokens

* Refactor battle, building, and city tests to assert 401 status code for unauthorized access

* Remove unused mock patches from battle, building, and city tests to simplify code

* Update super-linter workflow to trigger on pull requests to the main branch

* Enhance registration error handling for duplicate player names and emails
  • Loading branch information
Vianpyro authored Nov 19, 2024
1 parent c7abe76 commit 59b8f7a
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 56 deletions.
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ __pycache__
# Ignore local development files
*.pyc
.DS_Store
.env
4 changes: 3 additions & 1 deletion .github/workflows/super-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ name: Lint

on:
push: null
pull_request: null
pull_request:
branches:
- main

permissions: {}

Expand Down
8 changes: 7 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def add_status(response):
if response.is_json:
original_data = response.get_json()
new_response = {
"success": response.status_code == 200,
"success": response.status_code in range(200, 300),
"data": original_data if original_data != [] else None,
}
response.set_data(jsonify(new_response).data)
return response


@app.after_request
def add_common_headers(response):
response.headers["X-Content-Type-Options"] = "nosniff"
return response


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
57 changes: 57 additions & 0 deletions jwt_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
from datetime import datetime, timedelta, timezone
from functools import wraps

import jwt
from flask import jsonify, request

SECRET_KEY = os.getenv("SECRET_JWT_KEY", "SuperSecretKey")
ACCESS_TOKEN_EXPIRY = timedelta(hours=1)
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


def generate_access_token(player_id: int) -> str:
"""Generate a JWT token for a user."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
"iat": datetime.now(timezone.utc), # Issued at
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def generate_refresh_token(player_id: int) -> str:
"""Generate a long-lived refresh token."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
"iat": datetime.now(timezone.utc),
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def verify_token(token: str) -> dict | None:
"""Verify a JWT token and return the payload."""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return None # Token expired
except jwt.InvalidTokenError:
return None # Invalid token


def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get("Authorization")
if not token:
return jsonify(message="Token is missing"), 401

decoded = verify_token(token)
if not decoded:
return jsonify(message="Token is invalid or expired"), 401

request.player_id = decoded["player_id"] # Attach user ID to the request
return f(*args, **kwargs)

return decorated
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
argon2-cffi>=23.1.0
flask>=3.0.3
flask-cors>=5.0.0
pyjwt>=2.10.0
pymysql>=1.1.1
pytest>=8.3.3
python-dotenv>=1.0.1
2 changes: 2 additions & 0 deletions routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .authentication import authentication_blueprint
from .battle import battles_blueprint
from .building import buildings_blueprint
from .city import cities_blueprint
Expand All @@ -8,6 +9,7 @@


def register_routes(app):
app.register_blueprint(authentication_blueprint, url_prefix="/auth")
app.register_blueprint(battles_blueprint, url_prefix="/battles")
app.register_blueprint(buildings_blueprint, url_prefix="/buildings")
app.register_blueprint(cities_blueprint, url_prefix="/cities")
Expand Down
137 changes: 137 additions & 0 deletions routes/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
from re import match

import jwt
from argon2 import PasswordHasher, exceptions
from dotenv import load_dotenv
from flask import Blueprint, jsonify, request
from pymysql import MySQLError

from db import get_db_connection
from jwt_helper import generate_access_token, generate_refresh_token

load_dotenv()

authentication_blueprint = Blueprint("authentication", __name__)
ph = PasswordHasher()


def hash_password_with_salt_and_pepper(password: str) -> tuple[str, bytes]:
salt = os.urandom(16)
pepper = os.getenv("PEPPER", "SuperSecretPepper").encode("utf-8")
seasoned_password = password.encode("utf-8") + salt + pepper
return ph.hash(seasoned_password), salt


def validate_password(password):
"""
Validates a password based on the following criteria:
- At least 12 characters long.
- Contains at least one uppercase letter (A-Z).
- Contains at least one lowercase letter (a-z).
- Contains at least one digit (0-9).
- Contains at least one special character (any non-alphanumeric character).
"""
return bool(
match(r"^(?=.*[A-Z])(?=.*[a-z])(?=.*\d)(?=.*[^A-Za-z0-9]).{12,}$", password)
)


@authentication_blueprint.route("/register", methods=["POST"])
def register():
data = request.get_json()
name = data.get("name")
email = data.get("email")
password = data.get("password")

if not name or not email or not password:
return jsonify(message="Username, email, and password are required"), 400

if not validate_password(password):
return jsonify(message="Password does not meet security requirements"), 400

hashed_password, salt = hash_password_with_salt_and_pepper(password)

db = get_db_connection()
with db.cursor() as cursor:
try:
cursor.callproc("register_player", (name, email, hashed_password, salt))
db.commit()
except MySQLError as e:
# Check for specific error messages in the SQL error
if "Player name already exists" in str(e):
return jsonify(message="Player name already exists"), 400
elif "Email already exists" in str(e):
return jsonify(message="Email already exists"), 400
else:
return jsonify(message="An error occurred during registration"), 500

db.close()
return jsonify(message="User created successfully"), 201


@authentication_blueprint.route("/login", methods=["POST"])
def login():
data = request.get_json()
email = data.get("email")
password = data.get("password")

if not email or not password:
return jsonify(message="Email and password are required"), 400

db = get_db_connection()
with db.cursor() as cursor:
cursor.execute(
"SELECT player_id, hashed_password, salt FROM player WHERE email = %s",
(email,),
)
player = cursor.fetchone()

if not player:
return jsonify(message="Invalid credentials"), 401

player_id = player["player_id"]
stored_password = player["hashed_password"]
salt = player["salt"]
pepper = os.getenv("PEPPER").encode("utf-8")
seasoned_password = password.encode("utf-8") + salt + pepper

try:
ph.verify(stored_password, seasoned_password)
access_token = generate_access_token(player_id)
refresh_token = generate_refresh_token(player_id)
return jsonify(
message="Login successful",
access_token=access_token,
refresh_token=refresh_token,
)
except exceptions.VerifyMismatchError:
return jsonify(message="Invalid credentials"), 401


@authentication_blueprint.route("/refresh", methods=["POST"])
def refresh_token():
auth_header = request.headers.get("Authorization")

if not auth_header or not auth_header.startswith("Bearer "):
return (
jsonify(message="Refresh token is required in the Authorization header"),
400,
)

refresh_token = auth_header.split("Bearer ")[1]

try:
decoded = jwt.decode(
refresh_token,
os.getenv("SECRET_JWT_KEY", "SuperSecretKey"),
algorithms=["HS256"],
)
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
return jsonify(access_token=new_access_token), 200
except jwt.ExpiredSignatureError:
return jsonify(message="Refresh token has expired, please log in again"), 401
except jwt.InvalidTokenError:
return jsonify(message="Invalid refresh token"), 401
2 changes: 2 additions & 0 deletions routes/battle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

battles_blueprint = Blueprint("battles", __name__)

Expand Down Expand Up @@ -28,6 +29,7 @@ def get_battle_by_id(battle_id):


@battles_blueprint.route("/<int:battle_id>/units", methods=["GET"])
@token_required
def get_battle_units(battle_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
2 changes: 2 additions & 0 deletions routes/building.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

buildings_blueprint = Blueprint("buildings", __name__)

Expand All @@ -17,6 +18,7 @@ def get_all_buildings():


@buildings_blueprint.route("/<int:building_id>", methods=["GET"])
@token_required
def get_building_by_id(building_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
3 changes: 3 additions & 0 deletions routes/city.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

cities_blueprint = Blueprint("cities", __name__)

Expand Down Expand Up @@ -28,6 +29,7 @@ def get_city_by_id(city_id):


@cities_blueprint.route("/<int:city_id>/buildings", methods=["GET"])
@token_required
def get_city_buildings(city_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand All @@ -39,6 +41,7 @@ def get_city_buildings(city_id):


@cities_blueprint.route("/<int:city_id>/units", methods=["GET"])
@token_required
def get_city_units(city_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
2 changes: 2 additions & 0 deletions routes/player.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify, request

from db import get_db_connection
from jwt_helper import token_required

players_blueprint = Blueprint("players", __name__)

Expand Down Expand Up @@ -72,6 +73,7 @@ def get_player_cities(player_id):


@players_blueprint.route("/<int:player_id>/battles", methods=["GET"])
@token_required
def get_player_battles(player_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
19 changes: 2 additions & 17 deletions tests/test_battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ def test_get_battle_by_id(mock_get_db, client):


# Test get_battle_units endpoint
@patch("routes.battle.get_db_connection")
def test_get_battle_units(mock_get_db, client):
mock_db_response(
mock_get_db,
[
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
],
)

def test_get_battle_units(client):
response = client(battles_blueprint).get("/battles/1/units")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == [
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
]
assert response.status_code == 401
10 changes: 2 additions & 8 deletions tests/test_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ def test_get_all_buildings(mock_get_db, client):


# Test get_building_by_id endpoint
@patch("routes.building.get_db_connection")
def test_get_building_by_id(mock_get_db, client):
mock_db_response(mock_get_db, {"id": 1, "name": "Barracks"}, fetchone=True)

def test_get_building_by_id(client):
response = client(buildings_blueprint).get("/buildings/1")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == {"id": 1, "name": "Barracks"}
assert response.status_code == 401


# Test get_building_prerequisites endpoint
Expand Down
Loading

0 comments on commit 59b8f7a

Please sign in to comment.