-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add some unit tests to datacosmos client
- Loading branch information
1 parent
1efe6ed
commit 0f1479f
Showing
4 changed files
with
222 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,141 @@ | ||
"""Datacosmos client for interacting with the Datacosmos API. | ||
Provides an authenticated HTTP client and convenience methods for HTTP requests. | ||
""" | ||
|
||
import logging | ||
import os | ||
from datetime import datetime, timedelta, timezone | ||
from typing import Any, Optional | ||
from typing import Optional, Any | ||
|
||
import requests | ||
from oauthlib.oauth2 import BackendApplicationClient | ||
from requests_oauthlib import OAuth2Session | ||
from oauthlib.oauth2 import BackendApplicationClient | ||
from requests.exceptions import RequestException | ||
|
||
from config.config import Config | ||
|
||
|
||
class DatacosmosClient: | ||
"""DatacosmosClient handles authenticated interactions with the Datacosmos API. | ||
""" | ||
DatacosmosClient handles authenticated interactions with the Datacosmos API. | ||
Automatically manages token refreshing and provides HTTP convenience methods. | ||
""" | ||
|
||
def __init__( | ||
self, config: Optional[Config] = None, config_file: str = "config/config.yaml" | ||
): | ||
"""Initialize the DatacosmosClient. | ||
""" | ||
Initialize the DatacosmosClient. | ||
If no configuration is provided, it will load from the specified YAML file | ||
or fall back to environment variables. | ||
""" | ||
self.logger = logging.getLogger(__name__) | ||
self.logger.setLevel(logging.INFO) | ||
|
||
self.config = config or self._load_config(config_file) | ||
self.token = None | ||
self.token_expiry = None | ||
self._http_client = self._authenticate_and_initialize_client() | ||
|
||
def _load_config(self, config_file: str) -> Config: | ||
"""Load configuration from the YAML file. Fall back to environment variables if the file is missing.""" | ||
if os.path.exists(config_file): | ||
return Config.from_yaml(config_file) | ||
return Config.from_env() | ||
""" | ||
Load configuration from the YAML file. Fall back to environment variables if the file is missing. | ||
""" | ||
try: | ||
if os.path.exists(config_file): | ||
self.logger.info(f"Loading configuration from {config_file}") | ||
return Config.from_yaml(config_file) | ||
self.logger.info("Loading configuration from environment variables") | ||
return Config.from_env() | ||
except Exception as e: | ||
self.logger.error(f"Failed to load configuration: {e}") | ||
raise | ||
|
||
def _authenticate_and_initialize_client(self) -> requests.Session: | ||
"""Authenticate and initialize the HTTP client with a valid token.""" | ||
client = BackendApplicationClient(client_id=self.config.client_id) | ||
oauth_session = OAuth2Session(client=client) | ||
|
||
# Fetch the token using client credentials | ||
token_response = oauth_session.fetch_token( | ||
token_url=self.config.token_url, | ||
client_id=self.config.client_id, | ||
client_secret=self.config.client_secret, | ||
audience=self.config.audience, | ||
) | ||
|
||
self.token = token_response["access_token"] | ||
self.token_expiry = datetime.now(timezone.utc) + timedelta( | ||
seconds=token_response.get("expires_in", 3600) | ||
) | ||
|
||
# Initialize the HTTP session with the Authorization header | ||
http_client = requests.Session() | ||
http_client.headers.update({"Authorization": f"Bearer {self.token}"}) | ||
|
||
return http_client | ||
""" | ||
Authenticate and initialize the HTTP client with a valid token. | ||
""" | ||
try: | ||
self.logger.info("Authenticating with the token endpoint") | ||
client = BackendApplicationClient(client_id=self.config.client_id) | ||
oauth_session = OAuth2Session(client=client) | ||
|
||
# Fetch the token using client credentials | ||
token_response = oauth_session.fetch_token( | ||
token_url=self.config.token_url, | ||
client_id=self.config.client_id, | ||
client_secret=self.config.client_secret, | ||
audience=self.config.audience, | ||
) | ||
|
||
self.token = token_response["access_token"] | ||
self.token_expiry = datetime.now(timezone.utc) + timedelta( | ||
seconds=token_response.get("expires_in", 3600) | ||
) | ||
self.logger.info("Authentication successful, token obtained") | ||
|
||
# Initialize the HTTP session with the Authorization header | ||
http_client = requests.Session() | ||
http_client.headers.update({"Authorization": f"Bearer {self.token}"}) | ||
return http_client | ||
except RequestException as e: | ||
self.logger.error(f"Request failed during authentication: {e}") | ||
raise | ||
except Exception as e: | ||
self.logger.error(f"Unexpected error during authentication: {e}") | ||
raise | ||
|
||
def _refresh_token_if_needed(self): | ||
"""Refresh the token if it has expired.""" | ||
""" | ||
Refresh the token if it has expired. | ||
""" | ||
if not self.token or self.token_expiry <= datetime.now(timezone.utc): | ||
self.logger.info("Token expired or missing, refreshing token") | ||
self._http_client = self._authenticate_and_initialize_client() | ||
|
||
def get_http_client(self) -> requests.Session: | ||
"""Return the authenticated HTTP client, refreshing the token if necessary.""" | ||
""" | ||
Return the authenticated HTTP client, refreshing the token if necessary. | ||
""" | ||
self._refresh_token_if_needed() | ||
return self._http_client | ||
|
||
def request( | ||
self, method: str, url: str, *args: Any, **kwargs: Any | ||
) -> requests.Response: | ||
"""Send an HTTP request using the authenticated session.""" | ||
def request(self, method: str, url: str, *args: Any, **kwargs: Any) -> requests.Response: | ||
""" | ||
Send an HTTP request using the authenticated session. | ||
Logs request and response details. | ||
""" | ||
self._refresh_token_if_needed() | ||
return self._http_client.request(method, url, *args, **kwargs) | ||
try: | ||
self.logger.info(f"Making {method.upper()} request to {url}") | ||
response = self._http_client.request(method, url, *args, **kwargs) | ||
response.raise_for_status() | ||
self.logger.info(f"Request to {url} succeeded with status {response.status_code}") | ||
return response | ||
except RequestException as e: | ||
self.logger.error(f"HTTP request failed: {e}") | ||
raise | ||
except Exception as e: | ||
self.logger.error(f"Unexpected error during HTTP request: {e}") | ||
Check failure on line 116 in datacosmos/client.py GitHub Actions / Blackdatacosmos/client.py#L95-L116
|
||
raise | ||
|
||
def get(self, url: str, *args: Any, **kwargs: Any) -> requests.Response: | ||
"""Send a GET request using the authenticated session.""" | ||
""" | ||
Send a GET request using the authenticated session. | ||
""" | ||
return self.request("GET", url, *args, **kwargs) | ||
|
||
def post(self, url: str, *args: Any, **kwargs: Any) -> requests.Response: | ||
"""Send a POST request using the authenticated session.""" | ||
""" | ||
Send a POST request using the authenticated session. | ||
""" | ||
return self.request("POST", url, *args, **kwargs) | ||
|
||
def put(self, url: str, *args: Any, **kwargs: Any) -> requests.Response: | ||
"""Send a PUT request using the authenticated session.""" | ||
""" | ||
Send a PUT request using the authenticated session. | ||
""" | ||
return self.request("PUT", url, *args, **kwargs) | ||
|
||
def delete(self, url: str, *args: Any, **kwargs: Any) -> requests.Response: | ||
"""Send a DELETE request using the authenticated session.""" | ||
""" | ||
Send a DELETE request using the authenticated session. | ||
""" | ||
return self.request("DELETE", url, *args, **kwargs) |
53 changes: 53 additions & 0 deletions
53
tests/unit/datacosmos/client/test_client_authentication.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from unittest.mock import patch, MagicMock | ||
from datacosmos.client import DatacosmosClient | ||
from config.config import Config | ||
|
||
|
||
@patch("datacosmos.client.OAuth2Session.fetch_token") | ||
@patch("datacosmos.client.DatacosmosClient._authenticate_and_initialize_client", autospec=True) | ||
def test_client_authentication(mock_auth_client, mock_fetch_token): | ||
""" | ||
Test that the client correctly fetches a token during authentication. | ||
""" | ||
# Mock the token response from OAuth2Session | ||
mock_fetch_token.return_value = { | ||
Check failure on line 13 in tests/unit/datacosmos/client/test_client_authentication.py GitHub Actions / Blacktests/unit/datacosmos/client/test_client_authentication.py#L2-L13
|
||
"access_token": "mock-access-token", | ||
"expires_in": 3600, | ||
} | ||
|
||
# Simulate _authenticate_and_initialize_client calling fetch_token | ||
def mock_authenticate_and_initialize_client(self): | ||
# Call the real fetch_token (simulated by the mock) | ||
token_response = mock_fetch_token( | ||
token_url=self.config.token_url, | ||
client_id=self.config.client_id, | ||
client_secret=self.config.client_secret, | ||
audience=self.config.audience, | ||
) | ||
self.token = token_response["access_token"] | ||
self.token_expiry = "mock-expiry" | ||
|
||
# Attach the side effect to the mock | ||
mock_auth_client.side_effect = mock_authenticate_and_initialize_client | ||
|
||
# Create a mock configuration | ||
config = Config( | ||
client_id="test-client-id", | ||
client_secret="test-client-secret", | ||
token_url="https://mock.token.url/oauth/token", | ||
audience="https://mock.audience", | ||
) | ||
|
||
# Initialize the client | ||
client = DatacosmosClient(config=config) | ||
|
||
# Assertions | ||
assert client.token == "mock-access-token" | ||
assert client.token_expiry == "mock-expiry" | ||
mock_fetch_token.assert_called_once_with( | ||
token_url="https://mock.token.url/oauth/token", | ||
client_id="test-client-id", | ||
client_secret="test-client-secret", | ||
audience="https://mock.audience", | ||
) | ||
mock_auth_client.assert_called_once_with(client) |
28 changes: 28 additions & 0 deletions
28
tests/unit/datacosmos/client/test_client_initialization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from unittest.mock import patch, MagicMock | ||
from datacosmos.client import DatacosmosClient | ||
from config.config import Config | ||
|
||
|
||
@patch("datacosmos.client.DatacosmosClient._authenticate_and_initialize_client") | ||
@patch("os.path.exists", return_value=False) | ||
@patch("config.Config.from_env") | ||
def test_client_initialization(mock_from_env, mock_exists, mock_auth_client): | ||
""" | ||
Test that the client initializes correctly with environment variables and mocks the HTTP client. | ||
""" | ||
mock_config = Config( | ||
client_id="test-client-id", | ||
client_secret="test-client-secret", | ||
token_url="https://mock.token.url/oauth/token", | ||
audience="https://mock.audience", | ||
) | ||
mock_from_env.return_value = mock_config | ||
mock_auth_client.return_value = MagicMock() # Mock the HTTP client | ||
|
||
client = DatacosmosClient() | ||
|
||
assert client.config == mock_config | ||
assert client._http_client is not None # Ensure the HTTP client is mocked | ||
mock_exists.assert_called_once_with("config/config.yaml") | ||
mock_from_env.assert_called_once() | ||
mock_auth_client.assert_called_once() |
50 changes: 50 additions & 0 deletions
50
tests/unit/datacosmos/client/test_client_token_refreshing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from unittest.mock import patch, MagicMock | ||
from datetime import datetime, timedelta, timezone | ||
from datacosmos.client import DatacosmosClient | ||
from config.config import Config | ||
|
||
|
||
@patch("datacosmos.client.DatacosmosClient._authenticate_and_initialize_client") | ||
def test_token_refreshing(mock_auth_client): | ||
""" | ||
Test that the client refreshes the token when it expires. | ||
""" | ||
# Mock the HTTP client returned by _authenticate_and_initialize_client | ||
mock_http_client = MagicMock() | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.json.return_value = {"message": "success"} | ||
mock_http_client.request.return_value = mock_response | ||
mock_auth_client.return_value = mock_http_client | ||
|
||
config = Config( | ||
client_id="test-client-id", | ||
client_secret="test-client-secret", | ||
token_url="https://mock.token.url/oauth/token", | ||
audience="https://mock.audience", | ||
) | ||
|
||
# Initialize the client (first call to _authenticate_and_initialize_client) | ||
client = DatacosmosClient(config=config) | ||
|
||
# Simulate expired token | ||
client.token_expiry = datetime.now(timezone.utc) - timedelta(seconds=1) | ||
|
||
# Make a GET request (should trigger token refresh) | ||
response = client.get("https://mock.api/some-endpoint", headers={"Authorization": f"Bearer {client.token}"}) | ||
|
||
# Assertions | ||
assert response.status_code == 200 | ||
assert response.json() == {"message": "success"} | ||
|
||
# Verify _authenticate_and_initialize_client was called twice: | ||
Check failure on line 40 in tests/unit/datacosmos/client/test_client_token_refreshing.py GitHub Actions / Blacktests/unit/datacosmos/client/test_client_token_refreshing.py#L29-L40
|
||
# 1. During initialization | ||
# 2. During token refresh | ||
assert mock_auth_client.call_count == 2 | ||
|
||
# Verify the request was made correctly | ||
mock_http_client.request.assert_called_once_with( | ||
"GET", | ||
"https://mock.api/some-endpoint", | ||
headers={"Authorization": f"Bearer {client.token}"}, | ||
) |