From 56cd2c06826005f888c6eeea05fd06c44944f0ec Mon Sep 17 00:00:00 2001 From: Robert Raposa Date: Thu, 23 Jan 2020 16:32:09 -0500 Subject: [PATCH] restore oauth_uri variable and fix caching Restores the deleted `oauth_uri` variable of `OAuthAPIClient`. This is a fix for our ProctorTrack integration, which was relying on the `oauth_uri` variable as part of the contract. This may or may not be a long-term solution. Additionally, since `OAuthAPIClient` can be used to retrieve tokens against multiple oauth endpoints, we add the oauth_url to the cache key for the token. BOM-1184 --- edx_rest_api_client/__version__.py | 2 +- edx_rest_api_client/client.py | 19 ++++++-- edx_rest_api_client/tests/test_client.py | 60 ++++++++++++++++-------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/edx_rest_api_client/__version__.py b/edx_rest_api_client/__version__.py index 4eb28e3..b7a5531 100644 --- a/edx_rest_api_client/__version__.py +++ b/edx_rest_api_client/__version__.py @@ -1 +1 @@ -__version__ = '3.0.0' +__version__ = '3.0.1' diff --git a/edx_rest_api_client/client.py b/edx_rest_api_client/client.py index 6083740..b6eaac9 100644 --- a/edx_rest_api_client/client.py +++ b/edx_rest_api_client/client.py @@ -134,10 +134,12 @@ def get_and_cache_oauth_access_token(url, client_id, client_secret, token_type=' tuple: Tuple containing (access token string, expiration datetime). """ - cache_key = 'edx_rest_api_client.access_token.{}.{}.{}'.format( + oauth_url = _get_oauth_url(url) + cache_key = 'edx_rest_api_client.access_token.{}.{}.{}.{}'.format( token_type, grant_type, client_id, + oauth_url, ) cached_response = TieredCache.get_cached_response(cache_key) @@ -151,7 +153,7 @@ def get_and_cache_oauth_access_token(url, client_id, client_secret, token_type=' # Get a new access token if no unexpired access token was found in the cache. oauth_access_token_response = get_oauth_access_token( - _get_oauth_url(url), + oauth_url, client_id, client_secret, grant_type=grant_type, @@ -176,6 +178,11 @@ class OAuthAPIClient(requests.Session): See https://github.com/edx/edx-django-utils/blob/master/edx_django_utils/cache/README.rst#tieredcache """ + + # If the oauth_uri is set, it will be appended to the base_url. + # Also, if oauth_uri does not end with `/oauth2/access_token`, it will be adjusted as necessary to do so. + oauth_uri = None + def __init__(self, base_url, client_id, client_secret, **kwargs): """ Args: @@ -189,12 +196,12 @@ def __init__(self, base_url, client_id, client_secret, **kwargs): """ super(OAuthAPIClient, self).__init__(**kwargs) self.headers['user-agent'] = USER_AGENT + self.auth = SuppliedJwtAuth(None) + self._base_url = base_url.rstrip('/') self._client_id = client_id self._client_secret = client_secret - self.auth = SuppliedJwtAuth(None) - def _ensure_authentication(self): """ Ensures that the Session's auth.token is set with an unexpired token. @@ -203,8 +210,10 @@ def _ensure_authentication(self): requests.RequestException if there is a problem retrieving the access token. """ + oauth_url = self._base_url if not self.oauth_uri else self._base_url + self.oauth_uri + oauth_access_token_response = get_and_cache_oauth_access_token( - self._base_url, + oauth_url, self._client_id, self._client_secret, grant_type='client_credentials' diff --git a/edx_rest_api_client/tests/test_client.py b/edx_rest_api_client/tests/test_client.py index 3dacbd3..6a66f63 100644 --- a/edx_rest_api_client/tests/test_client.py +++ b/edx_rest_api_client/tests/test_client.py @@ -24,6 +24,7 @@ URL = 'http://example.com/api/v2' OAUTH_URL = "http://test-auth.com/oauth2/access_token" +OAUTH_URL_2 = "http://test-auth.com/edx/oauth2/access_token" SIGNING_KEY = 'edx' USERNAME = 'edx' FULL_NAME = 'édx äpp' @@ -180,18 +181,18 @@ def test_token_caching(self): """ Test that tokens are cached based on client, token_type, and grant_type """ - tokens = ['cred4', 'cred3', 'cred2', 'cred1'] + tokens = [ + 'auth2-cred4', 'auth2-cred3', 'auth2-cred2', 'auth2-cred1', + 'auth1-cred4', 'auth1-cred3', 'auth1-cred2', 'auth1-cred1', + ] def auth_callback(request): # pylint: disable=unused-argument resp = {'expires_in': 60} resp['access_token'] = 'no-more-credentials' if not tokens else tokens.pop() return (200, {}, json.dumps(resp)) - responses.add_callback( - responses.POST, OAUTH_URL, - callback=auth_callback, - content_type='application/json', - ) + responses.add_callback(responses.POST, OAUTH_URL, callback=auth_callback, content_type='application/json') + responses.add_callback(responses.POST, OAUTH_URL_2, callback=auth_callback, content_type='application/json') kwargs_list = [ {'client_id': 'test-id-1', 'token_type': "jwt", 'grant_type': 'client_credentials'}, @@ -200,24 +201,38 @@ def auth_callback(request): # pylint: disable=unused-argument {'client_id': 'test-id-1', 'token_type': "jwt", 'grant_type': 'refresh_token'}, ] - # initial requests should call the mock client and get the correct credentials + # initial requests to OAUTH_URL should call the mock client and get the correct credentials for index, kwargs in enumerate(kwargs_list): - token_response = self._get_and_cache_oauth_access_token(**kwargs) - expected_token = 'cred{}'.format(index + 1) + token_response = self._get_and_cache_oauth_access_token(OAUTH_URL, **kwargs) + expected_token = 'auth1-cred{}'.format(index + 1) self.assertEqual(token_response[0], expected_token) self.assertEqual(len(responses.calls), 4) - # second set of requests should return the same credentials without making any new mock calls + # initial requests to OAUTH_URL_2 should call the mock client and get the correct credentials for index, kwargs in enumerate(kwargs_list): - token_response = self._get_and_cache_oauth_access_token(**kwargs) - expected_token = 'cred{}'.format(index + 1) + token_response = self._get_and_cache_oauth_access_token(OAUTH_URL_2, **kwargs) + expected_token = 'auth2-cred{}'.format(index + 1) self.assertEqual(token_response[0], expected_token) - self.assertEqual(len(responses.calls), 4) + self.assertEqual(len(responses.calls), 8) + + # second set of requests to OAUTH_URL should return the same credentials without making any new mock calls + for index, kwargs in enumerate(kwargs_list): + token_response = self._get_and_cache_oauth_access_token(OAUTH_URL, **kwargs) + expected_token = 'auth1-cred{}'.format(index + 1) + self.assertEqual(token_response[0], expected_token) + self.assertEqual(len(responses.calls), 8) - def _get_and_cache_oauth_access_token(self, client_id, token_type, grant_type): + # second set of requests to OAUTH_URL_2 should return the same credentials without making any new mock calls + for index, kwargs in enumerate(kwargs_list): + token_response = self._get_and_cache_oauth_access_token(OAUTH_URL_2, **kwargs) + expected_token = 'auth2-cred{}'.format(index + 1) + self.assertEqual(token_response[0], expected_token) + self.assertEqual(len(responses.calls), 8) + + def _get_and_cache_oauth_access_token(self, auth_url, client_id, token_type, grant_type): refresh_token = 'test-refresh-token' if grant_type == 'refresh_token' else None return get_and_cache_oauth_access_token( - OAUTH_URL, client_id, 'test-secret', token_type=token_type, grant_type=grant_type, + auth_url, client_id, 'test-secret', token_type=token_type, grant_type=grant_type, refresh_token=refresh_token, ) @@ -237,15 +252,22 @@ def setUp(self): @responses.activate @ddt.data( - 'http://testing.test', - 'http://testing.test/oauth2', + ('http://testing.test', None, 'http://testing.test/oauth2/access_token'), + ('http://testing.test', '/edx', 'http://testing.test/edx/oauth2/access_token'), + ('http://testing.test', '/edx/oauth2', 'http://testing.test/edx/oauth2/access_token'), + ('http://testing.test', '/edx/oauth2/access_token', 'http://testing.test/edx/oauth2/access_token'), + ('http://testing.test/oauth2', None, 'http://testing.test/oauth2/access_token'), + ('http://testing.test/test', '/edx/oauth2/access_token', 'http://testing.test/test/edx/oauth2/access_token'), ) - def test_automatic_auth(self, client_base_url): + @ddt.unpack + def test_automatic_auth(self, client_base_url, custom_oauth_uri, expected_oauth_url): """ Test that the JWT token is automatically set """ client_session = OAuthAPIClient(client_base_url, self.client_id, self.client_secret) - self._mock_auth_api(self.base_url + '/oauth2/access_token', 200, {'access_token': 'abcd', 'expires_in': 60}) + client_session.oauth_uri = custom_oauth_uri + + self._mock_auth_api(expected_oauth_url, 200, {'access_token': 'abcd', 'expires_in': 60}) self._mock_auth_api(self.base_url + '/endpoint', 200, {'status': 'ok'}) response = client_session.post(self.base_url + '/endpoint', data={'test': 'ok'}) self.assertIn('client_id=%s' % self.client_id, responses.calls[0].request.body)