Skip to content

Commit

Permalink
restore oauth_uri variable and fix caching
Browse files Browse the repository at this point in the history
Restores the deleted `oauth_uri` variable of `OAuthAPIClient`. This is a
fix for our ProctorTrack integration, which was relying on the
`oauth_uri` variable as part of the contract. This may or may not be a
long-term solution.

Additionally, since `OAuthAPIClient` can be used to retrieve tokens
against multiple oauth endpoints, we add the oauth_url to the cache key
for the token.

BOM-1184
  • Loading branch information
robrap committed Jan 24, 2020
1 parent a0450c6 commit 56cd2c0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 25 deletions.
2 changes: 1 addition & 1 deletion edx_rest_api_client/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.0.0'
__version__ = '3.0.1'
19 changes: 14 additions & 5 deletions edx_rest_api_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,12 @@ def get_and_cache_oauth_access_token(url, client_id, client_secret, token_type='
tuple: Tuple containing (access token string, expiration datetime).
"""
cache_key = 'edx_rest_api_client.access_token.{}.{}.{}'.format(
oauth_url = _get_oauth_url(url)
cache_key = 'edx_rest_api_client.access_token.{}.{}.{}.{}'.format(
token_type,
grant_type,
client_id,
oauth_url,
)
cached_response = TieredCache.get_cached_response(cache_key)

Expand All @@ -151,7 +153,7 @@ def get_and_cache_oauth_access_token(url, client_id, client_secret, token_type='

# Get a new access token if no unexpired access token was found in the cache.
oauth_access_token_response = get_oauth_access_token(
_get_oauth_url(url),
oauth_url,
client_id,
client_secret,
grant_type=grant_type,
Expand All @@ -176,6 +178,11 @@ class OAuthAPIClient(requests.Session):
See https://github.com/edx/edx-django-utils/blob/master/edx_django_utils/cache/README.rst#tieredcache
"""

# If the oauth_uri is set, it will be appended to the base_url.
# Also, if oauth_uri does not end with `/oauth2/access_token`, it will be adjusted as necessary to do so.
oauth_uri = None

def __init__(self, base_url, client_id, client_secret, **kwargs):
"""
Args:
Expand All @@ -189,12 +196,12 @@ def __init__(self, base_url, client_id, client_secret, **kwargs):
"""
super(OAuthAPIClient, self).__init__(**kwargs)
self.headers['user-agent'] = USER_AGENT
self.auth = SuppliedJwtAuth(None)

self._base_url = base_url.rstrip('/')
self._client_id = client_id
self._client_secret = client_secret

self.auth = SuppliedJwtAuth(None)

def _ensure_authentication(self):
"""
Ensures that the Session's auth.token is set with an unexpired token.
Expand All @@ -203,8 +210,10 @@ def _ensure_authentication(self):
requests.RequestException if there is a problem retrieving the access token.
"""
oauth_url = self._base_url if not self.oauth_uri else self._base_url + self.oauth_uri

oauth_access_token_response = get_and_cache_oauth_access_token(
self._base_url,
oauth_url,
self._client_id,
self._client_secret,
grant_type='client_credentials'
Expand Down
60 changes: 41 additions & 19 deletions edx_rest_api_client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

URL = 'http://example.com/api/v2'
OAUTH_URL = "http://test-auth.com/oauth2/access_token"
OAUTH_URL_2 = "http://test-auth.com/edx/oauth2/access_token"
SIGNING_KEY = 'edx'
USERNAME = 'edx'
FULL_NAME = 'édx äpp'
Expand Down Expand Up @@ -180,18 +181,18 @@ def test_token_caching(self):
"""
Test that tokens are cached based on client, token_type, and grant_type
"""
tokens = ['cred4', 'cred3', 'cred2', 'cred1']
tokens = [
'auth2-cred4', 'auth2-cred3', 'auth2-cred2', 'auth2-cred1',
'auth1-cred4', 'auth1-cred3', 'auth1-cred2', 'auth1-cred1',
]

def auth_callback(request): # pylint: disable=unused-argument
resp = {'expires_in': 60}
resp['access_token'] = 'no-more-credentials' if not tokens else tokens.pop()
return (200, {}, json.dumps(resp))

responses.add_callback(
responses.POST, OAUTH_URL,
callback=auth_callback,
content_type='application/json',
)
responses.add_callback(responses.POST, OAUTH_URL, callback=auth_callback, content_type='application/json')
responses.add_callback(responses.POST, OAUTH_URL_2, callback=auth_callback, content_type='application/json')

kwargs_list = [
{'client_id': 'test-id-1', 'token_type': "jwt", 'grant_type': 'client_credentials'},
Expand All @@ -200,24 +201,38 @@ def auth_callback(request): # pylint: disable=unused-argument
{'client_id': 'test-id-1', 'token_type': "jwt", 'grant_type': 'refresh_token'},
]

# initial requests should call the mock client and get the correct credentials
# initial requests to OAUTH_URL should call the mock client and get the correct credentials
for index, kwargs in enumerate(kwargs_list):
token_response = self._get_and_cache_oauth_access_token(**kwargs)
expected_token = 'cred{}'.format(index + 1)
token_response = self._get_and_cache_oauth_access_token(OAUTH_URL, **kwargs)
expected_token = 'auth1-cred{}'.format(index + 1)
self.assertEqual(token_response[0], expected_token)
self.assertEqual(len(responses.calls), 4)

# second set of requests should return the same credentials without making any new mock calls
# initial requests to OAUTH_URL_2 should call the mock client and get the correct credentials
for index, kwargs in enumerate(kwargs_list):
token_response = self._get_and_cache_oauth_access_token(**kwargs)
expected_token = 'cred{}'.format(index + 1)
token_response = self._get_and_cache_oauth_access_token(OAUTH_URL_2, **kwargs)
expected_token = 'auth2-cred{}'.format(index + 1)
self.assertEqual(token_response[0], expected_token)
self.assertEqual(len(responses.calls), 4)
self.assertEqual(len(responses.calls), 8)

# second set of requests to OAUTH_URL should return the same credentials without making any new mock calls
for index, kwargs in enumerate(kwargs_list):
token_response = self._get_and_cache_oauth_access_token(OAUTH_URL, **kwargs)
expected_token = 'auth1-cred{}'.format(index + 1)
self.assertEqual(token_response[0], expected_token)
self.assertEqual(len(responses.calls), 8)

def _get_and_cache_oauth_access_token(self, client_id, token_type, grant_type):
# second set of requests to OAUTH_URL_2 should return the same credentials without making any new mock calls
for index, kwargs in enumerate(kwargs_list):
token_response = self._get_and_cache_oauth_access_token(OAUTH_URL_2, **kwargs)
expected_token = 'auth2-cred{}'.format(index + 1)
self.assertEqual(token_response[0], expected_token)
self.assertEqual(len(responses.calls), 8)

def _get_and_cache_oauth_access_token(self, auth_url, client_id, token_type, grant_type):
refresh_token = 'test-refresh-token' if grant_type == 'refresh_token' else None
return get_and_cache_oauth_access_token(
OAUTH_URL, client_id, 'test-secret', token_type=token_type, grant_type=grant_type,
auth_url, client_id, 'test-secret', token_type=token_type, grant_type=grant_type,
refresh_token=refresh_token,
)

Expand All @@ -237,15 +252,22 @@ def setUp(self):

@responses.activate
@ddt.data(
'http://testing.test',
'http://testing.test/oauth2',
('http://testing.test', None, 'http://testing.test/oauth2/access_token'),
('http://testing.test', '/edx', 'http://testing.test/edx/oauth2/access_token'),
('http://testing.test', '/edx/oauth2', 'http://testing.test/edx/oauth2/access_token'),
('http://testing.test', '/edx/oauth2/access_token', 'http://testing.test/edx/oauth2/access_token'),
('http://testing.test/oauth2', None, 'http://testing.test/oauth2/access_token'),
('http://testing.test/test', '/edx/oauth2/access_token', 'http://testing.test/test/edx/oauth2/access_token'),
)
def test_automatic_auth(self, client_base_url):
@ddt.unpack
def test_automatic_auth(self, client_base_url, custom_oauth_uri, expected_oauth_url):
"""
Test that the JWT token is automatically set
"""
client_session = OAuthAPIClient(client_base_url, self.client_id, self.client_secret)
self._mock_auth_api(self.base_url + '/oauth2/access_token', 200, {'access_token': 'abcd', 'expires_in': 60})
client_session.oauth_uri = custom_oauth_uri

self._mock_auth_api(expected_oauth_url, 200, {'access_token': 'abcd', 'expires_in': 60})
self._mock_auth_api(self.base_url + '/endpoint', 200, {'status': 'ok'})
response = client_session.post(self.base_url + '/endpoint', data={'test': 'ok'})
self.assertIn('client_id=%s' % self.client_id, responses.calls[0].request.body)
Expand Down

0 comments on commit 56cd2c0

Please sign in to comment.