Skip to content

Commit

Permalink
feat(genericjsonapi): add url_configs keyword arg
Browse files Browse the repository at this point in the history
  • Loading branch information
vergenzt committed Oct 15, 2024
1 parent 4c0017a commit c42165c
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 53 deletions.
23 changes: 23 additions & 0 deletions docs/adapters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,29 @@ Or via query parameters:
Note that if passing the headers via query parameters the dictionary should be serialized using `RISON <https://pypi.org/project/prison/>`_.

To ensure secret header values aren't sent to the wrong URL, you may wish to specify that certain request headers should only be sent to certain URL patterns. You may do so via the ``url_configs`` keyword argument, which should contain a map from URL patterns to adapter kwarg dicts. A URL pattern is determined to match a URL if:

1. the origins (scheme+host+port) are equal;
2. the pattern's slash-separated path parts are a prefix of the URL's path parts; and
3. the pattern's query parameter key-value tuples are a subset of the URL's query parameter key-value tuples.

Headers specified in config keys ``request_headers`` and ``url_configs.*.request_headers`` and in URI query param ``_s_headers`` are merged in that order, with later values overriding earlier ones.

For example, with the following config:

.. literalinclude:: /../tests/adapters/api/generic_json_test.py
:start-after: START DOC: url_configs_test_config
:end-before: END DOC: url_configs_test_config
:dedent:

here are some example URLs alongside their resulting ``Authorization`` and ``X-Source`` headers:

.. literalinclude:: /../tests/adapters/api/generic_json_test.py
:start-after: START DOC: url_configs_test_cases
:end-before: END DOC: url_configs_test_cases
:dedent:


Generic XML
===========

Expand Down
123 changes: 70 additions & 53 deletions src/shillelagh/adapters/api/generic_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@

import logging
from collections.abc import Iterator
from copy import deepcopy
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, Optional, TypedDict

import jsonpath
import prison
from requests_cache import CachedSession
from yarl import URL

from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import Field, Order
from shillelagh.filters import Filter
from shillelagh.lib import SimpleCostModel, analyze, flatten, get_session
from shillelagh.lib import (
SimpleCostModel,
analyze,
flatten,
get_session,
seq_startswith,
)
from shillelagh.typing import Maybe, RequestedOrder, Row

_logger = logging.getLogger(__name__)
Expand All @@ -28,6 +36,15 @@
CACHE_EXPIRATION = timedelta(minutes=3)


class URLConfig(TypedDict, total=False):
request_headers: dict[str, str]
cache_expiration: float


class Config(URLConfig, total=False):
url_configs: dict[str, URLConfig]


class GenericJSONAPI(Adapter):
"""
An adapter for fetching JSON data.
Expand All @@ -45,71 +62,71 @@ class GenericJSONAPI(Adapter):

@classmethod
def supports(cls, uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = URL(uri)
if parsed.scheme not in SUPPORTED_PROTOCOLS:
parsed_uri = URL(uri)
if parsed_uri.scheme not in SUPPORTED_PROTOCOLS:
return False
if fast:
return Maybe

if REQUEST_HEADERS_KEY in parsed.query:
request_headers = prison.loads(parsed.query[REQUEST_HEADERS_KEY])
parsed = parsed.with_query(
{k: v for k, v in parsed.query.items() if k != REQUEST_HEADERS_KEY},
)
else:
request_headers = kwargs.get("request_headers", {})

cache_expiration = kwargs.get(
"cache_expiration",
CACHE_EXPIRATION.total_seconds(),
)
session = get_session(
request_headers,
cls.cache_name,
timedelta(seconds=cache_expiration),
)
response = session.head(str(parsed))
parsed_uri, session = cls._get_session(parsed_uri, **kwargs)
response = session.head(str(parsed_uri))
return cls.content_type in response.headers.get("content-type", "")

@classmethod
def parse_uri(
cls,
uri: str,
) -> Union[tuple[str, str], tuple[str, str, dict[str, str]]]:
def parse_uri(cls, uri: str) -> tuple[str, str]:
parsed = URL(uri)
return (str(parsed.with_fragment("")), parsed.fragment or cls.default_path)

path = parsed.fragment or cls.default_path
parsed = parsed.with_fragment("")

if REQUEST_HEADERS_KEY in parsed.query:
request_headers = prison.loads(parsed.query[REQUEST_HEADERS_KEY])
parsed = parsed.with_query(
{k: v for k, v in parsed.query.items() if k != REQUEST_HEADERS_KEY},
)
return str(parsed), path, request_headers
def __init__(self, uri: str, path: str, **kwargs):
super().__init__()

return str(parsed), path
# may be decorated with json path (as fragment) or headers (as query param)
self.path = path
self.uri, self._session = self._get_session(URL(uri), **kwargs)

def __init__(
self,
uri: str,
path: Optional[str] = None,
request_headers: Optional[dict[str, str]] = None,
cache_expiration: float = CACHE_EXPIRATION.total_seconds(),
):
super().__init__()
self._set_columns()

self.uri = uri
self.path = path or self.default_path
@classmethod
def _get_session(cls, url: URL, **kwargs) -> tuple[URL, CachedSession]:
config: Config = deepcopy(kwargs) # type: ignore (need PEP 692 in Python 3.12+)

url_config: URLConfig = {
"request_headers": config.pop("request_headers", {}),
"cache_expiration": config.pop(
"cache_expiration", CACHE_EXPIRATION.total_seconds()
),
}

self._session = get_session(
request_headers or {},
self.cache_name,
timedelta(seconds=cache_expiration),
mutable_query = url.query.copy()
query_request_header_dicts = [
prison.loads(q) for q in mutable_query.popall(REQUEST_HEADERS_KEY, [])
]
url = url.with_query(mutable_query)

if url_configs := config.pop("url_configs", None):
for url_pat_str, url_pat_config in url_configs.items():
url_pat = URL(url_pat_str)

if (
url.origin() == url_pat.origin()
and seq_startswith(url.parts, url_pat.parts)
and set(url.query.values()) >= set(url_pat.query.values())
):
url_config["request_headers"].update(
url_pat_config.pop("request_headers", {})
)
url_config.update(url_pat_config)

# apply query headers last
for query_request_header_dict in query_request_header_dicts:
url_config["request_headers"].update(query_request_header_dict)

return url, get_session(
url_config["request_headers"],
cls.cache_name,
timedelta(seconds=url_config["cache_expiration"]),
)

self._set_columns()

def _set_columns(self) -> None:
rows = list(self.get_data({}, []))
column_names = list(rows[0].keys()) if rows else []
Expand Down Expand Up @@ -140,7 +157,7 @@ def get_data( # pylint: disable=unused-argument, too-many-arguments
) -> Iterator[Row]:
response = self._session.get(self.uri)
if not response.ok:
raise ProgrammingError(f'Error: {response.text}')
raise ProgrammingError(f"Error: {response.text}")

payload = response.json()

Expand Down
4 changes: 4 additions & 0 deletions src/shillelagh/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,10 @@ def flatten(row: Row) -> Row:
}


def seq_startswith(t1: Sequence[Any], t2: Sequence[Any]) -> bool:
return t1[: len(t2)] == t2


def best_index_object_available() -> bool:
"""
Check if support for best index object is available.
Expand Down
68 changes: 68 additions & 0 deletions tests/adapters/api/generic_json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re

import pytest
from requests_mock import ANY
from requests_mock.mocker import Mocker
from yarl import URL

Expand Down Expand Up @@ -211,6 +212,73 @@ def test_request_headers(requests_mock: Mocker) -> None:
assert data.last_request.headers["foo"] == "bar"


@pytest.mark.parametrize(
("url", "expected_auth", "expected_source"),
# START DOC: url_configs_test_cases
[
("https://api1.example.com/path", "xyz", "api1"),
("https://api1.example.com/path/1", "xyz", "api1"),
("https://api1.example.com/", "NOPE", "base"),
("https://api1.example.com:8080/path", "NOPE", "base"),
("https://example.com/path", "NOPE", "base"),
("http://api1.example.com/path", "NOPE", "base"),
("https://api2.example.com/?param=yes", "abc", "api2"),
("https://api2.example.com?param=yes", "abc", "api2"),
("https://api2.example.com?q=123&param=yes", "abc", "api2"),
("https://api2.example.com?param=no", "NOPE", "base"),
("https://api2.example.com?q=123", "NOPE", "base"),
("https://api2.example.com?q=123&param=yes&_s_headers=(Authorization:mine)", "mine", "api2"),
]
# END DOC: url_configs_test_cases
)
def test_request_headers_in_url_configs(
requests_mock: Mocker, url: str, expected_auth: str, expected_source: str
) -> None:
# START DOC: url_configs_test_config
adapter_kwargs = {
"genericjsonapi": {
"cache_expiration": -1,
"request_headers": {
"Authorization": "NOPE",
"X-Source": "base",
},
"url_configs": {
"https://api1.example.com/path": {
"request_headers": {
"Authorization": "xyz",
"X-Source": "api1",
},
},
"https://api2.example.com?param=yes": {
"request_headers": {
"Authorization": "abc",
"X-Source": "api2",
},
},
},
}
}
connection = connect(":memory:", adapter_kwargs=adapter_kwargs)
# END DOC: url_configs_test_config
cursor = connection.cursor()

requests_mock.register_uri(
method=ANY,
url=ANY,
status_code=200,
headers={"Content-Type": "application/json"},
json=lambda req, _ctx: [dict(req.headers)],
)

assert GenericJSONAPI.supports(url, fast=False, **adapter_kwargs["genericjsonapi"])
assert cursor.execute(
f'SELECT "Authorization", "X-Source" FROM "{url}"'
).fetchone() == (
expected_auth,
expected_source
)


def test_request_headers_in_url(requests_mock: Mocker) -> None:
"""
Test passing requests headers.
Expand Down
21 changes: 21 additions & 0 deletions tests/lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_not_null,
is_null,
serialize,
seq_startswith,
uncombine_args_kwargs,
unescape_identifier,
unescape_string,
Expand Down Expand Up @@ -371,6 +372,12 @@ def test_deserialize() -> None:
{"b": "TEST", "c": 20.0},
(0, "TEST", {"c": 20.0}),
),
(
'def func(a: int = 0, b: str = "test", *, c: float = 10.0) -> None: pass',
(),
{"b": "TEST"},
(0, "TEST", {"c": 10.0}),
),
],
)
def test_combine_uncombine_args_kwargs(
Expand Down Expand Up @@ -546,3 +553,17 @@ def test_get_session_namespaced(mocker: MockerFixture) -> None:
backend="sqlite",
expire_after=10,
)


@pytest.mark.parametrize(
("seq1", "seq2", "result"),
[
((1, 2, 3), (), True),
((1, 2, 3), (1, 2), True),
((1, 2, 3), (1, 2, 3, 4, 5), False),
((), (), True),
((), (1, 2), False),
],
)
def test_seq_startswith(seq1: Tuple[int], seq2: Tuple[int], result: bool):
assert seq_startswith(seq1, seq2) == result

0 comments on commit c42165c

Please sign in to comment.