From 52c520d7caff9ecf4e6a1741809eb2d59de6ab7b Mon Sep 17 00:00:00 2001 From: M Umar Khan Date: Wed, 5 Apr 2023 23:51:03 +0500 Subject: [PATCH 1/4] chore: add pyjwt requirement --- lti_consumer/lti_1p3/key_handlers.py | 195 ++++++------------ lti_consumer/lti_1p3/tests/test_consumer.py | 46 +++-- .../lti_1p3/tests/test_key_handlers.py | 70 +++---- lti_consumer/lti_1p3/tests/utils.py | 8 +- lti_consumer/plugin/views.py | 42 ++-- .../tests/unit/plugin/test_proctoring.py | 4 +- lti_consumer/tests/unit/plugin/test_views.py | 9 +- lti_consumer/tests/unit/test_lti_xblock.py | 11 +- requirements/base.in | 1 + requirements/base.txt | 2 + requirements/ci.txt | 2 + requirements/dev.txt | 2 + requirements/quality.txt | 2 + requirements/test.txt | 2 + 14 files changed, 173 insertions(+), 223 deletions(-) diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py index 0d530df7..7993d2e6 100644 --- a/lti_consumer/lti_1p3/key_handlers.py +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -4,18 +4,16 @@ This handles validating messages sent by the tool and generating access token with LTI scopes. """ -import codecs import copy -import time import json +import math +import time +import sys import logging +import jwt from Cryptodome.PublicKey import RSA from edx_django_utils.monitoring import function_trace -from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk -from jwkest.jwk import RSAKey, load_jwks_from_url -from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm -from jwkest.jwt import JWT from . import exceptions @@ -52,14 +50,9 @@ def __init__(self, public_key=None, keyset_url=None): # Import from public key if public_key: try: - new_key = RSAKey(use='sig') - - # Unescape key before importing it - raw_key = codecs.decode(public_key, 'unicode_escape') - # Import Key and save to internal state - new_key.load_key(RSA.import_key(raw_key)) - self.public_key = new_key + algo_obj = jwt.get_algorithm_by_name('RS256') + self.public_key = algo_obj.prepare_key(public_key) except ValueError as err: log.warning( 'An error was encountered while loading the LTI tool\'s key from the public key. ' @@ -78,7 +71,7 @@ def _get_keyset(self, kid=None): if self.keyset_url: try: - keys = load_jwks_from_url(self.keyset_url) + keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set() except Exception as err: # Broad Exception is required here because jwkest raises # an Exception object explicitly. @@ -91,13 +84,13 @@ def _get_keyset(self, kid=None): raise exceptions.NoSuitableKeys() from err keyset.extend(keys) - if self.public_key and kid: - # Fill in key id of stored key. - # This is needed because if the JWS is signed with a - # key with a kid, pyjwkest doesn't match them with - # keys without kid (kid=None) and fails verification - self.public_key.kid = kid - + if self.public_key: + if kid: + # Fill in key id of stored key. + # This is needed because if the JWS is signed with a + # key with a kid, pyjwkest doesn't match them with + # keys without kid (kid=None) and fails verification + self.public_key.kid = kid # Add to keyset keyset.append(self.public_key) @@ -113,48 +106,24 @@ def validate_and_decode(self, token): iss, sub, exp, aud and jti claims. """ try: - # Get KID from JWT header - jwt = JWT().unpack(token) - - # Verify message signature - message = JWS().verify_compact( - token, - keys=self._get_keyset( - jwt.headers.get('kid') - ) - ) - - # If message is valid, check expiration from JWT - if 'exp' in message and message['exp'] < time.time(): - log.warning( - 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. ' - 'The JWT has expired.' - ) - raise exceptions.TokenSignatureExpired() - - # TODO: Validate other JWT claims - - # Else returns decoded message - return message - - except NoSuitableSigningKeys as err: - log.warning( - 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. ' - 'There is no suitable signing key.' - ) - raise exceptions.NoSuitableKeys() from err - except (BadSyntax, WrongNumberOfParts) as err: - log.warning( - 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. ' - 'The JWT is malformed.' - ) - raise exceptions.MalformedJwtToken() from err - except BadSignature as err: - log.warning( - 'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. ' - 'The JWT signature is incorrect.' - ) - raise exceptions.BadJwtSignature() from err + key_set = self._get_keyset() + if not key_set: + raise exceptions.NoSuitableKeys() + for i in range(len(key_set)): + try: + message = jwt.decode( + token, + key=key_set[i], + algorithms=['RS256', 'RS512',], + options={'verify_signature': True} + ) + return message + except Exception: + if i == len(key_set) - 1: + raise + except Exception as token_error: + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error class PlatformKeyHandler: @@ -174,14 +143,8 @@ def __init__(self, key_pem, kid=None): if key_pem: # Import JWK from RSA key try: - self.key = RSAKey( - # Using the same key ID as client id - # This way we can easily serve multiple public - # keys on the same endpoint and keep all - # LTI 1.3 blocks working - kid=kid, - key=RSA.import_key(key_pem) - ) + algo = jwt.get_algorithm_by_name('RS256') + self.key = algo.prepare_key(key_pem) except ValueError as err: log.warning( 'An error was encountered while loading the LTI platform\'s key. ' @@ -206,41 +169,26 @@ def encode_and_sign(self, message, expiration=None): # Set iat and exp if expiration is set if expiration: _message.update({ - "iat": int(round(time.time())), - "exp": int(round(time.time()) + expiration), + "iat": int(math.floor(time.time())), + "exp": int(math.floor(time.time()) + expiration), }) # The class instance that sets up the signing operation # An RS 256 key is required for LTI 1.3 - _jws = JWS(_message, alg="RS256", cty="JWT") - - try: - # Encode and sign LTI message - return _jws.sign_compact([self.key]) - except NoSuitableSigningKeys as err: - log.warning( - 'An error was encountered while signing the OAuth 2.0 access token JWT. ' - 'There is no suitable signing key.' - ) - raise exceptions.NoSuitableKeys() from err - except UnknownAlgorithm as err: - log.warning( - 'An error was encountered while signing the OAuth 2.0 access token JWT. ' - 'There algorithm is unknown.' - ) - raise exceptions.MalformedJwtToken() from err + return jwt.encode(_message, self.key, algorithm="RS256") def get_public_jwk(self): """ Export Public JWK """ - public_keys = jwk.KEYS() + jwk = {"keys": []} # Only append to keyset if a key exists if self.key: - public_keys.append(self.key) - - return json.loads(public_keys.dump_jwks()) + algo_obj = jwt.get_algorithm_by_name('RS256') + public_key = algo_obj.prepare_key(self.key).public_key() + jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key))) + return jwk def validate_and_decode(self, token, iss=None, aud=None): """ @@ -249,49 +197,22 @@ def validate_and_decode(self, token, iss=None, aud=None): Validates a token sent by the tool using the platform's RSA Key. Optionally validate iss and aud claims if provided. """ + if not self.key: + raise exceptions.RsaKeyNotSet() try: - # Verify message signature - message = JWS().verify_compact(token, keys=[self.key]) - - # If message is valid, check expiration from JWT - if 'exp' in message and message['exp'] < time.time(): - log.warning( - 'An error was encountered while verifying the OAuth 2.0 access token. ' - 'The JWT has expired.' - ) - raise exceptions.TokenSignatureExpired() - - # Validate issuer claim (if present) - log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. ' - if iss: - if 'iss' not in message or message['iss'] != iss: - error_message = 'The required iss claim is missing or does not match the expected iss value. ' - log_message = log_message_base + error_message - - log.warning(log_message) - raise exceptions.InvalidClaimValue(error_message) - - # Validate audience claim (if present) - if aud: - if 'aud' not in message or aud not in message['aud']: - error_message = 'The required aud claim is missing.' - log_message = log_message_base + error_message - - log.warning(log_message) - raise exceptions.InvalidClaimValue(error_message) - - # Else return token contents + message = jwt.decode( + token, + key=self.key.public_key(), + audience=aud, + issuer=iss, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': True if aud else False + } + ) return message - except NoSuitableSigningKeys as err: - log.warning( - 'An error was encountered while verifying the OAuth 2.0 access token. ' - 'There is no suitable signing key.' - ) - raise exceptions.NoSuitableKeys() from err - except BadSyntax as err: - log.warning( - 'An error was encountered while verifying the OAuth 2.0 access token. ' - 'The JWT is malformed.' - ) - raise exceptions.MalformedJwtToken() from err + except Exception as token_error: + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py index c261502a..fa7fb1af 100644 --- a/lti_consumer/lti_1p3/tests/test_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -2,18 +2,18 @@ Unit tests for LTI 1.3 consumer implementation """ -import json from unittest.mock import patch from urllib.parse import parse_qs, urlparse import uuid import ddt +import jwt +import sys from Cryptodome.PublicKey import RSA from django.conf import settings from django.test.testcases import TestCase from edx_django_utils.cache import get_cache_key, TieredCache -from jwkest.jwk import load_jwks -from jwkest.jws import JWS +from jwt.api_jwk import PyJWKSet from lti_consumer.data import Lti1p3LaunchData from lti_consumer.lti_1p3 import exceptions @@ -36,7 +36,9 @@ STATE = "ABCD" # Consider storing a fixed key RSA_KEY_ID = "1" -RSA_KEY = RSA.generate(2048).export_key('PEM') +RSA_KEY = RSA.generate(2048) +RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM') +RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM') def _generate_token_request_data(token, scope): @@ -69,11 +71,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) def _setup_lti_launch_data(self): @@ -118,9 +120,25 @@ def _decode_token(self, token): This also tests the public keyset function. """ public_keyset = self.lti_consumer.get_public_keyset() - key_set = load_jwks(json.dumps(public_keyset)) - - return JWS().verify_compact(token, keys=key_set) + keyset = PyJWKSet.from_dict(public_keyset).keys + + for i in range(len(keyset)): + try: + message = jwt.decode( + token, + key=keyset[i].key, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': False + } + ) + return message + except Exception as token_error: + if i < len(keyset) - 1: + continue + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error @ddt.data( ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True), @@ -558,7 +576,7 @@ def test_access_token_invalid_jwt(self): """ request_data = _generate_token_request_data("invalid_jwt", "") - with self.assertRaises(exceptions.MalformedJwtToken): + with self.assertRaises(jwt.exceptions.InvalidTokenError): self.lti_consumer.access_token(request_data) def test_access_token_no_acs(self): @@ -686,11 +704,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) self.preflight_response = {} @@ -930,11 +948,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) self.preflight_response = {} diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py index 43b5ffc5..bc5710aa 100644 --- a/lti_consumer/lti_1p3/tests/test_key_handlers.py +++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py @@ -3,9 +3,12 @@ """ import json +import math +import time from unittest.mock import patch import ddt +import jwt from Cryptodome.PublicKey import RSA from django.test.testcases import TestCase from jwkest import BadSignature @@ -131,18 +134,17 @@ def test_empty_rsa_key(self): {'keys': []} ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode(self, mock_time): + def test_validate_and_decode(self): """ Test validate and decode with all parameters. """ + expiration = 1000 signed_token = self.key_handler.encode_and_sign( { "iss": "test-issuer", "aud": "test-aud", }, - expiration=1000 + expiration=expiration ) self.assertEqual( @@ -150,14 +152,12 @@ def test_validate_and_decode(self, mock_time): { "iss": "test-issuer", "aud": "test-aud", - "iat": 1000, - "exp": 2000 + "iat": int(math.floor(time.time())), + "exp": int(math.floor(time.time()) + expiration), } ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode_expired(self, mock_time): + def test_validate_and_decode_expired(self): """ Test validate and decode with all parameters. """ @@ -166,7 +166,7 @@ def test_validate_and_decode_expired(self, mock_time): expiration=-10 ) - with self.assertRaises(exceptions.TokenSignatureExpired): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token) def test_validate_and_decode_invalid_iss(self): @@ -175,7 +175,7 @@ def test_validate_and_decode_invalid_iss(self): """ signed_token = self.key_handler.encode_and_sign({"iss": "wrong"}) - with self.assertRaises(exceptions.InvalidClaimValue): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token, iss="right") def test_validate_and_decode_invalid_aud(self): @@ -184,14 +184,14 @@ def test_validate_and_decode_invalid_aud(self): """ signed_token = self.key_handler.encode_and_sign({"aud": "wrong"}) - with self.assertRaises(exceptions.InvalidClaimValue): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token, aud="right") def test_validate_and_decode_no_jwt(self): """ Test validate and decode with invalid JWT. """ - with self.assertRaises(exceptions.MalformedJwtToken): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode("1.2.3") def test_validate_and_decode_no_keys(self): @@ -199,10 +199,10 @@ def test_validate_and_decode_no_keys(self): Test validate and decode when no keys are available. """ signed_token = self.key_handler.encode_and_sign({}) - # Changing the KID so it doesn't match - self.key_handler.key.kid = "invalid_kid" - with self.assertRaises(exceptions.NoSuitableKeys): + self.key_handler.key = None + + with self.assertRaises(exceptions.RsaKeyNotSet): self.key_handler.validate_and_decode(signed_token) @@ -217,12 +217,10 @@ def setUp(self): self.rsa_key_id = "1" # Generate RSA and save exports - rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) - self.public_key = rsa_key.publickey().export_key() + rsa_key = RSA.generate(2048).export_key('PEM') + algo_obj = jwt.get_algorithm_by_name('RS256') + self.key = algo_obj.prepare_key(rsa_key) + self.public_key = self.key.public_key() # Key handler self.key_handler = None @@ -272,9 +270,7 @@ def test_get_keyset_with_pub_key(self): self.rsa_key_id ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode(self, mock_time): + def test_validate_and_decode(self): """ Check that the validate and decode works. """ @@ -283,7 +279,7 @@ def test_validate_and_decode(self, mock_time): message = { "test": "test_message", "iat": 1000, - "exp": 1200, + "exp": int(math.floor(time.time()) + 1000), } signed = create_jwt(self.key, message) @@ -291,9 +287,7 @@ def test_validate_and_decode(self, mock_time): decoded_message = self.key_handler.validate_and_decode(signed) self.assertEqual(decoded_message, message) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode_expired(self, mock_time): + def test_validate_and_decode_expired(self): """ Check that the validate and decode raises when signature expires. """ @@ -307,7 +301,7 @@ def test_validate_and_decode_expired(self, mock_time): signed = create_jwt(self.key, message) # Decode and check results - with self.assertRaises(exceptions.TokenSignatureExpired): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed) def test_validate_and_decode_no_keys(self): @@ -324,14 +318,13 @@ def test_validate_and_decode_no_keys(self): signed = create_jwt(self.key, message) # Decode and check results - with self.assertRaises(exceptions.NoSuitableKeys): + with self.assertRaises(jwt.InvalidTokenError): key_handler.validate_and_decode(signed) - @patch("lti_consumer.lti_1p3.key_handlers.JWS.verify_compact") - def test_validate_and_decode_bad_signature(self, mock_verify_compact): - mock_verify_compact.side_effect = BadSignature() - - key_handler = ToolKeyHandler() + @patch("lti_consumer.lti_1p3.key_handlers.jwt.decode") + def test_validate_and_decode_bad_signature(self, mock_jwt_decode): + mock_jwt_decode.side_effect = Exception() + self._setup_key_handler() message = { "test": "test_message", @@ -340,6 +333,5 @@ def test_validate_and_decode_bad_signature(self, mock_verify_compact): } signed = create_jwt(self.key, message) - # Decode and check results - with self.assertRaises(exceptions.BadJwtSignature): - key_handler.validate_and_decode(signed) + with self.assertRaises(jwt.InvalidTokenError): + self.key_handler.validate_and_decode(signed) diff --git a/lti_consumer/lti_1p3/tests/utils.py b/lti_consumer/lti_1p3/tests/utils.py index 3a76d162..3aae56ca 100644 --- a/lti_consumer/lti_1p3/tests/utils.py +++ b/lti_consumer/lti_1p3/tests/utils.py @@ -1,12 +1,14 @@ """ Test utils """ -from jwkest.jws import JWS +import jwt def create_jwt(key, message): """ Uses private key to create a JWS from a dict. """ - jws = JWS(message, alg="RS256", cty="JWT") - return jws.sign_compact([key]) + token = jwt.encode( + message, key, algorithm='RS256' + ) + return token diff --git a/lti_consumer/plugin/views.py b/lti_consumer/plugin/views.py index 7c211647..773ac2a4 100644 --- a/lti_consumer/plugin/views.py +++ b/lti_consumer/plugin/views.py @@ -2,8 +2,10 @@ LTI consumer plugin passthrough views """ import logging +import sys import urllib +import jwt from django.conf import settings from django.contrib.auth import get_user_model from django.core.exceptions import PermissionDenied, ValidationError @@ -16,7 +18,6 @@ from django.views.decorators.http import require_http_methods from django_filters.rest_framework import DjangoFilterBackend from edx_django_utils.cache import TieredCache, get_cache_key -from jwkest.jwt import JWT, BadSyntax from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import UsageKey from rest_framework import status, viewsets @@ -475,20 +476,22 @@ def access_token_endpoint( )) ) return JsonResponse(token) - - # Handle errors and return a proper response - except MissingRequiredClaim: - # Missing request attributes - return JsonResponse({"error": "invalid_request"}, status=HTTP_400_BAD_REQUEST) - except (MalformedJwtToken, TokenSignatureExpired): - # Triggered when an invalid grant token is used - return JsonResponse({"error": "invalid_grant"}, status=HTTP_400_BAD_REQUEST) - except (NoSuitableKeys, UnknownClientId): - # Client ID is not registered in the block or - # isn't possible to validate token using available keys. - return JsonResponse({"error": "invalid_client"}, status=HTTP_400_BAD_REQUEST) - except UnsupportedGrantType: - return JsonResponse({"error": "unsupported_grant_type"}, status=HTTP_400_BAD_REQUEST) + except Exception as token_error: + exc_info = sys.exc_info() + + # Handle errors and return a proper response + if exc_info[0] == MissingRequiredClaim: + # Missing request attributes + return JsonResponse({"error": "invalid_request"}, status=HTTP_400_BAD_REQUEST) + elif exc_info[0] in (MalformedJwtToken, TokenSignatureExpired, jwt.InvalidTokenError): + # Triggered when a invalid grant token is used + return JsonResponse({"error": "invalid_grant"}, status=HTTP_400_BAD_REQUEST) + elif exc_info[0] == UnsupportedGrantType: + return JsonResponse({"error": "unsupported_grant_type"}, status=HTTP_400_BAD_REQUEST) + else: + # Client ID is not registered in the block or + # isn't possible to validate token using available keys. + return JsonResponse({"error": "invalid_client"}, status=HTTP_400_BAD_REQUEST) # Post from external tool that doesn't @@ -868,13 +871,12 @@ def start_proctoring_assessment_endpoint(request): token = request.POST.get('JWT') try: - jwt = JWT().unpack(token) - except BadSyntax: + decoded_jwt = jwt.decode(token, options={'verify_signature': False}) + except Exception: return render(request, 'html/lti_proctoring_start_error.html', status=HTTP_400_BAD_REQUEST) - jwt_payload = jwt.payload() - iss = jwt_payload.get('iss') - resource_link_id = jwt_payload.get('https://purl.imsglobal.org/spec/lti/claim/resource_link', {}).get('id') + iss = decoded_jwt.get('iss') + resource_link_id = decoded_jwt.get('https://purl.imsglobal.org/spec/lti/claim/resource_link', {}).get('id') try: lti_config = LtiConfiguration.objects.get(lti_1p3_client_id=iss) diff --git a/lti_consumer/tests/unit/plugin/test_proctoring.py b/lti_consumer/tests/unit/plugin/test_proctoring.py index 67da1001..5f4e8167 100644 --- a/lti_consumer/tests/unit/plugin/test_proctoring.py +++ b/lti_consumer/tests/unit/plugin/test_proctoring.py @@ -137,8 +137,8 @@ def test_valid_token(self): def test_unparsable_token(self): """Tests that a call to the start_assessment_endpoint with an unparsable token results in a 400 response.""" - with patch("lti_consumer.plugin.views.JWT.unpack") as mock_jwt_unpack_method: - mock_jwt_unpack_method.side_effect = BadSyntax(value="", msg="") + with patch("lti_consumer.plugin.views.jwt.decode") as mock_jwt_decode_method: + mock_jwt_decode_method.side_effect = Exception response = self.client.post( self.url, diff --git a/lti_consumer/tests/unit/plugin/test_views.py b/lti_consumer/tests/unit/plugin/test_views.py index a6226448..0b7bca3a 100644 --- a/lti_consumer/tests/unit/plugin/test_views.py +++ b/lti_consumer/tests/unit/plugin/test_views.py @@ -5,7 +5,7 @@ from unittest.mock import patch, Mock import ddt - +import jwt from django.test.testcases import TestCase from django.urls import reverse from edx_django_utils.cache import TieredCache, get_cache_key @@ -674,8 +674,11 @@ def setUp(self): ) self.addCleanup(get_lti_consumer_patcher.stop) self._mock_xblock_handler = get_lti_consumer_patcher.start() - # Generate RSA - self.key = RSAKey(key=RSA.generate(2048), kid="1") + # Generate RSA and save exports + rsa_key = RSA.generate(2048).export_key('PEM') + algo_obj = jwt.get_algorithm_by_name('RS256') + self.key = algo_obj.prepare_key(rsa_key) + self.public_key = self.key.public_key() def get_body(self, token, **overrides): """ diff --git a/lti_consumer/tests/unit/test_lti_xblock.py b/lti_consumer/tests/unit/test_lti_xblock.py index 153ca812..ee3e3faa 100644 --- a/lti_consumer/tests/unit/test_lti_xblock.py +++ b/lti_consumer/tests/unit/test_lti_xblock.py @@ -2,6 +2,7 @@ Unit tests for LtiConsumerXBlock """ import json +import jwt import logging import string from datetime import timedelta @@ -1935,11 +1936,11 @@ def setUp(self): self.rsa_key_id = "1" # Generate RSA and save exports rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) - self.public_key = rsa_key.publickey().export_key() + pem = rsa_key.export_key('PEM') + algo_obj = jwt.get_algorithm_by_name('RS256') + self.key = algo_obj.prepare_key(pem) + self.public_key = rsa_key.public_key().export_key('PEM') + self.xblock_attributes = { 'lti_version': 'lti_1p3', diff --git a/requirements/base.in b/requirements/base.in index 491efaa0..13f6cdb3 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -9,6 +9,7 @@ oauthlib mako lazy XBlock +pyjwt pycryptodomex pyjwkest edx-opaque-keys[django] diff --git a/requirements/base.txt b/requirements/base.txt index f267a79e..91fe7b91 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -116,6 +116,8 @@ pycryptodomex==3.21.0 # pyjwkest pyjwkest==1.4.2 # via -r requirements/base.in +pyjwt==2.6.0 + # via -r requirements/base.in pymongo==4.10.1 # via edx-opaque-keys pynacl==1.5.0 diff --git a/requirements/ci.txt b/requirements/ci.txt index bd2feec9..28787688 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -333,6 +333,8 @@ pygments==2.18.0 # rich pyjwkest==1.4.2 # via -r requirements/test.txt +pyjwt==2.6.0 + # via -r requirements/test.txt pylint==3.3.2 # via # -r requirements/test.txt diff --git a/requirements/dev.txt b/requirements/dev.txt index 720e79c3..067f4ede 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -170,6 +170,8 @@ pycryptodomex==3.21.0 # pyjwkest pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pymongo==4.10.1 # via # -r requirements/base.txt diff --git a/requirements/quality.txt b/requirements/quality.txt index ec3758d4..09f92685 100644 --- a/requirements/quality.txt +++ b/requirements/quality.txt @@ -208,6 +208,8 @@ pygments==2.18.0 # via rich pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pylint==3.3.2 # via # -r requirements/quality.in diff --git a/requirements/test.txt b/requirements/test.txt index 734dd5a5..55688b48 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -250,6 +250,8 @@ pygments==2.18.0 # rich pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pylint==3.3.2 # via # edx-lint From 65ffe2943efa4b570250f3f6aa2d60b69deed843 Mon Sep 17 00:00:00 2001 From: Muhammad Umar Khan Date: Mon, 6 Nov 2023 14:10:04 +0500 Subject: [PATCH 2/4] squash! fix quality issues --- lti_consumer/lti_1p3/key_handlers.py | 30 ++++++++++----------- lti_consumer/lti_1p3/tests/test_consumer.py | 16 +++++------ lti_consumer/tests/unit/test_lti_xblock.py | 1 - 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py index 7993d2e6..b5fac3dd 100644 --- a/lti_consumer/lti_1p3/key_handlers.py +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -112,11 +112,11 @@ def validate_and_decode(self, token): for i in range(len(key_set)): try: message = jwt.decode( - token, - key=key_set[i], - algorithms=['RS256', 'RS512',], - options={'verify_signature': True} - ) + token, + key=key_set[i], + algorithms=['RS256', 'RS512',], + options={'verify_signature': True} + ) return message except Exception: if i == len(key_set) - 1: @@ -201,16 +201,16 @@ def validate_and_decode(self, token, iss=None, aud=None): raise exceptions.RsaKeyNotSet() try: message = jwt.decode( - token, - key=self.key.public_key(), - audience=aud, - issuer=iss, - algorithms=['RS256', 'RS512'], - options={ - 'verify_signature': True, - 'verify_aud': True if aud else False - } - ) + token, + key=self.key.public_key(), + audience=aud, + issuer=iss, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': True if aud else False + } + ) return message except Exception as token_error: diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py index fa7fb1af..f86144a0 100644 --- a/lti_consumer/lti_1p3/tests/test_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -125,14 +125,14 @@ def _decode_token(self, token): for i in range(len(keyset)): try: message = jwt.decode( - token, - key=keyset[i].key, - algorithms=['RS256', 'RS512'], - options={ - 'verify_signature': True, - 'verify_aud': False - } - ) + token, + key=keyset[i].key, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': False + } + ) return message except Exception as token_error: if i < len(keyset) - 1: diff --git a/lti_consumer/tests/unit/test_lti_xblock.py b/lti_consumer/tests/unit/test_lti_xblock.py index ee3e3faa..c2dde8f5 100644 --- a/lti_consumer/tests/unit/test_lti_xblock.py +++ b/lti_consumer/tests/unit/test_lti_xblock.py @@ -1941,7 +1941,6 @@ def setUp(self): self.key = algo_obj.prepare_key(pem) self.public_key = rsa_key.public_key().export_key('PEM') - self.xblock_attributes = { 'lti_version': 'lti_1p3', 'lti_1p3_launch_url': 'http://tool.example/launch', From 8d63c769ff359c7f31e6b0dd4e5aede2e17a5e0a Mon Sep 17 00:00:00 2001 From: Muhammad Umar Khan Date: Thu, 11 Jan 2024 13:47:46 +0500 Subject: [PATCH 3/4] fix: remove useless tests --- lti_consumer/lti_1p3/key_handlers.py | 82 +++++++++++-------- lti_consumer/lti_1p3/tests/test_consumer.py | 33 ++++---- .../lti_1p3/tests/test_key_handlers.py | 74 ++++------------- lti_consumer/lti_1p3/tests/utils.py | 2 +- lti_consumer/plugin/views.py | 19 +++-- .../tests/unit/plugin/test_proctoring.py | 6 -- lti_consumer/tests/unit/plugin/test_views.py | 11 ++- .../tests/unit/plugin/test_views_lti_ags.py | 6 -- .../plugin/test_views_lti_deep_linking.py | 9 -- .../tests/unit/plugin/test_views_lti_nrps.py | 6 -- lti_consumer/tests/unit/test_lti_xblock.py | 64 +++++++++------ lti_consumer/tests/unit/test_models.py | 6 -- 12 files changed, 137 insertions(+), 181 deletions(-) diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py index b5fac3dd..beb6eedf 100644 --- a/lti_consumer/lti_1p3/key_handlers.py +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -7,13 +7,13 @@ import copy import json import math -import time import sys +import time import logging import jwt -from Cryptodome.PublicKey import RSA from edx_django_utils.monitoring import function_trace +from jwt.api_jwk import PyJWK from . import exceptions @@ -52,7 +52,9 @@ def __init__(self, public_key=None, keyset_url=None): try: # Import Key and save to internal state algo_obj = jwt.get_algorithm_by_name('RS256') - self.public_key = algo_obj.prepare_key(public_key) + public_key = algo_obj.prepare_key(public_key) + public_jwk = json.loads(algo_obj.to_jwk(public_key)) + self.public_key = PyJWK.from_dict(public_jwk) except ValueError as err: log.warning( 'An error was encountered while loading the LTI tool\'s key from the public key. ' @@ -82,15 +84,16 @@ def _get_keyset(self, kid=None): 'The RSA keys could not be loaded.' ) raise exceptions.NoSuitableKeys() from err - keyset.extend(keys) + keyset.extend(keys.keys) + + if self.public_key and kid: + # Fill in key id of stored key. + # This is needed because if the JWS is signed with a + # key with a kid, pyjwkest doesn't match them with + # keys without kid (kid=None) and fails verification + self.public_key.kid = kid if self.public_key: - if kid: - # Fill in key id of stored key. - # This is needed because if the JWS is signed with a - # key with a kid, pyjwkest doesn't match them with - # keys without kid (kid=None) and fails verification - self.public_key.kid = kid # Add to keyset keyset.append(self.public_key) @@ -105,25 +108,29 @@ def validate_and_decode(self, token): The authorization server decodes the JWT and MUST validate the values for the iss, sub, exp, aud and jti claims. """ - try: - key_set = self._get_keyset() - if not key_set: - raise exceptions.NoSuitableKeys() - for i in range(len(key_set)): - try: - message = jwt.decode( - token, - key=key_set[i], - algorithms=['RS256', 'RS512',], - options={'verify_signature': True} - ) - return message - except Exception: - if i == len(key_set) - 1: - raise - except Exception as token_error: - exc_info = sys.exc_info() - raise jwt.InvalidTokenError(exc_info[2]) from token_error + key_set = self._get_keyset() + + for i, obj in enumerate(key_set): + try: + if hasattr(obj.key, 'public_key'): + key = obj.key.public_key() + else: + key = obj.key + message = jwt.decode( + token, + key, + algorithms=['RS256', 'RS512',], + options={ + 'verify_signature': True, + 'verify_aud': False + } + ) + return message + except Exception: # pylint: disable=broad-except + if i == len(key_set) - 1: + raise + + raise exceptions.NoSuitableKeys() class PlatformKeyHandler: @@ -144,7 +151,10 @@ def __init__(self, key_pem, kid=None): # Import JWK from RSA key try: algo = jwt.get_algorithm_by_name('RS256') - self.key = algo.prepare_key(key_pem) + private_key = algo.prepare_key(key_pem) + private_jwk = json.loads(algo.to_jwk(private_key)) + private_jwk['kid'] = kid + self.key = PyJWK.from_dict(private_jwk) except ValueError as err: log.warning( 'An error was encountered while loading the LTI platform\'s key. ' @@ -175,7 +185,7 @@ def encode_and_sign(self, message, expiration=None): # The class instance that sets up the signing operation # An RS 256 key is required for LTI 1.3 - return jwt.encode(_message, self.key, algorithm="RS256") + return jwt.encode(_message, self.key.key, algorithm="RS256") def get_public_jwk(self): """ @@ -186,11 +196,11 @@ def get_public_jwk(self): # Only append to keyset if a key exists if self.key: algo_obj = jwt.get_algorithm_by_name('RS256') - public_key = algo_obj.prepare_key(self.key).public_key() + public_key = algo_obj.prepare_key(self.key.key).public_key() jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key))) return jwk - def validate_and_decode(self, token, iss=None, aud=None): + def validate_and_decode(self, token, iss=None, aud=None, exp=True): """ Check if a platform token is valid, and return allowed scopes. @@ -202,13 +212,15 @@ def validate_and_decode(self, token, iss=None, aud=None): try: message = jwt.decode( token, - key=self.key.public_key(), + key=self.key.key.public_key(), audience=aud, issuer=iss, algorithms=['RS256', 'RS512'], options={ 'verify_signature': True, - 'verify_aud': True if aud else False + 'verify_exp': bool(exp), + 'verify_iss': bool(iss), + 'verify_aud': bool(aud) } ) return message diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py index f86144a0..7b9bac2e 100644 --- a/lti_consumer/lti_1p3/tests/test_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -8,7 +8,6 @@ import ddt import jwt -import sys from Cryptodome.PublicKey import RSA from django.conf import settings from django.test.testcases import TestCase @@ -115,30 +114,26 @@ def _get_lti_message( def _decode_token(self, token): """ - Checks for a valid signarute and decodes JWT signed LTI message + Checks for a valid signature and decodes JWT signed LTI message This also tests the public keyset function. """ public_keyset = self.lti_consumer.get_public_keyset() keyset = PyJWKSet.from_dict(public_keyset).keys - for i in range(len(keyset)): - try: - message = jwt.decode( - token, - key=keyset[i].key, - algorithms=['RS256', 'RS512'], - options={ - 'verify_signature': True, - 'verify_aud': False - } - ) - return message - except Exception as token_error: - if i < len(keyset) - 1: - continue - exc_info = sys.exc_info() - raise jwt.InvalidTokenError(exc_info[2]) from token_error + for obj in keyset: + message = jwt.decode( + token, + key=obj.key, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': False + } + ) + return message + + return exceptions.NoSuitableKeys() @ddt.data( ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True), diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py index bc5710aa..e7097f44 100644 --- a/lti_consumer/lti_1p3/tests/test_key_handlers.py +++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py @@ -5,16 +5,14 @@ import json import math import time +from datetime import datetime, timezone from unittest.mock import patch import ddt import jwt from Cryptodome.PublicKey import RSA from django.test.testcases import TestCase -from jwkest import BadSignature -from jwkest.jwk import RSAKey, load_jwks -from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm - +from jwt.api_jwk import PyJWK from lti_consumer.lti_1p3 import exceptions from lti_consumer.lti_1p3.key_handlers import PlatformKeyHandler, ToolKeyHandler @@ -39,16 +37,13 @@ def setUp(self): kid=self.rsa_key_id ) - def _decode_token(self, token): + def _decode_token(self, token, exp=True): """ - Checks for a valid signarute and decodes JWT signed LTI message + Checks for a valid signature and decodes JWT signed LTI message This also touches the public keyset method. """ - public_keyset = self.key_handler.get_public_jwk() - key_set = load_jwks(json.dumps(public_keyset)) - - return JWS().verify_compact(token, keys=key_set) + return self.key_handler.validate_and_decode(token, exp=exp) def test_encode_and_sign(self): """ @@ -59,7 +54,7 @@ def test_encode_and_sign(self): } signed_token = self.key_handler.encode_and_sign(message) self.assertEqual( - self._decode_token(signed_token), + self._decode_token(signed_token, exp=False), message ) @@ -72,10 +67,10 @@ def test_encode_and_sign_with_exp(self, mock_time): message = { "test": "test" } - + expiration = int(datetime.now(tz=timezone.utc).timestamp()) signed_token = self.key_handler.encode_and_sign( message, - expiration=1000 + expiration=expiration ) self.assertEqual( @@ -83,34 +78,10 @@ def test_encode_and_sign_with_exp(self, mock_time): { "test": "test", "iat": 1000, - "exp": 2000 + "exp": expiration + 1000 } ) - def test_encode_and_sign_no_suitable_keys(self): - """ - Test if an exception is raised when there are no suitable keys when signing the JWT. - """ - message = { - "test": "test" - } - - with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=NoSuitableSigningKeys): - with self.assertRaises(exceptions.NoSuitableKeys): - self.key_handler.encode_and_sign(message) - - def test_encode_and_sign_unknown_algorithm(self): - """ - Test if an exception is raised when the signing algorithm is unknown when signing the JWT. - """ - message = { - "test": "test" - } - - with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=UnknownAlgorithm): - with self.assertRaises(exceptions.MalformedJwtToken): - self.key_handler.encode_and_sign(message) - def test_invalid_rsa_key(self): """ Check that class raises when trying to import invalid RSA Key. @@ -217,10 +188,14 @@ def setUp(self): self.rsa_key_id = "1" # Generate RSA and save exports - rsa_key = RSA.generate(2048).export_key('PEM') + rsa_key = RSA.generate(2048) algo_obj = jwt.get_algorithm_by_name('RS256') - self.key = algo_obj.prepare_key(rsa_key) - self.public_key = self.key.public_key() + private_key = algo_obj.prepare_key(rsa_key.export_key()) + private_jwk = json.loads(algo_obj.to_jwk(private_key)) + private_jwk['kid'] = self.rsa_key_id + self.key = PyJWK.from_dict(private_jwk) + + self.public_key = rsa_key.publickey().export_key() # Key handler self.key_handler = None @@ -318,20 +293,5 @@ def test_validate_and_decode_no_keys(self): signed = create_jwt(self.key, message) # Decode and check results - with self.assertRaises(jwt.InvalidTokenError): + with self.assertRaises(exceptions.NoSuitableKeys): key_handler.validate_and_decode(signed) - - @patch("lti_consumer.lti_1p3.key_handlers.jwt.decode") - def test_validate_and_decode_bad_signature(self, mock_jwt_decode): - mock_jwt_decode.side_effect = Exception() - self._setup_key_handler() - - message = { - "test": "test_message", - "iat": 1000, - "exp": 1200, - } - signed = create_jwt(self.key, message) - - with self.assertRaises(jwt.InvalidTokenError): - self.key_handler.validate_and_decode(signed) diff --git a/lti_consumer/lti_1p3/tests/utils.py b/lti_consumer/lti_1p3/tests/utils.py index 3aae56ca..607b6f37 100644 --- a/lti_consumer/lti_1p3/tests/utils.py +++ b/lti_consumer/lti_1p3/tests/utils.py @@ -9,6 +9,6 @@ def create_jwt(key, message): Uses private key to create a JWS from a dict. """ token = jwt.encode( - message, key, algorithm='RS256' + message, key.key, algorithm='RS256' ) return token diff --git a/lti_consumer/plugin/views.py b/lti_consumer/plugin/views.py index 773ac2a4..a9967e4a 100644 --- a/lti_consumer/plugin/views.py +++ b/lti_consumer/plugin/views.py @@ -476,22 +476,27 @@ def access_token_endpoint( )) ) return JsonResponse(token) - except Exception as token_error: + except Exception: # pylint: disable=broad-except exc_info = sys.exc_info() + # import pdb; pdb.set_trace() # Handle errors and return a proper response if exc_info[0] == MissingRequiredClaim: # Missing request attributes return JsonResponse({"error": "invalid_request"}, status=HTTP_400_BAD_REQUEST) - elif exc_info[0] in (MalformedJwtToken, TokenSignatureExpired, jwt.InvalidTokenError): + elif exc_info[0] in (MalformedJwtToken, TokenSignatureExpired, jwt.exceptions.DecodeError): # Triggered when a invalid grant token is used return JsonResponse({"error": "invalid_grant"}, status=HTTP_400_BAD_REQUEST) - elif exc_info[0] == UnsupportedGrantType: - return JsonResponse({"error": "unsupported_grant_type"}, status=HTTP_400_BAD_REQUEST) - else: + elif exc_info[0] in (NoSuitableKeys, UnknownClientId, + jwt.exceptions.InvalidSignatureError, + KeyError, AttributeError): # Client ID is not registered in the block or # isn't possible to validate token using available keys. return JsonResponse({"error": "invalid_client"}, status=HTTP_400_BAD_REQUEST) + elif exc_info[0] == UnsupportedGrantType: + return JsonResponse({"error": "unsupported_grant_type"}, status=HTTP_400_BAD_REQUEST) + else: + return JsonResponse({"error": "unidentified_error"}, status=HTTP_400_BAD_REQUEST) # Post from external tool that doesn't @@ -572,7 +577,7 @@ def deep_linking_response_endpoint(request, lti_config_id=None): status=400 ) # Bad JWT message, invalid token, or any other message validation issues - except (Lti1p3Exception, PermissionDenied) as exc: + except (Lti1p3Exception, PermissionDenied, jwt.exceptions.DecodeError) as exc: log.warning( "Permission on LTI Config %r denied for user %r: %s", lti_config, @@ -872,7 +877,7 @@ def start_proctoring_assessment_endpoint(request): try: decoded_jwt = jwt.decode(token, options={'verify_signature': False}) - except Exception: + except Exception: # pylint: disable=broad-except return render(request, 'html/lti_proctoring_start_error.html', status=HTTP_400_BAD_REQUEST) iss = decoded_jwt.get('iss') diff --git a/lti_consumer/tests/unit/plugin/test_proctoring.py b/lti_consumer/tests/unit/plugin/test_proctoring.py index 5f4e8167..4cd5ea3b 100644 --- a/lti_consumer/tests/unit/plugin/test_proctoring.py +++ b/lti_consumer/tests/unit/plugin/test_proctoring.py @@ -9,8 +9,6 @@ from django.contrib.auth import get_user_model from django.test.testcases import TestCase from edx_django_utils.cache import TieredCache, get_cache_key -from jwkest.jwk import RSAKey -from jwkest.jwt import BadSyntax from lti_consumer.data import Lti1p3LaunchData, Lti1p3ProctoringLaunchData from lti_consumer.lti_1p3.exceptions import (BadJwtSignature, InvalidClaimValue, MalformedJwtToken, @@ -45,10 +43,6 @@ def setUp(self): # Set up a public key - private key pair that allows encoding and decoding a Tool JWT. self.rsa_key_id = str(uuid.uuid4()) self.private_key = RSA.generate(2048) - self.key = RSAKey( - key=self.private_key, - kid=self.rsa_key_id - ) self.public_key = self.private_key.publickey().export_key().decode() self.lti_config.lti_1p3_tool_public_key = self.public_key diff --git a/lti_consumer/tests/unit/plugin/test_views.py b/lti_consumer/tests/unit/plugin/test_views.py index 0b7bca3a..1efa781c 100644 --- a/lti_consumer/tests/unit/plugin/test_views.py +++ b/lti_consumer/tests/unit/plugin/test_views.py @@ -9,9 +9,9 @@ from django.test.testcases import TestCase from django.urls import reverse from edx_django_utils.cache import TieredCache, get_cache_key +from jwt.api_jwk import PyJWK from Cryptodome.PublicKey import RSA -from jwkest.jwk import RSAKey from opaque_keys.edx.keys import UsageKey from lti_consumer.data import Lti1p3LaunchData, Lti1p3ProctoringLaunchData from lti_consumer.models import LtiConfiguration, LtiDlContentItem @@ -675,10 +675,13 @@ def setUp(self): self.addCleanup(get_lti_consumer_patcher.stop) self._mock_xblock_handler = get_lti_consumer_patcher.start() # Generate RSA and save exports - rsa_key = RSA.generate(2048).export_key('PEM') + rsa_key = RSA.generate(2048) algo_obj = jwt.get_algorithm_by_name('RS256') - self.key = algo_obj.prepare_key(rsa_key) - self.public_key = self.key.public_key() + private_key = algo_obj.prepare_key(rsa_key.export_key()) + private_jwk = json.loads(algo_obj.to_jwk(private_key)) + private_jwk['kid'] = 1 + self.key = PyJWK.from_dict(private_jwk) + self.public_key = rsa_key.public_key().export_key('PEM') def get_body(self, token, **overrides): """ diff --git a/lti_consumer/tests/unit/plugin/test_views_lti_ags.py b/lti_consumer/tests/unit/plugin/test_views_lti_ags.py index 7a9e850f..5bb7973d 100644 --- a/lti_consumer/tests/unit/plugin/test_views_lti_ags.py +++ b/lti_consumer/tests/unit/plugin/test_views_lti_ags.py @@ -9,7 +9,6 @@ import ddt from django.urls import reverse from django.utils import timezone -from jwkest.jwk import RSAKey from rest_framework.test import APITransactionTestCase @@ -26,12 +25,7 @@ def setUp(self): super().setUp() # Create custom LTI Block - self.rsa_key_id = "1" rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) self.public_key = rsa_key.publickey().export_key() self.xblock_attributes = { diff --git a/lti_consumer/tests/unit/plugin/test_views_lti_deep_linking.py b/lti_consumer/tests/unit/plugin/test_views_lti_deep_linking.py index a0ddc8b4..3c131fe6 100644 --- a/lti_consumer/tests/unit/plugin/test_views_lti_deep_linking.py +++ b/lti_consumer/tests/unit/plugin/test_views_lti_deep_linking.py @@ -6,7 +6,6 @@ import re import ddt from Cryptodome.PublicKey import RSA -from jwkest.jwk import RSAKey from rest_framework.test import APITransactionTestCase from rest_framework.exceptions import ValidationError @@ -37,14 +36,6 @@ def setUp(self): # Create custom LTI Block rsa_key = RSA.import_key(self.lti_config.lti_1p3_private_key) - self.key = RSAKey( - # Using the same key ID as client id - # This way we can easily serve multiple public - # keys on the same endpoint and keep all - # LTI 1.3 blocks working - kid=self.lti_config.lti_1p3_private_key_id, - key=rsa_key - ) self.public_key = rsa_key.publickey().export_key() self.xblock_attributes = { diff --git a/lti_consumer/tests/unit/plugin/test_views_lti_nrps.py b/lti_consumer/tests/unit/plugin/test_views_lti_nrps.py index 352e4533..e6ae56cf 100644 --- a/lti_consumer/tests/unit/plugin/test_views_lti_nrps.py +++ b/lti_consumer/tests/unit/plugin/test_views_lti_nrps.py @@ -3,7 +3,6 @@ """ from unittest.mock import Mock, patch from Cryptodome.PublicKey import RSA -from jwkest.jwk import RSAKey from rest_framework.test import APITransactionTestCase from rest_framework.reverse import reverse @@ -113,12 +112,7 @@ def setUp(self): super().setUp() # Create custom LTI Block - self.rsa_key_id = "1" rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) self.public_key = rsa_key.publickey().export_key() self.xblock_attributes = { diff --git a/lti_consumer/tests/unit/test_lti_xblock.py b/lti_consumer/tests/unit/test_lti_xblock.py index c2dde8f5..6a23cf51 100644 --- a/lti_consumer/tests/unit/test_lti_xblock.py +++ b/lti_consumer/tests/unit/test_lti_xblock.py @@ -2,7 +2,6 @@ Unit tests for LtiConsumerXBlock """ import json -import jwt import logging import string from datetime import timedelta @@ -10,14 +9,15 @@ from unittest.mock import Mock, PropertyMock, patch import ddt +import jwt from Cryptodome.PublicKey import RSA from django.conf import settings as dj_settings from django.test import override_settings from django.test.testcases import TestCase from django.utils import timezone -from jwkest.jwk import RSAKey, KEYS from xblock.validation import Validation +from jwt.api_jwk import PyJWK, PyJWKSet from lti_consumer.exceptions import LtiError from lti_consumer.api import config_id_for_block @@ -1936,9 +1936,11 @@ def setUp(self): self.rsa_key_id = "1" # Generate RSA and save exports rsa_key = RSA.generate(2048) - pem = rsa_key.export_key('PEM') algo_obj = jwt.get_algorithm_by_name('RS256') - self.key = algo_obj.prepare_key(pem) + private_key = algo_obj.prepare_key(rsa_key.export_key()) + private_jwk = json.loads(algo_obj.to_jwk(private_key)) + private_jwk['kid'] = self.rsa_key_id + self.key = PyJWK.from_dict(private_jwk) self.public_key = rsa_key.public_key().export_key('PEM') self.xblock_attributes = { @@ -2019,8 +2021,8 @@ def test_access_token_invalid_client(self): self.xblock.lti_1p3_tool_public_key = '' self.xblock.save() - jwt = create_jwt(self.key, {}) - request = make_jwt_request(jwt) + jwt_token = create_jwt(self.key, {}) + request = make_jwt_request(jwt_token) response = self.xblock.lti_1p3_access_token(request) self.assertEqual(response.status_code, 400) self.assertJSONEqual(response.content, {'error': 'invalid_client'}) @@ -2029,8 +2031,8 @@ def test_access_token(self): """ Test request with valid JWT. """ - jwt = create_jwt(self.key, {}) - request = make_jwt_request(jwt) + jwt_token = create_jwt(self.key, {}) + request = make_jwt_request(jwt_token) response = self.xblock.lti_1p3_access_token(request) self.assertEqual(response.status_code, 200) @@ -2123,10 +2125,15 @@ def setUp(self): 'lti_1p3_tool_keyset_url': "http://tool.example/keyset", }) - self.key = RSAKey(key=RSA.generate(2048), kid="1") + rsa_key = RSA.generate(2048).export_key('PEM') + self.algo_obj = jwt.get_algorithm_by_name('RS256') + private_key = self.algo_obj.prepare_key(rsa_key) + private_jwk = json.loads(self.algo_obj.to_jwk(private_key)) + private_jwk['kid'] = '1' + self.key = PyJWK.from_dict(private_jwk) - jwt = create_jwt(self.key, {}) - self.request = make_jwt_request(jwt) + jwt_token = create_jwt(self.key, {}) + self.request = make_jwt_request(jwt_token) patcher = patch( 'lti_consumer.plugin.compat.load_enough_xblock', @@ -2139,37 +2146,44 @@ def make_keyset(self, keys): """ Builds a keyset object with the given keys. """ - jwks = KEYS() - jwks._keys = keys # pylint: disable=protected-access - return jwks + keys_dict = {'keys': []} + for key in keys: + keys_dict['keys'].append(key._jwk_data) # pylint: disable=protected-access + return PyJWKSet.from_dict(keys_dict) - @patch("lti_consumer.lti_1p3.key_handlers.load_jwks_from_url") - def test_access_token_using_keyset_url(self, load_jwks_from_url): + @patch("lti_consumer.lti_1p3.key_handlers.jwt.PyJWKClient.get_jwk_set") + def test_access_token_using_keyset_url(self, get_jwk_set): """ Test request using the provider's keyset URL instead of a public key. """ - load_jwks_from_url.return_value = self.make_keyset([self.key]) + # import pdb; pdb.set_trace() + get_jwk_set.return_value = self.make_keyset([self.key]) response = self.xblock.lti_1p3_access_token(self.request) - load_jwks_from_url.assert_called_once_with("http://tool.example/keyset") + get_jwk_set.assert_called_once() self.assertEqual(response.status_code, 200) - @patch("lti_consumer.lti_1p3.key_handlers.load_jwks_from_url") - def test_access_token_using_keyset_url_with_empty_keys(self, load_jwks_from_url): + @patch("lti_consumer.lti_1p3.key_handlers.jwt.PyJWKClient.get_jwk_set") + def test_access_token_using_keyset_url_with_empty_keys(self, get_jwk_set): """ Test request where the provider's keyset URL returns an empty list of keys. """ - load_jwks_from_url.return_value = self.make_keyset([]) + # get_jwk_set.return_value = self.make_keyset([]) + get_jwk_set.side_effect = jwt.exceptions.PyJWKSetError response = self.xblock.lti_1p3_access_token(self.request) self.assertEqual(response.status_code, 400) self.assertJSONEqual(response.content, {"error": "invalid_client"}) - @patch("lti_consumer.lti_1p3.key_handlers.load_jwks_from_url") - def test_access_token_using_keyset_url_with_wrong_keys(self, load_jwks_from_url): + @patch("lti_consumer.lti_1p3.key_handlers.jwt.PyJWKClient.get_jwk_set") + def test_access_token_using_keyset_url_with_wrong_keys(self, get_jwk_set): """ Test request where the provider's keyset URL returns wrong keys. """ - key = RSAKey(key=RSA.generate(2048), kid="2") - load_jwks_from_url.return_value = self.make_keyset([key]) + rsa_key = RSA.generate(2048).export_key('PEM') + private_key = self.algo_obj.prepare_key(rsa_key) + private_jwk = json.loads(self.algo_obj.to_jwk(private_key)) + private_jwk['kid'] = 2 + key = PyJWK.from_dict(private_jwk) + get_jwk_set.return_value = self.make_keyset([key]) response = self.xblock.lti_1p3_access_token(self.request) self.assertEqual(response.status_code, 400) self.assertJSONEqual(response.content, {"error": "invalid_client"}) diff --git a/lti_consumer/tests/unit/test_models.py b/lti_consumer/tests/unit/test_models.py index 78d10536..f49b2cb0 100644 --- a/lti_consumer/tests/unit/test_models.py +++ b/lti_consumer/tests/unit/test_models.py @@ -11,7 +11,6 @@ from django.test.testcases import TestCase from django.utils import timezone from edx_django_utils.cache import RequestCache -from jwkest.jwk import RSAKey from ccx_keys.locator import CCXBlockUsageLocator from opaque_keys.edx.locator import CourseLocator @@ -32,13 +31,8 @@ class TestLtiConfigurationModel(TestCase): def setUp(self): super().setUp() - self.rsa_key_id = "1" # Generate RSA and save exports rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) self.public_key = rsa_key.publickey().export_key() self.xblock_attributes = { From 3bee2478f19ae0c86cd812bdd982d893a5263997 Mon Sep 17 00:00:00 2001 From: Alie Langston Date: Thu, 9 Jan 2025 09:13:45 -0500 Subject: [PATCH 4/4] docs: update init and changelog files --- CHANGELOG.rst | 4 ++++ lti_consumer/__init__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d7d8d7cf..9b58fc29 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,10 @@ Please See the `releases tab