Skip to content

Commit

Permalink
feat: generic XML adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Oct 21, 2023
1 parent cd2c6b1 commit 7f197ff
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
max-line-length = 79
max-line-length = 90
max-complexity = 18
select = B,C,E,F,W,T4,B9
19 changes: 19 additions & 0 deletions examples/generic_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
A simple example showing the generic XML.
"""
import sys

from shillelagh.backends.apsw.db import connect

if __name__ == "__main__":
API_KEY = sys.argv[1]

connection = connect(":memory:")
cursor = connection.cursor()

SQL = f"""
SELECT congress, type, latestAction FROM
"https://api.congress.gov/v3/bill/118?format=xml&offset=0&limit=2&api_key={API_KEY}#.//bill"
"""
for row in cursor.execute(SQL):
print(row)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
# See configuration details in https://github.com/pypa/setuptools_scm
version_scheme = "no-guess-dev"

[tool.flake8]
max-line-length = 90
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ codespell==2.1.0
# via shillelagh
coverage[toml]==6.4.2
# via pytest-cov
defusedxml==0.7.1
# via shillelagh
dill==0.3.6
# via
# pylint
Expand Down
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ testing =
beautifulsoup4>=4.11.1
boto3>=1.24.28
codespell>=2.1.0
defusedxml>=0.7.1
dill>=0.3.6
freezegun>=1.1.0
google-auth>=1.23.0
Expand Down Expand Up @@ -135,6 +136,11 @@ genericjsonapi =
prison>=0.2.1
requests-cache>=0.7.1
yarl>=1.8.1
genericxmlapi =
defusedxml>=0.7.1
prison>=0.2.1
requests-cache>=0.7.1
yarl>=1.8.1
githubapi =
jsonpath-python>=1.0.5
gsheetsapi =
Expand All @@ -160,6 +166,7 @@ shillelagh.adapter =
csvfile = shillelagh.adapters.file.csvfile:CSVFile
datasetteapi = shillelagh.adapters.api.datasette:DatasetteAPI
genericjsonapi = shillelagh.adapters.api.generic_json:GenericJSONAPI
genericxmlapi = shillelagh.adapters.api.generic_xml:GenericXMLAPI
githubapi = shillelagh.adapters.api.github:GitHubAPI
gsheetsapi = shillelagh.adapters.api.gsheets.adapter:GSheetsAPI
htmltableapi = shillelagh.adapters.api.html_table:HTMLTableAPI
Expand Down
55 changes: 26 additions & 29 deletions src/shillelagh/adapters/api/generic_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,26 @@
# pylint: disable=invalid-name

import logging
from datetime import timedelta
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import prison
import requests_cache
from jsonpath import JSONPath
from yarl import URL

from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import Field
from shillelagh.fields import Field, Order
from shillelagh.filters import Filter
from shillelagh.lib import SimpleCostModel, analyze, flatten
from shillelagh.lib import SimpleCostModel, analyze, flatten, get_session
from shillelagh.typing import Maybe, RequestedOrder, Row

_logger = logging.getLogger(__name__)

SUPPORTED_PROTOCOLS = {"http", "https"}
AVERAGE_NUMBER_OF_ROWS = 100
CACHE_EXPIRATION = 180
REQUEST_HEADERS_KEY = "_s_headers"


def get_session(request_headers: Dict[str, str]) -> requests_cache.CachedSession:
"""
Return a cached session.
"""
session = requests_cache.CachedSession(
cache_name="generic_json_cache",
backend="sqlite",
expire_after=CACHE_EXPIRATION,
)
session.headers.update(request_headers)

return session
CACHE_EXPIRATION = timedelta(minutes=3)


class GenericJSONAPI(Adapter):
Expand All @@ -53,8 +39,12 @@ class GenericJSONAPI(Adapter):
supports_offset = False
supports_requested_columns = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
content_type = "application/json"
default_path = "$[*]"
cache_name = "generic_json_cache"

@classmethod
def supports(cls, uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = URL(uri)
if parsed.scheme not in SUPPORTED_PROTOCOLS:
return False
Expand All @@ -69,15 +59,18 @@ def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
else:
request_headers = kwargs.get("request_headers", {})

session = get_session(request_headers)
session = get_session(request_headers, cls.cache_name, CACHE_EXPIRATION)
response = session.head(str(parsed))
return "application/json" in response.headers.get("content-type", "")
return cls.content_type in response.headers.get("content-type", "")

@staticmethod
def parse_uri(uri: str) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]]]:
@classmethod
def parse_uri(
cls,
uri: str,
) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]]]:
parsed = URL(uri)

path = parsed.fragment or "$[*]"
path = parsed.fragment or cls.default_path
parsed = parsed.with_fragment("")

if REQUEST_HEADERS_KEY in parsed.query:
Expand All @@ -92,15 +85,19 @@ def parse_uri(uri: str) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]
def __init__(
self,
uri: str,
path: str = "$[*]",
path: Optional[str] = None,
request_headers: Optional[Dict[str, str]] = None,
):
super().__init__()

self.uri = uri
self.path = path
self.path = path or self.default_path

self._session = get_session(request_headers or {})
self._session = get_session(
request_headers or {},
self.cache_name,
CACHE_EXPIRATION,
)

self._set_columns()

Expand All @@ -113,7 +110,7 @@ def _set_columns(self) -> None:
self.columns = {
column_name: types[column_name](
filters=[],
order=order[column_name],
order=order.get(column_name, Order.NONE),
exact=False,
)
for column_name in column_names
Expand Down
85 changes: 85 additions & 0 deletions src/shillelagh/adapters/api/generic_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
An adapter for fetching XML data.
"""

import logging
import xml.etree.ElementTree as ET
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple

from defusedxml import ElementTree as DET

from shillelagh.adapters.api.generic_json import GenericJSONAPI
from shillelagh.exceptions import ProgrammingError
from shillelagh.filters import Filter
from shillelagh.lib import flatten
from shillelagh.typing import RequestedOrder, Row

_logger = logging.getLogger(__name__)


def element_to_dict(element: ET.Element) -> Any:
"""
Convert XML element to a dictionary, recursively.
This uses a super simple algorithm that focuses on text and ignores attributes.
"""
if element.text and element.text.strip():
return element.text.strip()

result: Dict[str, Any] = {}
for child in element:
child_data = element_to_dict(child)
if child.tag in result:
# Convert to a list if multiple elements with the same tag exist
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag], child_data]
else:
result[child.tag].append(child_data)
else:
result.update({child.tag: child_data})

return result


class GenericXMLAPI(GenericJSONAPI):

"""
An adapter for fetching XML data.
"""

safe = True

supports_limit = False
supports_offset = False
supports_requested_columns = True

content_type = "xml" # works with text/xml and application/xml
default_path = "*"
cache_name = "generic_xml_cache"

def get_data( # pylint: disable=unused-argument, too-many-arguments
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
requested_columns: Optional[Set[str]] = None,
**kwargs: Any,
) -> Iterator[Row]:
response = self._session.get(self.uri)
payload = response.content.decode("utf-8")
if not response.ok:
raise ProgrammingError(f"Error: {payload}")

root = DET.fromstring(payload)
result = root.findall(self.path)
for i, element in enumerate(result):
row = element_to_dict(element)
row = {
k: v
for k, v in row.items()
if requested_columns is None or k in requested_columns
}
row["rowid"] = i
_logger.debug(row)
yield flatten(row)
25 changes: 24 additions & 1 deletion src/shillelagh/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import marshal
import math
import operator
from datetime import timedelta
from typing import (
Any,
Callable,
Expand All @@ -20,6 +21,7 @@
)

import apsw
import requests_cache
from packaging.version import Version

from shillelagh.adapters.base import Adapter
Expand All @@ -39,6 +41,7 @@
from shillelagh.typing import RequestedOrder, Row

DELETED = range(-1, 0)
CACHE_EXPIRATION = timedelta(minutes=3)


class RowIDManager:
Expand Down Expand Up @@ -168,7 +171,7 @@ def analyze( # pylint: disable=too-many-branches
for column_name, value in row.items():
# determine order
if i > 0:
previous = previous_row[column_name]
previous = previous_row.get(column_name)
order[column_name] = update_order(
current_order=order.get(column_name, Order.NONE),
previous=previous,
Expand Down Expand Up @@ -594,3 +597,23 @@ def best_index_object_available() -> bool:
Check if support for best index object is available.
"""
return bool(Version(apsw.apswversion()) >= Version("3.41.0.0"))


def get_session(
request_headers: Dict[str, str],
cache_name: str,
expire_after: timedelta = CACHE_EXPIRATION,
) -> requests_cache.CachedSession: # E: line too long (81 > 79 characters)
"""
Return a cached session.
"""
session = requests_cache.CachedSession(
cache_name=cache_name,
backend="sqlite",
expire_after=requests_cache.DO_NOT_CACHE
if expire_after == timedelta(seconds=-1)
else expire_after.total_seconds(),
)
session.headers.update(request_headers)

return session
7 changes: 5 additions & 2 deletions tests/adapters/api/generic_json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Test the generic JSON adapter.
"""

from datetime import timedelta

import pytest
from pytest_mock import MockerFixture
from requests_cache import DO_NOT_CACHE
from requests_mock.mocker import Mocker
from yarl import URL

Expand All @@ -13,6 +14,8 @@
from shillelagh.exceptions import ProgrammingError
from shillelagh.typing import Maybe

DO_NOT_CACHE = timedelta(seconds=-1)

baseurl = URL("https://api.stlouisfed.org/fred/series")


Expand Down Expand Up @@ -174,7 +177,7 @@ def test_request_headers(mocker: MockerFixture, requests_mock: Mocker) -> None:
)

# for datassette and other probing adapters
requests_mock.head("https://exmaple.org/-/versions.json", status_code=404)
requests_mock.head("https://example.org/-/versions.json", status_code=404)

url = URL("https://example.org/")
data = requests_mock.head(str(url), headers={"content-type": "application/json"})
Expand Down
Loading

0 comments on commit 7f197ff

Please sign in to comment.