diff --git a/docs/adapters.rst b/docs/adapters.rst index bfbae712..6da2a5d1 100644 --- a/docs/adapters.rst +++ b/docs/adapters.rst @@ -460,6 +460,29 @@ Or via query parameters: Note that if passing the headers via query parameters the dictionary should be serialized using `RISON `_. +To ensure secret header values aren't sent to the wrong URL, you may wish to specify that certain request headers should only be sent to certain URL patterns. You may do so via the ``url_configs`` keyword argument, which should contain a map from URL patterns to adapter kwarg dicts. A URL pattern is determined to match a URL if: + + 1. the origins (scheme+host+port) are equal; + 2. the pattern's slash-separated path parts are a prefix of the URL's path parts; and + 3. the pattern's query parameter key-value tuples are a subset of the URL's query parameter key-value tuples. + +Headers specified in config keys ``request_headers`` and ``url_configs.*.request_headers`` and in URI query param ``_s_headers`` are merged in that order, with later values overriding earlier ones. + +For example, with the following config: + +.. literalinclude:: /../tests/adapters/api/generic_json_test.py + :start-after: START DOC: url_configs_test_config + :end-before: END DOC: url_configs_test_config + :dedent: + +here are some example URLs alongside their resulting ``Authorization`` and ``X-Source`` headers: + +.. literalinclude:: /../tests/adapters/api/generic_json_test.py + :start-after: START DOC: url_configs_test_cases + :end-before: END DOC: url_configs_test_cases + :dedent: + + Generic XML =========== diff --git a/src/shillelagh/adapters/api/generic_json.py b/src/shillelagh/adapters/api/generic_json.py index 6c9d9c06..b8be2452 100644 --- a/src/shillelagh/adapters/api/generic_json.py +++ b/src/shillelagh/adapters/api/generic_json.py @@ -6,18 +6,26 @@ import logging from collections.abc import Iterator +from copy import deepcopy from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Optional, TypedDict import jsonpath import prison +from requests_cache import CachedSession from yarl import URL from shillelagh.adapters.base import Adapter from shillelagh.exceptions import ProgrammingError from shillelagh.fields import Field, Order from shillelagh.filters import Filter -from shillelagh.lib import SimpleCostModel, analyze, flatten, get_session +from shillelagh.lib import ( + SimpleCostModel, + analyze, + flatten, + get_session, + seq_startswith, +) from shillelagh.typing import Maybe, RequestedOrder, Row _logger = logging.getLogger(__name__) @@ -28,6 +36,15 @@ CACHE_EXPIRATION = timedelta(minutes=3) +class URLConfig(TypedDict, total=False): + request_headers: dict[str, str] + cache_expiration: float + + +class Config(URLConfig, total=False): + url_configs: dict[str, URLConfig] + + class GenericJSONAPI(Adapter): """ An adapter for fetching JSON data. @@ -45,71 +62,71 @@ class GenericJSONAPI(Adapter): @classmethod def supports(cls, uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]: - parsed = URL(uri) - if parsed.scheme not in SUPPORTED_PROTOCOLS: + parsed_uri = URL(uri) + if parsed_uri.scheme not in SUPPORTED_PROTOCOLS: return False if fast: return Maybe - if REQUEST_HEADERS_KEY in parsed.query: - request_headers = prison.loads(parsed.query[REQUEST_HEADERS_KEY]) - parsed = parsed.with_query( - {k: v for k, v in parsed.query.items() if k != REQUEST_HEADERS_KEY}, - ) - else: - request_headers = kwargs.get("request_headers", {}) - - cache_expiration = kwargs.get( - "cache_expiration", - CACHE_EXPIRATION.total_seconds(), - ) - session = get_session( - request_headers, - cls.cache_name, - timedelta(seconds=cache_expiration), - ) - response = session.head(str(parsed)) + parsed_uri, session = cls._get_session(parsed_uri, **kwargs) + response = session.head(str(parsed_uri)) return cls.content_type in response.headers.get("content-type", "") @classmethod - def parse_uri( - cls, - uri: str, - ) -> Union[tuple[str, str], tuple[str, str, dict[str, str]]]: + def parse_uri(cls, uri: str) -> tuple[str, str]: parsed = URL(uri) + return (str(parsed.with_fragment("")), parsed.fragment or cls.default_path) - path = parsed.fragment or cls.default_path - parsed = parsed.with_fragment("") - - if REQUEST_HEADERS_KEY in parsed.query: - request_headers = prison.loads(parsed.query[REQUEST_HEADERS_KEY]) - parsed = parsed.with_query( - {k: v for k, v in parsed.query.items() if k != REQUEST_HEADERS_KEY}, - ) - return str(parsed), path, request_headers + def __init__(self, uri: str, path: str, **kwargs): + super().__init__() - return str(parsed), path + # may be decorated with json path (as fragment) or headers (as query param) + self.path = path + self.uri, self._session = self._get_session(URL(uri), **kwargs) - def __init__( - self, - uri: str, - path: Optional[str] = None, - request_headers: Optional[dict[str, str]] = None, - cache_expiration: float = CACHE_EXPIRATION.total_seconds(), - ): - super().__init__() + self._set_columns() - self.uri = uri - self.path = path or self.default_path + @classmethod + def _get_session(cls, url: URL, **kwargs) -> tuple[URL, CachedSession]: + config: Config = deepcopy(kwargs) # type: ignore (need PEP 692 in Python 3.12+) + + url_config: URLConfig = { + "request_headers": config.pop("request_headers", {}), + "cache_expiration": config.pop( + "cache_expiration", CACHE_EXPIRATION.total_seconds() + ), + } - self._session = get_session( - request_headers or {}, - self.cache_name, - timedelta(seconds=cache_expiration), + mutable_query = url.query.copy() + query_request_header_dicts = [ + prison.loads(q) for q in mutable_query.popall(REQUEST_HEADERS_KEY, []) + ] + url = url.with_query(mutable_query) + + if url_configs := config.pop("url_configs", None): + for url_pat_str, url_pat_config in url_configs.items(): + url_pat = URL(url_pat_str) + + if ( + url.origin() == url_pat.origin() + and seq_startswith(url.parts, url_pat.parts) + and set(url.query.values()) >= set(url_pat.query.values()) + ): + url_config["request_headers"].update( + url_pat_config.pop("request_headers", {}) + ) + url_config.update(url_pat_config) + + # apply query headers last + for query_request_header_dict in query_request_header_dicts: + url_config["request_headers"].update(query_request_header_dict) + + return url, get_session( + url_config["request_headers"], + cls.cache_name, + timedelta(seconds=url_config["cache_expiration"]), ) - self._set_columns() - def _set_columns(self) -> None: rows = list(self.get_data({}, [])) column_names = list(rows[0].keys()) if rows else [] @@ -140,7 +157,7 @@ def get_data( # pylint: disable=unused-argument, too-many-arguments ) -> Iterator[Row]: response = self._session.get(self.uri) if not response.ok: - raise ProgrammingError(f'Error: {response.text}') + raise ProgrammingError(f"Error: {response.text}") payload = response.json() diff --git a/src/shillelagh/lib.py b/src/shillelagh/lib.py index 3e7a18a6..c513ce2c 100644 --- a/src/shillelagh/lib.py +++ b/src/shillelagh/lib.py @@ -627,6 +627,10 @@ def flatten(row: Row) -> Row: } +def seq_startswith(t1: Sequence[Any], t2: Sequence[Any]) -> bool: + return t1[: len(t2)] == t2 + + def best_index_object_available() -> bool: """ Check if support for best index object is available. diff --git a/tests/adapters/api/generic_json_test.py b/tests/adapters/api/generic_json_test.py index 71535ed5..edd52623 100644 --- a/tests/adapters/api/generic_json_test.py +++ b/tests/adapters/api/generic_json_test.py @@ -5,6 +5,7 @@ import re import pytest +from requests_mock import ANY from requests_mock.mocker import Mocker from yarl import URL @@ -211,6 +212,73 @@ def test_request_headers(requests_mock: Mocker) -> None: assert data.last_request.headers["foo"] == "bar" +@pytest.mark.parametrize( + ("url", "expected_auth", "expected_source"), + # START DOC: url_configs_test_cases + [ + ("https://api1.example.com/path", "xyz", "api1"), + ("https://api1.example.com/path/1", "xyz", "api1"), + ("https://api1.example.com/", "NOPE", "base"), + ("https://api1.example.com:8080/path", "NOPE", "base"), + ("https://example.com/path", "NOPE", "base"), + ("http://api1.example.com/path", "NOPE", "base"), + ("https://api2.example.com/?param=yes", "abc", "api2"), + ("https://api2.example.com?param=yes", "abc", "api2"), + ("https://api2.example.com?q=123¶m=yes", "abc", "api2"), + ("https://api2.example.com?param=no", "NOPE", "base"), + ("https://api2.example.com?q=123", "NOPE", "base"), + ("https://api2.example.com?q=123¶m=yes&_s_headers=(Authorization:mine)", "mine", "api2"), + ] + # END DOC: url_configs_test_cases +) +def test_request_headers_in_url_configs( + requests_mock: Mocker, url: str, expected_auth: str, expected_source: str +) -> None: + # START DOC: url_configs_test_config + adapter_kwargs = { + "genericjsonapi": { + "cache_expiration": -1, + "request_headers": { + "Authorization": "NOPE", + "X-Source": "base", + }, + "url_configs": { + "https://api1.example.com/path": { + "request_headers": { + "Authorization": "xyz", + "X-Source": "api1", + }, + }, + "https://api2.example.com?param=yes": { + "request_headers": { + "Authorization": "abc", + "X-Source": "api2", + }, + }, + }, + } + } + connection = connect(":memory:", adapter_kwargs=adapter_kwargs) + # END DOC: url_configs_test_config + cursor = connection.cursor() + + requests_mock.register_uri( + method=ANY, + url=ANY, + status_code=200, + headers={"Content-Type": "application/json"}, + json=lambda req, _ctx: [dict(req.headers)], + ) + + assert GenericJSONAPI.supports(url, fast=False, **adapter_kwargs["genericjsonapi"]) + assert cursor.execute( + f'SELECT "Authorization", "X-Source" FROM "{url}"' + ).fetchone() == ( + expected_auth, + expected_source + ) + + def test_request_headers_in_url(requests_mock: Mocker) -> None: """ Test passing requests headers. diff --git a/tests/lib_test.py b/tests/lib_test.py index a8a7b534..cad63f97 100644 --- a/tests/lib_test.py +++ b/tests/lib_test.py @@ -37,6 +37,7 @@ is_not_null, is_null, serialize, + seq_startswith, uncombine_args_kwargs, unescape_identifier, unescape_string, @@ -371,6 +372,12 @@ def test_deserialize() -> None: {"b": "TEST", "c": 20.0}, (0, "TEST", {"c": 20.0}), ), + ( + 'def func(a: int = 0, b: str = "test", *, c: float = 10.0) -> None: pass', + (), + {"b": "TEST"}, + (0, "TEST", {"c": 10.0}), + ), ], ) def test_combine_uncombine_args_kwargs( @@ -546,3 +553,17 @@ def test_get_session_namespaced(mocker: MockerFixture) -> None: backend="sqlite", expire_after=10, ) + + +@pytest.mark.parametrize( + ("seq1", "seq2", "result"), + [ + ((1, 2, 3), (), True), + ((1, 2, 3), (1, 2), True), + ((1, 2, 3), (1, 2, 3, 4, 5), False), + ((), (), True), + ((), (1, 2), False), + ], +) +def test_seq_startswith(seq1: Tuple[int], seq2: Tuple[int], result: bool): + assert seq_startswith(seq1, seq2) == result