From ecaa89d1d0efe0c1f5b90ab24d2eeb59f22f344b Mon Sep 17 00:00:00 2001 From: Adrien Perrin Date: Mon, 25 Sep 2023 13:27:51 +0000 Subject: [PATCH 1/2] store oauth2 token in cache to avoid re-sending the token request for every download --- geospaas_processing/downloaders.py | 63 +++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/geospaas_processing/downloaders.py b/geospaas_processing/downloaders.py index da69d75..6e61dc9 100644 --- a/geospaas_processing/downloaders.py +++ b/geospaas_processing/downloaders.py @@ -9,9 +9,11 @@ """ import errno import ftplib +import hashlib import logging import os import os.path +import pickle import re import shutil from urllib.parse import urlparse @@ -163,32 +165,64 @@ class HTTPDownloader(Downloader): CHUNK_SIZE = 1024 * 1024 @classmethod - def build_oauth2_authentication(cls, username, password, token_url, client_id, - totp_secret=None): - """Creates an OAuth2 object usable by `requests` methods""" + def get_oauth2_token(cls, username, password, token_url, client, totp_secret=None): + """Try to get a token from Redis. If this fails, fetch one from the URL""" + token = None + + LOGGER.debug("Attempting to get an OAuth2 token") + if Redis is not None and utils.REDIS_HOST and utils.REDIS_PORT: # cache available + cache = Redis(host=utils.REDIS_HOST, port=utils.REDIS_PORT) + key_hash = hashlib.sha1(bytes(token_url + username, encoding='utf-8')).hexdigest() + LOGGER.debug("Trying to retrieve OAuth2 token from the cache") + raw_token = cache.get(key_hash) + if raw_token is None: # did not get the token from the cache + token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret) + LOGGER.debug("Got OAuth2 token from URL") + cache.set(key_hash, pickle.dumps(token), ex=token['expires_in']) + LOGGER.debug("Stored Oauth2 token in the cache") + else: # successfully got the token from the cache + token = pickle.loads(raw_token) + LOGGER.debug("Got OAuth2 token from the cache") + else: # cache not available + LOGGER.debug("Cache not available, getting OAuth2 token from URL") + token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret) + + return token + + + @classmethod + def fetch_oauth2_token(cls, username, password, token_url, client, totp_secret=None): + """Fetches a new token from the URL""" # TOTP passwords are valid for 30 seconds, so we retry a few # times in case we get unlucky and the password expires between # the generation of the password and the authentication request + session_args = { + 'token_url': token_url, + 'username': username, + 'password': password, + 'client_id': client.client_id, + } retries = 5 while retries > 0: - client = oauthlib.oauth2.LegacyApplicationClient(client_id=client_id) - session_args = { - 'token_url': token_url, - 'username': username, - 'password': password, - 'client_id': client_id, - } + if totp_secret: + session_args['totp'] = pyotp.TOTP(totp_secret).now() try: - if totp_secret: - session_args['totp'] = pyotp.TOTP(totp_secret).now() token = requests_oauthlib.OAuth2Session(client=client).fetch_token(**session_args) except oauthlib.oauth2.rfc6749.errors.InvalidGrantError: retries -= 1 - if retries > 0: + if retries > 0: continue else: raise - return requests_oauthlib.OAuth2(client_id=client_id, client=client, token=token) + return token + + @classmethod + def build_oauth2_authentication(cls, username, password, token_url, client_id, + totp_secret=None): + """Creates an OAuth2 object usable by `requests` methods""" + client = oauthlib.oauth2.LegacyApplicationClient(client_id=client_id) + token = cls.get_oauth2_token(username, password, token_url, client, totp_secret) + return requests_oauthlib.OAuth2(client_id=client_id, client=client, token=token) @classmethod def get_auth(cls, kwargs): @@ -211,6 +245,7 @@ def get_auth(cls, kwargs): @classmethod def get_request_parameters(cls, kwargs): + """Retrieve and check request parameters from kwargs""" parameters = kwargs.get('request_parameters', {}) if isinstance(parameters, dict): return parameters From eca6e77edff9f48ed3710c12d1461037facd6ffc Mon Sep 17 00:00:00 2001 From: Adrien Perrin Date: Mon, 25 Sep 2023 14:30:34 +0000 Subject: [PATCH 2/2] add test for oauth2 token caching --- tests/test_downloaders.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_downloaders.py b/tests/test_downloaders.py index db7c07a..cac3618 100644 --- a/tests/test_downloaders.py +++ b/tests/test_downloaders.py @@ -5,6 +5,7 @@ import logging import os import os.path +import pickle import tempfile import unittest import unittest.mock as mock @@ -298,6 +299,41 @@ def test_get_oauth2_auth_with_totp(self): mock_build_auth.assert_called_with('username', 'password', 'token_url', 'client_id', totp_secret='totp_secret') + def test_get_oauth2_token_with_cache(self): + """Test getting an OAuth2 token from the cache""" + fake_token = { + 'access_token': 'foo', + 'expires_in': 36000, + 'refresh_expires_in': 28800, + 'refresh_token': 'foo', + 'token_type': 'bearer', + 'not-before-policy': 0, + 'session_state': 'd82c2e20-f690-474f-9d4f-51d68d2d042e', + 'expires_at': 1616444581.1169086 + } + pickled_fake_token = pickle.dumps(fake_token) + with mock.patch('geospaas_processing.downloaders.utils.REDIS_HOST', 'test'), \ + mock.patch('geospaas_processing.downloaders.utils.REDIS_PORT', '6379'), \ + mock.patch('geospaas_processing.downloaders.Redis') as mock_redis, \ + mock.patch('geospaas_processing.downloaders.HTTPDownloader.fetch_oauth2_token', + return_value=fake_token) as mock_fetch_token: + + with self.subTest('Cache present, no token'): + mock_redis.return_value.get.return_value = None + result = downloaders.HTTPDownloader.get_oauth2_token( + 'foo', 'bar', 'baz', 'qux', 'quux') + mock_redis.return_value.set.assert_called_with( + 'fd05b6f4dcd0c72512ea0cf6e1c94a6689353678', + pickled_fake_token, + ex=36000) + self.assertEqual(result, fake_token) + + with self.subTest('Cache present with token'): + mock_redis.return_value.get.return_value = pickled_fake_token + result = downloaders.HTTPDownloader.get_oauth2_token( + 'foo', 'bar', 'baz', 'qux', 'quux') + self.assertEqual(result, fake_token) + def test_get_basic_auth(self): """Test getting a basic authentication from get_auth()""" self.assertEqual(