Skip to content

Commit

Permalink
feat: generic XML adapter (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Oct 21, 2023
1 parent cd2c6b1 commit d2e4f00
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
max-line-length = 79
max-line-length = 90
max-complexity = 18
select = B,C,E,F,W,T4,B9
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Next
====

- Add new cost model ``NetworkAPICostModel`` (#381)
- Add a generic XML adapter (#391)

Version 1.2.7 - 2023-08-14
==========================
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Queries like this are supported by `adapters <https://shillelagh.readthedocs.io/
CSV File/API ``/path/to/file.csv``; ``http(s)://*`` ``/home/user/sample_data.csv``
Datasette API ``http(s)://*`` ``https://global-power-plants.datasettes.com/global-power-plants/global-power-plants``
Generic JSON API ``http(s)://*`` ``https://api.stlouisfed.org/fred/series?series_id=GNPCA&api_key=XXX&file_type=json#$.seriess[*]``
Generic XML API ``http(s)://*`` ``https://api.congress.gov/v3/bill/118?format=xml&offset=0&limit=2&api_key=XXX#.//bill``
GitHub API ``https://api.github.com/repos/${owner}/{$repo}/pulls`` ``https://api.github.com/repos/apache/superset/pulls``
GSheets API ``https://docs.google.com/spreadsheets/d/${id}/edit#gid=${sheet_id}`` ``https://docs.google.com/spreadsheets/d/1LcWZMsdCl92g7nA-D6qGRqg1T5TiHyuKJUY1u9XAnsk/edit#gid=0``
HTML table API ``http(s)://*`` ``https://en.wikipedia.org/wiki/List_of_countries_and_dependencies_by_population``
Expand Down Expand Up @@ -134,6 +135,7 @@ You also need to install optional dependencies, depending on the adapter you wan
$ pip install 'shillelagh[console]' # to use the CLI
$ pip install 'shillelagh[datasetteapi]' # for Datasette
$ pip install 'shillelagh[genericjsonapi]' # for Generic JSON
$ pip install 'shillelagh[genericxmlapi]' # for Generic XML
$ pip install 'shillelagh[githubapi]' # for GitHub
$ pip install 'shillelagh[gsheetsapi]' # for GSheets
$ pip install 'shillelagh[htmltableapi]' # for HTML tables
Expand Down
16 changes: 16 additions & 0 deletions docs/adapters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,19 @@ Or via query parameters:
SELECT * FROM "https://api.example.com/?_s_headers=(X-Auth-Token:SECRET)"
Note that if passing the headers via query parameters the dictionary should be serialized using `RISON <https://pypi.org/project/prison/>`_.

Generic XML
===========

The generic XML adapter is based on the generic JSON; the only difference is that it takes XML responses and uses XPath to extract the data. The XML response is converted into a JSON equivalent payload that takes in consideration only text. For example, this XML:

.. code-block:: xml
<root>
<foo>bar</foo>
<baz>
<qux>quux</qux>o
</baz>
</root>
Would get mapped to two columns, ``foo`` and ``baz``, with values ``bar`` and ``{"qux": "quux"}`` respectively.
19 changes: 19 additions & 0 deletions examples/generic_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
A simple example showing the generic XML.
"""
import sys

from shillelagh.backends.apsw.db import connect

if __name__ == "__main__":
API_KEY = sys.argv[1]

connection = connect(":memory:")
cursor = connection.cursor()

SQL = f"""
SELECT congress, type, latestAction FROM
"https://api.congress.gov/v3/bill/118?format=xml&offset=0&limit=2&api_key={API_KEY}#.//bill"
"""
for row in cursor.execute(SQL):
print(row)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
# See configuration details in https://github.com/pypa/setuptools_scm
version_scheme = "no-guess-dev"

[tool.flake8]
max-line-length = 90
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ codespell==2.1.0
# via shillelagh
coverage[toml]==6.4.2
# via pytest-cov
defusedxml==0.7.1
# via shillelagh
dill==0.3.6
# via
# pylint
Expand Down
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ testing =
beautifulsoup4>=4.11.1
boto3>=1.24.28
codespell>=2.1.0
defusedxml>=0.7.1
dill>=0.3.6
freezegun>=1.1.0
google-auth>=1.23.0
Expand Down Expand Up @@ -135,6 +136,11 @@ genericjsonapi =
prison>=0.2.1
requests-cache>=0.7.1
yarl>=1.8.1
genericxmlapi =
defusedxml>=0.7.1
prison>=0.2.1
requests-cache>=0.7.1
yarl>=1.8.1
githubapi =
jsonpath-python>=1.0.5
gsheetsapi =
Expand All @@ -160,6 +166,7 @@ shillelagh.adapter =
csvfile = shillelagh.adapters.file.csvfile:CSVFile
datasetteapi = shillelagh.adapters.api.datasette:DatasetteAPI
genericjsonapi = shillelagh.adapters.api.generic_json:GenericJSONAPI
genericxmlapi = shillelagh.adapters.api.generic_xml:GenericXMLAPI
githubapi = shillelagh.adapters.api.github:GitHubAPI
gsheetsapi = shillelagh.adapters.api.gsheets.adapter:GSheetsAPI
htmltableapi = shillelagh.adapters.api.html_table:HTMLTableAPI
Expand Down
55 changes: 26 additions & 29 deletions src/shillelagh/adapters/api/generic_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,26 @@
# pylint: disable=invalid-name

import logging
from datetime import timedelta
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import prison
import requests_cache
from jsonpath import JSONPath
from yarl import URL

from shillelagh.adapters.base import Adapter
from shillelagh.exceptions import ProgrammingError
from shillelagh.fields import Field
from shillelagh.fields import Field, Order
from shillelagh.filters import Filter
from shillelagh.lib import SimpleCostModel, analyze, flatten
from shillelagh.lib import SimpleCostModel, analyze, flatten, get_session
from shillelagh.typing import Maybe, RequestedOrder, Row

_logger = logging.getLogger(__name__)

SUPPORTED_PROTOCOLS = {"http", "https"}
AVERAGE_NUMBER_OF_ROWS = 100
CACHE_EXPIRATION = 180
REQUEST_HEADERS_KEY = "_s_headers"


def get_session(request_headers: Dict[str, str]) -> requests_cache.CachedSession:
"""
Return a cached session.
"""
session = requests_cache.CachedSession(
cache_name="generic_json_cache",
backend="sqlite",
expire_after=CACHE_EXPIRATION,
)
session.headers.update(request_headers)

return session
CACHE_EXPIRATION = timedelta(minutes=3)


class GenericJSONAPI(Adapter):
Expand All @@ -53,8 +39,12 @@ class GenericJSONAPI(Adapter):
supports_offset = False
supports_requested_columns = True

@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
content_type = "application/json"
default_path = "$[*]"
cache_name = "generic_json_cache"

@classmethod
def supports(cls, uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
parsed = URL(uri)
if parsed.scheme not in SUPPORTED_PROTOCOLS:
return False
Expand All @@ -69,15 +59,18 @@ def supports(uri: str, fast: bool = True, **kwargs: Any) -> Optional[bool]:
else:
request_headers = kwargs.get("request_headers", {})

session = get_session(request_headers)
session = get_session(request_headers, cls.cache_name, CACHE_EXPIRATION)
response = session.head(str(parsed))
return "application/json" in response.headers.get("content-type", "")
return cls.content_type in response.headers.get("content-type", "")

@staticmethod
def parse_uri(uri: str) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]]]:
@classmethod
def parse_uri(
cls,
uri: str,
) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]]]:
parsed = URL(uri)

path = parsed.fragment or "$[*]"
path = parsed.fragment or cls.default_path
parsed = parsed.with_fragment("")

if REQUEST_HEADERS_KEY in parsed.query:
Expand All @@ -92,15 +85,19 @@ def parse_uri(uri: str) -> Union[Tuple[str, str], Tuple[str, str, Dict[str, str]
def __init__(
self,
uri: str,
path: str = "$[*]",
path: Optional[str] = None,
request_headers: Optional[Dict[str, str]] = None,
):
super().__init__()

self.uri = uri
self.path = path
self.path = path or self.default_path

self._session = get_session(request_headers or {})
self._session = get_session(
request_headers or {},
self.cache_name,
CACHE_EXPIRATION,
)

self._set_columns()

Expand All @@ -113,7 +110,7 @@ def _set_columns(self) -> None:
self.columns = {
column_name: types[column_name](
filters=[],
order=order[column_name],
order=order.get(column_name, Order.NONE),
exact=False,
)
for column_name in column_names
Expand Down
85 changes: 85 additions & 0 deletions src/shillelagh/adapters/api/generic_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
An adapter for fetching XML data.
"""

import logging
import xml.etree.ElementTree as ET
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple

from defusedxml import ElementTree as DET

from shillelagh.adapters.api.generic_json import GenericJSONAPI
from shillelagh.exceptions import ProgrammingError
from shillelagh.filters import Filter
from shillelagh.lib import flatten
from shillelagh.typing import RequestedOrder, Row

_logger = logging.getLogger(__name__)


def element_to_dict(element: ET.Element) -> Any:
"""
Convert XML element to a dictionary, recursively.
This uses a super simple algorithm that focuses on text and ignores attributes.
"""
if element.text and element.text.strip():
return element.text.strip()

result: Dict[str, Any] = {}
for child in element:
child_data = element_to_dict(child)
if child.tag in result:
# Convert to a list if multiple elements with the same tag exist
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag], child_data]
else:
result[child.tag].append(child_data)
else:
result.update({child.tag: child_data})

return result


class GenericXMLAPI(GenericJSONAPI):

"""
An adapter for fetching XML data.
"""

safe = True

supports_limit = False
supports_offset = False
supports_requested_columns = True

content_type = "xml" # works with text/xml and application/xml
default_path = "*"
cache_name = "generic_xml_cache"

def get_data( # pylint: disable=unused-argument, too-many-arguments
self,
bounds: Dict[str, Filter],
order: List[Tuple[str, RequestedOrder]],
limit: Optional[int] = None,
offset: Optional[int] = None,
requested_columns: Optional[Set[str]] = None,
**kwargs: Any,
) -> Iterator[Row]:
response = self._session.get(self.uri)
payload = response.content.decode("utf-8")
if not response.ok:
raise ProgrammingError(f"Error: {payload}")

root = DET.fromstring(payload)
result = root.findall(self.path)
for i, element in enumerate(result):
row = element_to_dict(element)
row = {
k: v
for k, v in row.items()
if requested_columns is None or k in requested_columns
}
row["rowid"] = i
_logger.debug(row)
yield flatten(row)
25 changes: 24 additions & 1 deletion src/shillelagh/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import marshal
import math
import operator
from datetime import timedelta
from typing import (
Any,
Callable,
Expand All @@ -20,6 +21,7 @@
)

import apsw
import requests_cache
from packaging.version import Version

from shillelagh.adapters.base import Adapter
Expand All @@ -39,6 +41,7 @@
from shillelagh.typing import RequestedOrder, Row

DELETED = range(-1, 0)
CACHE_EXPIRATION = timedelta(minutes=3)


class RowIDManager:
Expand Down Expand Up @@ -168,7 +171,7 @@ def analyze( # pylint: disable=too-many-branches
for column_name, value in row.items():
# determine order
if i > 0:
previous = previous_row[column_name]
previous = previous_row.get(column_name)
order[column_name] = update_order(
current_order=order.get(column_name, Order.NONE),
previous=previous,
Expand Down Expand Up @@ -594,3 +597,23 @@ def best_index_object_available() -> bool:
Check if support for best index object is available.
"""
return bool(Version(apsw.apswversion()) >= Version("3.41.0.0"))


def get_session(
request_headers: Dict[str, str],
cache_name: str,
expire_after: timedelta = CACHE_EXPIRATION,
) -> requests_cache.CachedSession: # E: line too long (81 > 79 characters)
"""
Return a cached session.
"""
session = requests_cache.CachedSession(
cache_name=cache_name,
backend="sqlite",
expire_after=requests_cache.DO_NOT_CACHE
if expire_after == timedelta(seconds=-1)
else expire_after.total_seconds(),
)
session.headers.update(request_headers)

return session
Loading

0 comments on commit d2e4f00

Please sign in to comment.