Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWT Token authenticator #542

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/argus/auth/authentication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from datetime import timedelta
from urllib.request import urlopen
from urllib.parse import urljoin
import json
import jwt

from django.conf import settings
from django.utils import timezone
from rest_framework.authentication import TokenAuthentication
from rest_framework.authentication import TokenAuthentication, BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed

from .models import User


class ExpiringTokenAuthentication(TokenAuthentication):
EXPIRATION_DURATION = timedelta(days=settings.AUTH_TOKEN_EXPIRES_AFTER_DAYS)
Expand All @@ -17,3 +23,80 @@ def authenticate_credentials(self, key):
raise AuthenticationFailed("Token has expired.")

return user, token


class JWTAuthentication(BaseAuthentication):
REQUIRED_CLAIMS = ["exp", "nbf", "aud", "iss", "sub"]
SUPPORTED_ALGORITHMS = ["RS256", "RS384", "RS512"]
AUTH_SCHEME = "Bearer"

def authenticate(self, request):
try:
raw_token = self.get_raw_token(request)
except ValueError:
return None
validated_token = self.decode_token(raw_token)
return self.get_user(validated_token), validated_token

def get_public_key(self, kid):
r = urlopen(self.get_jwk_endpoint())
jwks = json.loads(r.read())
for jwk in jwks.get("keys"):
if jwk["kid"] == kid:
return jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk))
raise AuthenticationFailed(f"Invalid kid '{kid}'")

def get_raw_token(self, request):
"""Raises ValueError if a jwt token could not be found"""
auth_header = request.META.get("HTTP_AUTHORIZATION")
if not auth_header:
raise ValueError("No Authorization header found")
try:
scheme, token = auth_header.split()
except ValueError as e:
raise ValueError(f"Failed to parse Authorization header: {e}")
if scheme != self.AUTH_SCHEME:
raise ValueError(f"Invalid Authorization scheme '{scheme}'")
return token

def decode_token(self, raw_token):
kid = self.get_kid(raw_token)
try:
validated_token = jwt.decode(
jwt=raw_token,
algorithms=self.SUPPORTED_ALGORITHMS,
key=self.get_public_key(kid),
options={"require": self.REQUIRED_CLAIMS},
audience=settings.JWT_AUDIENCE,
issuer=self.get_openid_issuer(),
)
return validated_token
except jwt.exceptions.PyJWTError as e:
raise AuthenticationFailed(f"Error validating token: {e}")

def get_user(self, token):
username = token["sub"]
try:
return User.objects.get(username=username)
except User.DoesNotExist:
raise AuthenticationFailed(f"No user found for username '{username}'")

def get_openid_config(self):
url = urljoin(settings.OIDC_ENDPOINT, ".well-known/openid-configuration")
r = urlopen(url)
return json.loads(r.read())

def get_jwk_endpoint(self):
openid_config = self.get_openid_config()
return openid_config["jwks_uri"]

def get_openid_issuer(self):
openid_config = self.get_openid_config()
return openid_config["issuer"]

def get_kid(self, token):
header = jwt.get_unverified_header(token)
kid = header.get("kid")
if not kid:
raise AuthenticationFailed("Token must include the 'kid' header")
return kid
4 changes: 4 additions & 0 deletions src/argus/site/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
"argus.auth.authentication.ExpiringTokenAuthentication",
# For BrowsableAPIRenderer
"rest_framework.authentication.SessionAuthentication",
"argus.auth.authentication.JWTAuthentication",
),
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
"DEFAULT_RENDERER_CLASSES": (
Expand Down Expand Up @@ -301,3 +302,6 @@
#
# SOCIAL_AUTH_DATAPORTEN_FEIDE_KEY = SOCIAL_AUTH_DATAPORTEN_KEY
# SOCIAL_AUTH_DATAPORTEN_FEIDE_SECRET = SOCIAL_AUTH_DATAPORTEN_SECRET

OIDC_ENDPOINT = get_str_env("OIDC_ENDPOINT")
JWT_AUDIENCE = get_str_env("JWT_AUDIENCE")