Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace pyjwkest with pyjwt package #349

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Please See the `releases tab <https://github.com/openedx/xblock-lti-consumer/rel
Unreleased
~~~~~~~~~~

9.13.0 - 2025-01-08
-------------------
* Removed pyjwkset package and replace with pyjwt package

9.12.0 - 2024-11-14
-------------------
* Dropped support for Python 3.8 and added support for Python 3.12.
Expand Down
2 changes: 1 addition & 1 deletion lti_consumer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .apps import LTIConsumerApp
from .lti_xblock import LtiConsumerXBlock

__version__ = '9.12.1'
__version__ = '9.13.0'
195 changes: 64 additions & 131 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
This handles validating messages sent by the tool and generating
access token with LTI scopes.
"""
import codecs
import copy
import time
import json
import math
import sys
import time
import logging

from Cryptodome.PublicKey import RSA
import jwt
from edx_django_utils.monitoring import function_trace
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
from jwkest.jwk import RSAKey, load_jwks_from_url
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm
from jwkest.jwt import JWT
from jwt.api_jwk import PyJWK

from . import exceptions

Expand Down Expand Up @@ -52,14 +50,11 @@ def __init__(self, public_key=None, keyset_url=None):
# Import from public key
if public_key:
try:
new_key = RSAKey(use='sig')

# Unescape key before importing it
raw_key = codecs.decode(public_key, 'unicode_escape')

# Import Key and save to internal state
new_key.load_key(RSA.import_key(raw_key))
self.public_key = new_key
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(public_key)
public_jwk = json.loads(algo_obj.to_jwk(public_key))
self.public_key = PyJWK.from_dict(public_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI tool\'s key from the public key. '
Expand All @@ -78,7 +73,7 @@ def _get_keyset(self, kid=None):

if self.keyset_url:
try:
keys = load_jwks_from_url(self.keyset_url)
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
except Exception as err:
# Broad Exception is required here because jwkest raises
# an Exception object explicitly.
Expand All @@ -89,7 +84,7 @@ def _get_keyset(self, kid=None):
'The RSA keys could not be loaded.'
)
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys)
keyset.extend(keys.keys)

if self.public_key and kid:
# Fill in key id of stored key.
Expand All @@ -98,6 +93,7 @@ def _get_keyset(self, kid=None):
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
# Add to keyset
keyset.append(self.public_key)

Expand All @@ -112,49 +108,29 @@ def validate_and_decode(self, token):
The authorization server decodes the JWT and MUST validate the values for the
iss, sub, exp, aud and jti claims.
"""
try:
# Get KID from JWT header
jwt = JWT().unpack(token)
key_set = self._get_keyset()

# Verify message signature
message = JWS().verify_compact(
token,
keys=self._get_keyset(
jwt.headers.get('kid')
)
)

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT has expired.'
for i, obj in enumerate(key_set):
try:
if hasattr(obj.key, 'public_key'):
key = obj.key.public_key()
else:
key = obj.key
message = jwt.decode(
token,
key,
algorithms=['RS256', 'RS512',],
options={
'verify_signature': True,
'verify_aud': False
}
)
raise exceptions.TokenSignatureExpired()
return message
except Exception: # pylint: disable=broad-except
if i == len(key_set) - 1:
raise

# TODO: Validate other JWT claims

# Else returns decoded message
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except (BadSyntax, WrongNumberOfParts) as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except BadSignature as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT signature is incorrect.'
)
raise exceptions.BadJwtSignature() from err
raise exceptions.NoSuitableKeys()


class PlatformKeyHandler:
Expand All @@ -174,14 +150,11 @@ def __init__(self, key_pem, kid=None):
if key_pem:
# Import JWK from RSA key
try:
self.key = RSAKey(
# Using the same key ID as client id
# This way we can easily serve multiple public
# keys on the same endpoint and keep all
# LTI 1.3 blocks working
kid=kid,
key=RSA.import_key(key_pem)
)
algo = jwt.get_algorithm_by_name('RS256')
private_key = algo.prepare_key(key_pem)
private_jwk = json.loads(algo.to_jwk(private_key))
private_jwk['kid'] = kid
self.key = PyJWK.from_dict(private_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI platform\'s key. '
Expand All @@ -206,92 +179,52 @@ def encode_and_sign(self, message, expiration=None):
# Set iat and exp if expiration is set
if expiration:
_message.update({
"iat": int(round(time.time())),
"exp": int(round(time.time()) + expiration),
"iat": int(math.floor(time.time())),
"exp": int(math.floor(time.time()) + expiration),
})

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
_jws = JWS(_message, alg="RS256", cty="JWT")

try:
# Encode and sign LTI message
return _jws.sign_compact([self.key])
except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except UnknownAlgorithm as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There algorithm is unknown.'
)
raise exceptions.MalformedJwtToken() from err
return jwt.encode(_message, self.key.key, algorithm="RS256")

def get_public_jwk(self):
"""
Export Public JWK
"""
public_keys = jwk.KEYS()
jwk = {"keys": []}

# Only append to keyset if a key exists
if self.key:
public_keys.append(self.key)

return json.loads(public_keys.dump_jwks())
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
return jwk

def validate_and_decode(self, token, iss=None, aud=None):
def validate_and_decode(self, token, iss=None, aud=None, exp=True):
"""
Check if a platform token is valid, and return allowed scopes.

Validates a token sent by the tool using the platform's RSA Key.
Optionally validate iss and aud claims if provided.
"""
if not self.key:
raise exceptions.RsaKeyNotSet()
try:
# Verify message signature
message = JWS().verify_compact(token, keys=[self.key])

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT has expired.'
)
raise exceptions.TokenSignatureExpired()

# Validate issuer claim (if present)
log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. '
if iss:
if 'iss' not in message or message['iss'] != iss:
error_message = 'The required iss claim is missing or does not match the expected iss value. '
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Validate audience claim (if present)
if aud:
if 'aud' not in message or aud not in message['aud']:
error_message = 'The required aud claim is missing.'
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Else return token contents
message = jwt.decode(
token,
key=self.key.key.public_key(),
audience=aud,
issuer=iss,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_exp': bool(exp),
'verify_iss': bool(iss),
'verify_aud': bool(aud)
}
)
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except BadSyntax as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
Loading
Loading