Skip to content

Commit

Permalink
Merge pull request #85 from nansencenter/hotfix_store_token
Browse files Browse the repository at this point in the history
Hotfix store OAuth2 token
  • Loading branch information
aperrin66 authored Sep 25, 2023
2 parents 5f04ca0 + eca6e77 commit c2dd225
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 14 deletions.
63 changes: 49 additions & 14 deletions geospaas_processing/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
"""
import errno
import ftplib
import hashlib
import logging
import os
import os.path
import pickle
import re
import shutil
from urllib.parse import urlparse
Expand Down Expand Up @@ -163,32 +165,64 @@ class HTTPDownloader(Downloader):
CHUNK_SIZE = 1024 * 1024

@classmethod
def build_oauth2_authentication(cls, username, password, token_url, client_id,
totp_secret=None):
"""Creates an OAuth2 object usable by `requests` methods"""
def get_oauth2_token(cls, username, password, token_url, client, totp_secret=None):
"""Try to get a token from Redis. If this fails, fetch one from the URL"""
token = None

LOGGER.debug("Attempting to get an OAuth2 token")
if Redis is not None and utils.REDIS_HOST and utils.REDIS_PORT: # cache available
cache = Redis(host=utils.REDIS_HOST, port=utils.REDIS_PORT)
key_hash = hashlib.sha1(bytes(token_url + username, encoding='utf-8')).hexdigest()
LOGGER.debug("Trying to retrieve OAuth2 token from the cache")
raw_token = cache.get(key_hash)
if raw_token is None: # did not get the token from the cache
token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret)
LOGGER.debug("Got OAuth2 token from URL")
cache.set(key_hash, pickle.dumps(token), ex=token['expires_in'])
LOGGER.debug("Stored Oauth2 token in the cache")
else: # successfully got the token from the cache
token = pickle.loads(raw_token)
LOGGER.debug("Got OAuth2 token from the cache")
else: # cache not available
LOGGER.debug("Cache not available, getting OAuth2 token from URL")
token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret)

return token


@classmethod
def fetch_oauth2_token(cls, username, password, token_url, client, totp_secret=None):
"""Fetches a new token from the URL"""
# TOTP passwords are valid for 30 seconds, so we retry a few
# times in case we get unlucky and the password expires between
# the generation of the password and the authentication request
session_args = {
'token_url': token_url,
'username': username,
'password': password,
'client_id': client.client_id,
}
retries = 5
while retries > 0:
client = oauthlib.oauth2.LegacyApplicationClient(client_id=client_id)
session_args = {
'token_url': token_url,
'username': username,
'password': password,
'client_id': client_id,
}
if totp_secret:
session_args['totp'] = pyotp.TOTP(totp_secret).now()
try:
if totp_secret:
session_args['totp'] = pyotp.TOTP(totp_secret).now()
token = requests_oauthlib.OAuth2Session(client=client).fetch_token(**session_args)
except oauthlib.oauth2.rfc6749.errors.InvalidGrantError:
retries -= 1
if retries > 0:
if retries > 0:
continue
else:
raise
return requests_oauthlib.OAuth2(client_id=client_id, client=client, token=token)
return token

@classmethod
def build_oauth2_authentication(cls, username, password, token_url, client_id,
totp_secret=None):
"""Creates an OAuth2 object usable by `requests` methods"""
client = oauthlib.oauth2.LegacyApplicationClient(client_id=client_id)
token = cls.get_oauth2_token(username, password, token_url, client, totp_secret)
return requests_oauthlib.OAuth2(client_id=client_id, client=client, token=token)

@classmethod
def get_auth(cls, kwargs):
Expand All @@ -211,6 +245,7 @@ def get_auth(cls, kwargs):

@classmethod
def get_request_parameters(cls, kwargs):
"""Retrieve and check request parameters from kwargs"""
parameters = kwargs.get('request_parameters', {})
if isinstance(parameters, dict):
return parameters
Expand Down
36 changes: 36 additions & 0 deletions tests/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import os.path
import pickle
import tempfile
import unittest
import unittest.mock as mock
Expand Down Expand Up @@ -298,6 +299,41 @@ def test_get_oauth2_auth_with_totp(self):
mock_build_auth.assert_called_with('username', 'password', 'token_url', 'client_id',
totp_secret='totp_secret')

def test_get_oauth2_token_with_cache(self):
"""Test getting an OAuth2 token from the cache"""
fake_token = {
'access_token': 'foo',
'expires_in': 36000,
'refresh_expires_in': 28800,
'refresh_token': 'foo',
'token_type': 'bearer',
'not-before-policy': 0,
'session_state': 'd82c2e20-f690-474f-9d4f-51d68d2d042e',
'expires_at': 1616444581.1169086
}
pickled_fake_token = pickle.dumps(fake_token)
with mock.patch('geospaas_processing.downloaders.utils.REDIS_HOST', 'test'), \
mock.patch('geospaas_processing.downloaders.utils.REDIS_PORT', '6379'), \
mock.patch('geospaas_processing.downloaders.Redis') as mock_redis, \
mock.patch('geospaas_processing.downloaders.HTTPDownloader.fetch_oauth2_token',
return_value=fake_token) as mock_fetch_token:

with self.subTest('Cache present, no token'):
mock_redis.return_value.get.return_value = None
result = downloaders.HTTPDownloader.get_oauth2_token(
'foo', 'bar', 'baz', 'qux', 'quux')
mock_redis.return_value.set.assert_called_with(
'fd05b6f4dcd0c72512ea0cf6e1c94a6689353678',
pickled_fake_token,
ex=36000)
self.assertEqual(result, fake_token)

with self.subTest('Cache present with token'):
mock_redis.return_value.get.return_value = pickled_fake_token
result = downloaders.HTTPDownloader.get_oauth2_token(
'foo', 'bar', 'baz', 'qux', 'quux')
self.assertEqual(result, fake_token)

def test_get_basic_auth(self):
"""Test getting a basic authentication from get_auth()"""
self.assertEqual(
Expand Down

0 comments on commit c2dd225

Please sign in to comment.