Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement user authentication with registration and login endpoints #9

Merged
merged 7 commits into from
Nov 19, 2024
Merged
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ __pycache__
# Ignore local development files
*.pyc
.DS_Store
.env
4 changes: 3 additions & 1 deletion .github/workflows/super-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ name: Lint

on:
push: null
pull_request: null
pull_request:
branches:
- main

permissions: {}

Expand Down
8 changes: 7 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def add_status(response):
if response.is_json:
original_data = response.get_json()
new_response = {
"success": response.status_code == 200,
"success": response.status_code in range(200, 300),
"data": original_data if original_data != [] else None,
}
response.set_data(jsonify(new_response).data)
return response


@app.after_request
def add_common_headers(response):
response.headers["X-Content-Type-Options"] = "nosniff"
return response


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
57 changes: 57 additions & 0 deletions jwt_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
from datetime import datetime, timedelta, timezone
from functools import wraps

import jwt
from flask import jsonify, request

SECRET_KEY = os.getenv("SECRET_JWT_KEY", "SuperSecretKey")
ACCESS_TOKEN_EXPIRY = timedelta(hours=1)
REFRESH_TOKEN_EXPIRY = timedelta(days=30)


def generate_access_token(player_id: int) -> str:
"""Generate a JWT token for a user."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + ACCESS_TOKEN_EXPIRY, # Expiration
"iat": datetime.now(timezone.utc), # Issued at
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def generate_refresh_token(player_id: int) -> str:
"""Generate a long-lived refresh token."""
payload = {
"player_id": player_id,
"exp": datetime.now(timezone.utc) + REFRESH_TOKEN_EXPIRY,
"iat": datetime.now(timezone.utc),
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")


def verify_token(token: str) -> dict | None:
"""Verify a JWT token and return the payload."""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return None # Token expired
except jwt.InvalidTokenError:
return None # Invalid token


def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get("Authorization")
if not token:
return jsonify(message="Token is missing"), 401

decoded = verify_token(token)
if not decoded:
return jsonify(message="Token is invalid or expired"), 401

request.player_id = decoded["player_id"] # Attach user ID to the request
return f(*args, **kwargs)

return decorated
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
argon2-cffi>=23.1.0
flask>=3.0.3
flask-cors>=5.0.0
pyjwt>=2.10.0
pymysql>=1.1.1
pytest>=8.3.3
python-dotenv>=1.0.1
2 changes: 2 additions & 0 deletions routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .authentication import authentication_blueprint
from .battle import battles_blueprint
from .building import buildings_blueprint
from .city import cities_blueprint
Expand All @@ -8,6 +9,7 @@


def register_routes(app):
app.register_blueprint(authentication_blueprint, url_prefix="/auth")
app.register_blueprint(battles_blueprint, url_prefix="/battles")
app.register_blueprint(buildings_blueprint, url_prefix="/buildings")
app.register_blueprint(cities_blueprint, url_prefix="/cities")
Expand Down
137 changes: 137 additions & 0 deletions routes/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
from re import match

import jwt
from argon2 import PasswordHasher, exceptions
from dotenv import load_dotenv
from flask import Blueprint, jsonify, request
from pymysql import MySQLError

from db import get_db_connection
from jwt_helper import generate_access_token, generate_refresh_token

load_dotenv()

authentication_blueprint = Blueprint("authentication", __name__)
ph = PasswordHasher()


def hash_password_with_salt_and_pepper(password: str) -> tuple[str, bytes]:
salt = os.urandom(16)
pepper = os.getenv("PEPPER", "SuperSecretPepper").encode("utf-8")
seasoned_password = password.encode("utf-8") + salt + pepper
return ph.hash(seasoned_password), salt


def validate_password(password):
"""
Validates a password based on the following criteria:
- At least 12 characters long.
- Contains at least one uppercase letter (A-Z).
- Contains at least one lowercase letter (a-z).
- Contains at least one digit (0-9).
- Contains at least one special character (any non-alphanumeric character).
"""
return bool(
match(r"^(?=.*[A-Z])(?=.*[a-z])(?=.*\d)(?=.*[^A-Za-z0-9]).{12,}$", password)
)


@authentication_blueprint.route("/register", methods=["POST"])
def register():
data = request.get_json()
name = data.get("name")
email = data.get("email")
password = data.get("password")

if not name or not email or not password:
return jsonify(message="Username, email, and password are required"), 400

if not validate_password(password):
return jsonify(message="Password does not meet security requirements"), 400

hashed_password, salt = hash_password_with_salt_and_pepper(password)

db = get_db_connection()
with db.cursor() as cursor:
try:
cursor.callproc("register_player", (name, email, hashed_password, salt))
db.commit()
except MySQLError as e:
# Check for specific error messages in the SQL error
if "Player name already exists" in str(e):
return jsonify(message="Player name already exists"), 400
elif "Email already exists" in str(e):
return jsonify(message="Email already exists"), 400
else:
return jsonify(message="An error occurred during registration"), 500

db.close()
return jsonify(message="User created successfully"), 201


@authentication_blueprint.route("/login", methods=["POST"])
def login():
data = request.get_json()
email = data.get("email")
password = data.get("password")

if not email or not password:
return jsonify(message="Email and password are required"), 400

db = get_db_connection()
with db.cursor() as cursor:
cursor.execute(
"SELECT player_id, hashed_password, salt FROM player WHERE email = %s",
(email,),
)
player = cursor.fetchone()

if not player:
return jsonify(message="Invalid credentials"), 401

player_id = player["player_id"]
stored_password = player["hashed_password"]
salt = player["salt"]
pepper = os.getenv("PEPPER").encode("utf-8")
seasoned_password = password.encode("utf-8") + salt + pepper

try:
ph.verify(stored_password, seasoned_password)
access_token = generate_access_token(player_id)
refresh_token = generate_refresh_token(player_id)
return jsonify(
message="Login successful",
access_token=access_token,
refresh_token=refresh_token,
)
except exceptions.VerifyMismatchError:
return jsonify(message="Invalid credentials"), 401


@authentication_blueprint.route("/refresh", methods=["POST"])
def refresh_token():
auth_header = request.headers.get("Authorization")

if not auth_header or not auth_header.startswith("Bearer "):
return (
jsonify(message="Refresh token is required in the Authorization header"),
400,
)

refresh_token = auth_header.split("Bearer ")[1]

try:
decoded = jwt.decode(
refresh_token,
os.getenv("SECRET_JWT_KEY", "SuperSecretKey"),
algorithms=["HS256"],
)
player_id = decoded["player_id"]

new_access_token = generate_access_token(player_id)
return jsonify(access_token=new_access_token), 200
except jwt.ExpiredSignatureError:
return jsonify(message="Refresh token has expired, please log in again"), 401
except jwt.InvalidTokenError:
return jsonify(message="Invalid refresh token"), 401
2 changes: 2 additions & 0 deletions routes/battle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

battles_blueprint = Blueprint("battles", __name__)

Expand Down Expand Up @@ -28,6 +29,7 @@ def get_battle_by_id(battle_id):


@battles_blueprint.route("/<int:battle_id>/units", methods=["GET"])
@token_required
def get_battle_units(battle_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
2 changes: 2 additions & 0 deletions routes/building.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

buildings_blueprint = Blueprint("buildings", __name__)

Expand All @@ -17,6 +18,7 @@ def get_all_buildings():


@buildings_blueprint.route("/<int:building_id>", methods=["GET"])
@token_required
def get_building_by_id(building_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
3 changes: 3 additions & 0 deletions routes/city.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify

from db import get_db_connection
from jwt_helper import token_required

cities_blueprint = Blueprint("cities", __name__)

Expand Down Expand Up @@ -28,6 +29,7 @@ def get_city_by_id(city_id):


@cities_blueprint.route("/<int:city_id>/buildings", methods=["GET"])
@token_required
def get_city_buildings(city_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand All @@ -39,6 +41,7 @@ def get_city_buildings(city_id):


@cities_blueprint.route("/<int:city_id>/units", methods=["GET"])
@token_required
def get_city_units(city_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
2 changes: 2 additions & 0 deletions routes/player.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify, request

from db import get_db_connection
from jwt_helper import token_required

players_blueprint = Blueprint("players", __name__)

Expand Down Expand Up @@ -72,6 +73,7 @@ def get_player_cities(player_id):


@players_blueprint.route("/<int:player_id>/battles", methods=["GET"])
@token_required
def get_player_battles(player_id):
db = get_db_connection()
with db.cursor() as cursor:
Expand Down
19 changes: 2 additions & 17 deletions tests/test_battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ def test_get_battle_by_id(mock_get_db, client):


# Test get_battle_units endpoint
@patch("routes.battle.get_db_connection")
def test_get_battle_units(mock_get_db, client):
mock_db_response(
mock_get_db,
[
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
],
)

def test_get_battle_units(client):
response = client(battles_blueprint).get("/battles/1/units")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == [
{"unit_id": 101, "name": "Swordsman", "count": 50, "side": 0},
{"unit_id": 102, "name": "Archer", "count": 30, "side": 1},
]
assert response.status_code == 401
10 changes: 2 additions & 8 deletions tests/test_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ def test_get_all_buildings(mock_get_db, client):


# Test get_building_by_id endpoint
@patch("routes.building.get_db_connection")
def test_get_building_by_id(mock_get_db, client):
mock_db_response(mock_get_db, {"id": 1, "name": "Barracks"}, fetchone=True)

def test_get_building_by_id(client):
response = client(buildings_blueprint).get("/buildings/1")
json_data = response.get_json(force=True)

assert response.status_code == 200
assert json_data == {"id": 1, "name": "Barracks"}
assert response.status_code == 401


# Test get_building_prerequisites endpoint
Expand Down
Loading