Skip to content

Commit

Permalink
Allow users to set Hyper Process params from reader/writer (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Sep 17, 2024
1 parent 6dd6101 commit 19f2237
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 16 deletions.
15 changes: 12 additions & 3 deletions src/pantab/_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
import shutil
import tempfile
from typing import Literal, Union
from typing import Literal, Optional, Union

import pyarrow as pa
import tableauhyperapi as tab_api
Expand All @@ -15,10 +15,14 @@ def frame_from_hyper_query(
query: str,
*,
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
process_params: Optional[dict[str, str]] = None,
):
"""See api.rst for documentation."""
if process_params is None:
process_params = {}

# Call native library to read tuples from result set
capsule = libpantab.read_from_hyper_query(str(source), query)
capsule = libpantab.read_from_hyper_query(str(source), query, process_params)
stream = pa.RecordBatchReader._import_from_c_capsule(capsule)
tbl = stream.read_all()

Expand All @@ -41,18 +45,22 @@ def frame_from_hyper(
*,
table: pantab_types.TableNameType,
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
process_params: Optional[dict[str, str]] = None,
):
"""See api.rst for documentation"""
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)

query = f"SELECT * FROM {table}"
return frame_from_hyper_query(source, query, return_type=return_type)
return frame_from_hyper_query(
source, query, return_type=return_type, process_params=process_params
)


def frames_from_hyper(
source: Union[str, pathlib.Path],
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
process_params: Optional[dict[str, str]] = None,
):
"""See api.rst for documentation."""
result = {}
Expand All @@ -73,6 +81,7 @@ def frames_from_hyper(
source=source,
table=table,
return_type=return_type,
process_params=process_params,
)

return result
6 changes: 6 additions & 0 deletions src/pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def frame_to_hyper(
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
process_params: Optional[dict[str, str]] = None,
) -> None:
"""See api.rst for documentation"""
frames_to_hyper(
Expand All @@ -66,6 +67,7 @@ def frame_to_hyper(
not_null_columns=not_null_columns,
json_columns=json_columns,
geo_columns=geo_columns,
process_params=process_params,
)


Expand All @@ -77,6 +79,7 @@ def frames_to_hyper(
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
geo_columns: Optional[set[str]] = None,
process_params: Optional[dict[str, str]] = None,
) -> None:
"""See api.rst for documentation."""
_validate_table_mode(table_mode)
Expand All @@ -87,6 +90,8 @@ def frames_to_hyper(
json_columns = set()
if geo_columns is None:
geo_columns = set()
if process_params is None:
process_params = {}

tmp_db = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper"

Expand All @@ -112,6 +117,7 @@ def convert_to_table_name(table: pantab_types.TableNameType):
not_null_columns=not_null_columns,
json_columns=json_columns,
geo_columns=geo_columns,
process_params=process_params,
)

# In Python 3.9+ we can just pass the path object, but due to bpo 32689
Expand Down
5 changes: 3 additions & 2 deletions src/pantab/libpantab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace nb = nanobind;
NB_MODULE(libpantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"))
nb::arg("json_columns"), nb::arg("geo_columns"),
nb::arg("process_params"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"));
nb::arg("query"), nb::arg("process_params"));
PyDateTime_IMPORT;
}
15 changes: 11 additions & 4 deletions src/pantab/reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,19 @@ static auto ReleaseArrowStream(void *ptr) noexcept -> void {
/// because the former detects a schema from the hyper Result object
/// which does not hold nullability information
///
auto read_from_hyper_query(const std::string &path, const std::string &query)
auto read_from_hyper_query(
const std::string &path, const std::string &query,
std::unordered_map<std::string, std::string> &&process_params)
-> nb::capsule {
const std::unordered_map<std::string, std::string> params = {
{"log_config", ""}, {"default_database_version", "4"}};

if (!process_params.count("log_config"))
process_params["log_config"] = "";
if (!process_params.count("default_database_version"))
process_params["default_database_version"] = "4";

const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)};
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "",
std::move(process_params)};
hyperapi::Connection connection(hyper.getEndpoint(), path);

auto hyperResult = connection.executeQuery(query);
Expand Down
5 changes: 4 additions & 1 deletion src/pantab/reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>

auto read_from_hyper_query(const std::string &path, const std::string &query)
auto read_from_hyper_query(
const std::string &path, const std::string &query,
std::unordered_map<std::string, std::string> &&process_params)
-> nanobind::capsule;
15 changes: 10 additions & 5 deletions src/pantab/writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,9 @@ using SchemaAndTableName = std::tuple<std::string, std::string>;
void write_to_hyper(
const std::map<SchemaAndTableName, nb::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode,
nb::iterable not_null_columns, nb::iterable json_columns,
nb::iterable geo_columns) {
const nb::iterable not_null_columns, const nb::iterable json_columns,
const nb::iterable geo_columns,
std::unordered_map<std::string, std::string> &&process_params) {

std::set<std::string> not_null_set;
for (auto col : not_null_columns) {
Expand All @@ -601,10 +602,14 @@ void write_to_hyper(
geo_set.insert(colstr);
}

const std::unordered_map<std::string, std::string> params = {
{"log_config", ""}, {"default_database_version", "4"}};
if (!process_params.count("log_config"))
process_params["log_config"] = "";
if (!process_params.count("default_database_version"))
process_params["default_database_version"] = "4";

const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)};
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "",
std::move(process_params)};

// TODO: we don't have separate table / database create modes in the API
// but probably should; for now we infer this from table mode
Expand Down
4 changes: 3 additions & 1 deletion src/pantab/writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <nanobind/stl/map.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h>

namespace nb = nanobind;

Expand All @@ -13,4 +14,5 @@ void write_to_hyper(
const std::map<SchemaAndTableName, nanobind::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode,
const nb::iterable not_null_columns, const nb::iterable json_columns,
const nb::iterable geo_columns);
const nb::iterable geo_columns,
std::unordered_map<std::string, std::string> &&process_params);
19 changes: 19 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import pandas.testing as tm
import pytest
import tableauhyperapi as tab_api

import pantab as pt
Expand Down Expand Up @@ -145,3 +146,21 @@ def test_frames_from_hyper_doesnt_generate_hyperd_log(frame, tmp_hyper):
pt.frame_to_hyper(frame, tmp_hyper, table="test")
pt.frames_from_hyper(tmp_hyper)
assert not pathlib.Path("hyperd.log").is_file()


def test_reader_accepts_process_params(tmp_hyper):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
pt.frame_to_hyper(frame, tmp_hyper, table="test")

params = {"default_database_version": "0"}
pt.frames_from_hyper(tmp_hyper, process_params=params)


def test_reader_invalid_process_params_raises(frame, tmp_hyper):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
pt.frame_to_hyper(frame, tmp_hyper, table="test")

params = {"not_a_real_parameter": "0"}
msg = r"No internal setting named 'not_a_real_parameter'"
with pytest.raises(RuntimeError, match=msg):
pt.frames_from_hyper(tmp_hyper, process_params=params)
15 changes: 15 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,18 @@ def test_eight_bit_int(tmp_hyper):
num_col = table_def.get_column_by_name("nums")
assert num_col is not None
assert num_col.type == tab_api.SqlType.small_int()


def test_writer_accepts_process_params(tmp_hyper):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
params = {"default_database_version": "0"}
pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params)


def test_writer_invalid_process_params_raises(tmp_hyper):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
params = {"not_a_real_parameter": "0"}

msg = r"No internal setting named 'not_a_real_parameter'"
with pytest.raises(RuntimeError, match=msg):
pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params)

0 comments on commit 19f2237

Please sign in to comment.