Skip to content

Commit

Permalink
Merge pull request #17 from altescy/use-manifest
Browse files Browse the repository at this point in the history
Add some features  that makes use of manifest files
  • Loading branch information
altescy authored Oct 26, 2023
2 parents 54f7725 + e20fbd4 commit 017de38
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 17 deletions.
3 changes: 3 additions & 0 deletions queuery_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
timeout: int = 300,
enable_cast: bool = False,
session: Optional[Session] = None,
use_manifest: Optional[bool] = None,
) -> None:
endpoint = endpoint or os.getenv("QUEUERY_ENDPOINT")
if endpoint is None:
Expand All @@ -33,6 +34,7 @@ def __init__(
self._timeout = timeout
self._enable_cast = enable_cast
self._session = session or Session()
self._use_manifest = use_manifest

@property
def _auth(self) -> Optional[Tuple[str, str]]:
Expand Down Expand Up @@ -78,6 +80,7 @@ def get_body(self, qid: int) -> Response:
response=body,
enable_cast=self._enable_cast,
session=self._session,
use_manifest=self._use_manifest,
)

def wait_for(self, qid: int) -> Response:
Expand Down
2 changes: 2 additions & 0 deletions queuery_client/queuery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def __init__(
timeout: int = 300,
enable_cast: bool = False,
session: Optional[Session] = None,
use_manifest: Optional[bool] = None,
) -> None:
self._client = Client(
endpoint=endpoint,
timeout=timeout,
enable_cast=enable_cast,
session=session,
use_manifest=use_manifest,
)

def run(self, sql: str) -> Response:
Expand Down
71 changes: 60 additions & 11 deletions queuery_client/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
import dataclasses
import gzip
from io import StringIO
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, overload
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload

from requests import Session

from queuery_client.cast import cast_row
from queuery_client.util import SizedIterator

try:
import pandas
except ModuleNotFoundError:
pandas = None


T = TypeVar("T")


@dataclasses.dataclass
class ResponseBody:
id: int
Expand All @@ -35,41 +39,62 @@ def __init__(
response: ResponseBody,
enable_cast: bool = False,
session: Optional[Session] = None,
use_manifest: Optional[bool] = None,
):
if enable_cast and use_manifest is False:
raise ValueError("enable_cast is not available when use_manifest is False.")

self._response = response
self._data_file_urls = response.data_file_urls
self._cursor = 0
self._parser = csv.reader
self._session = Session()
self._enable_cast = enable_cast
self._use_manifest = use_manifest or enable_cast
self._manifest: Optional[Dict[str, Any]] = None

def __iter__(self) -> Iterator[List[Any]]:
for url in self._data_file_urls:
for row in self._open(url):
if self._enable_cast:
yield cast_row(row, self.fetch_manifest())
else:
yield row
def get_iterator() -> Iterator[List[Any]]:
for url in self._data_file_urls:
for row in self._open(url):
if self._enable_cast:
yield cast_row(row, self.fetch_manifest())
else:
yield row

if self._use_manifest:
record_count = self.fetch_record_count()
return SizedIterator(get_iterator(), record_count)

return get_iterator()

def _open(self, url: str) -> List[List[str]]:
data = self._session.get(url).content
response = gzip.decompress(data).decode()
reader = csv.reader(StringIO(response), escapechar="\\")

self._cursor += 1
return list(reader)

def fetch_manifest(self, force: bool = False) -> Dict[str, Any]:
if not self._use_manifest:
raise RuntimeError("Manifest file is not available.")
if self._manifest is None or force:
if not self._response.manifest_file_url:
raise RuntimeError("Response does not contain manifest_file_url.")
raise RuntimeError(
"Manifest is not available because response does not contain manifest_file_url."
)

manifest = self._session.get(self._response.manifest_file_url).json()
assert isinstance(manifest, dict)
self._manifest = manifest
return self._manifest

def fetch_record_count(self) -> int:
manifest = self.fetch_manifest()
return int(manifest["meta"]["record_count"])

def fetch_column_names(self) -> List[str]:
manifest = self.fetch_manifest()
return [x["name"] for x in manifest["schema"]["elements"]]

@overload
def read(self) -> List[List[Any]]:
...
Expand All @@ -94,6 +119,30 @@ def read(
"pandas is not availabe. Please make sure that "
"pandas is successfully installed to use use_pandas option."
)

if self._use_manifest:
return pandas.DataFrame(elems, columns=self.fetch_column_names())

return pandas.DataFrame(elems)

return elems

def map(self, target: Union[Type[T], Callable[..., T]]) -> Iterator[T]:
column_names = self.fetch_column_names() if self._use_manifest else None

def convert_to_args(row: List[Any]) -> Tuple[List[Any], Dict[str, Any]]:
if column_names is None:
return row, {}
return [], {name: value for name, value in zip(column_names, row)}

def map_to_target(row: List[Any]) -> T:
args, kwargs = convert_to_args(row)
return target(*args, **kwargs)

iterator = map(map_to_target, self)

if self._use_manifest:
record_count = self.fetch_record_count()
return SizedIterator(iterator, record_count)

return iterator
26 changes: 26 additions & 0 deletions queuery_client/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Generic, Iterator, TypeVar

T = TypeVar("T")


class SizedIterator(Generic[T]):
"""
A wrapper for an iterator that knows its size.
Args:
iterator: The iterator.
size: The size of the iterator.
"""

def __init__(self, iterator: Iterator[T], size: int):
self.iterator = iterator
self.size = size

def __iter__(self) -> Iterator[T]:
return self.iterator

def __next__(self) -> T:
return next(self.iterator)

def __len__(self) -> int:
return self.size
54 changes: 48 additions & 6 deletions tests/test_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import gzip
import json
from typing import Any, Dict
Expand Down Expand Up @@ -44,12 +45,15 @@ def test_response_with_type_cast() -> None:

manifest_response = MockResponse(
b"""
{"schema": {
"elements": [
{"name": "id", "type": {"base": "integer"}},
{"name": "title", "type": {"base": "character varying"}}
]
}}
{
"schema": {
"elements": [
{"name": "id", "type": {"base": "integer"}},
{"name": "title", "type": {"base": "character varying"}}
]
},
"meta": {"record_count": 2}
}
""",
200,
)
Expand All @@ -60,3 +64,41 @@ def test_response_with_type_cast() -> None:
with mock.patch("requests.Session.get", return_value=data_response):
data = response.read()
assert data == [[1, "test_recipe1"], [2, "test_recipe2"]]


def test_response_with_map() -> None:
@dataclasses.dataclass
class Item:
id: int
title: str

response_body = ResponseBody(
id=1,
data_file_urls=["https://queuery.example.com/data"],
error=None,
status="success",
manifest_file_url="https://queuery.example.com/manifest",
)
response = Response(response_body, enable_cast=True)

manifest_response = MockResponse(
b"""
{
"schema": {
"elements": [
{"name": "id", "type": {"base": "integer"}},
{"name": "title", "type": {"base": "character varying"}}
]
},
"meta": {"record_count": 2}
}
""",
200,
)
with mock.patch("requests.Session.get", return_value=manifest_response):
response.fetch_manifest()

data_response = MockResponse(gzip.compress(b'"1","test_recipe1"\n"2","test_recipe2"'), 200)
with mock.patch("requests.Session.get", return_value=data_response):
data = list(response.map(Item))
assert data == [Item(id=1, title="test_recipe1"), Item(id=2, title="test_recipe2")]

0 comments on commit 017de38

Please sign in to comment.