diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6bd12667..48ac3948 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -14,10 +14,12 @@ import time import urllib import uuid -from typing import Dict, Optional +from contextlib import nullcontext as does_not_raise +from typing import Any, Dict, Optional from unittest import mock from urllib.parse import urlparse +import gssapi import httpretty import pytest import requests @@ -865,6 +867,73 @@ def test_extra_credential_value_encoding(mock_get_and_post): assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84" +class MockGssapiCredentials: + def __init__(self, name: gssapi.Name, usage: str): + self.name = name + self.usage = usage + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, MockGssapiCredentials): + return False + return ( + self.name == other.name, + self.usage == other.usage, + ) + + +@pytest.fixture +def mock_gssapi_creds(monkeypatch): + monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials) + + +def _gssapi_uname(spn: str): + return gssapi.Name(spn, gssapi.NameType.user) + + +def _gssapi_sname(principal: str): + return gssapi.Name(principal, gssapi.NameType.hostbased_service) + + +@pytest.mark.parametrize( + "options, expected_credentials, expected_hostname, expected_exception", + [ + ( + {}, None, None, does_not_raise(), + ), + ( + {"hostname_override": "foo"}, None, "foo", does_not_raise(), + ), + ( + {"service_name": "bar"}, None, None, + pytest.raises(ValueError, match=r"must be used together with hostname_override"), + ), + ( + {"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(), + ), + ( + {"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(), + ), + ] +) +def test_authentication_gssapi_init_arguments( + options, + expected_credentials, + expected_hostname, + expected_exception, + mock_gssapi_creds, + monkeypatch, +): + auth = GSSAPIAuthentication(**options) + + session = requests.Session() + + with expected_exception: + auth.set_http_session(session) + + assert session.auth.target_name == expected_hostname + assert session.auth.creds == expected_credentials + + class RetryRecorder(object): def __init__(self, error=None, result=None): self.__name__ = "RetryRecorder"