diff --git a/src/pantab/_reader.py b/src/pantab/_reader.py index 1968f881..59597538 100644 --- a/src/pantab/_reader.py +++ b/src/pantab/_reader.py @@ -1,7 +1,7 @@ import pathlib import shutil import tempfile -from typing import Literal, Union +from typing import Literal, Optional, Union import pyarrow as pa import tableauhyperapi as tab_api @@ -15,10 +15,14 @@ def frame_from_hyper_query( query: str, *, return_type: Literal["pandas", "polars", "pyarrow"] = "pandas", + process_params: Optional[dict[str, str]] = None, ): """See api.rst for documentation.""" + if process_params is None: + process_params = {} + # Call native library to read tuples from result set - capsule = libpantab.read_from_hyper_query(str(source), query) + capsule = libpantab.read_from_hyper_query(str(source), query, process_params) stream = pa.RecordBatchReader._import_from_c_capsule(capsule) tbl = stream.read_all() @@ -41,18 +45,22 @@ def frame_from_hyper( *, table: pantab_types.TableNameType, return_type: Literal["pandas", "polars", "pyarrow"] = "pandas", + process_params: Optional[dict[str, str]] = None, ): """See api.rst for documentation""" if isinstance(table, (str, tab_api.Name)) or not table.schema_name: table = tab_api.TableName("public", table) query = f"SELECT * FROM {table}" - return frame_from_hyper_query(source, query, return_type=return_type) + return frame_from_hyper_query( + source, query, return_type=return_type, process_params=process_params + ) def frames_from_hyper( source: Union[str, pathlib.Path], return_type: Literal["pandas", "polars", "pyarrow"] = "pandas", + process_params: Optional[dict[str, str]] = None, ): """See api.rst for documentation.""" result = {} @@ -73,6 +81,7 @@ def frames_from_hyper( source=source, table=table, return_type=return_type, + process_params=process_params, ) return result diff --git a/src/pantab/_writer.py b/src/pantab/_writer.py index 9e57d0bb..a44bb703 100644 --- a/src/pantab/_writer.py +++ b/src/pantab/_writer.py @@ -57,6 +57,7 @@ def frame_to_hyper( not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, geo_columns: Optional[set[str]] = None, + process_params: Optional[dict[str, str]] = None, ) -> None: """See api.rst for documentation""" frames_to_hyper( @@ -66,6 +67,7 @@ def frame_to_hyper( not_null_columns=not_null_columns, json_columns=json_columns, geo_columns=geo_columns, + process_params=process_params, ) @@ -77,6 +79,7 @@ def frames_to_hyper( not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, geo_columns: Optional[set[str]] = None, + process_params: Optional[dict[str, str]] = None, ) -> None: """See api.rst for documentation.""" _validate_table_mode(table_mode) @@ -87,6 +90,8 @@ def frames_to_hyper( json_columns = set() if geo_columns is None: geo_columns = set() + if process_params is None: + process_params = {} tmp_db = pathlib.Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.hyper" @@ -112,6 +117,7 @@ def convert_to_table_name(table: pantab_types.TableNameType): not_null_columns=not_null_columns, json_columns=json_columns, geo_columns=geo_columns, + process_params=process_params, ) # In Python 3.9+ we can just pass the path object, but due to bpo 32689 diff --git a/src/pantab/libpantab.cpp b/src/pantab/libpantab.cpp index dd2a9b0d..5f31b8cd 100644 --- a/src/pantab/libpantab.cpp +++ b/src/pantab/libpantab.cpp @@ -9,8 +9,9 @@ namespace nb = nanobind; NB_MODULE(libpantab, m) { // NOLINT m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), - nb::arg("json_columns"), nb::arg("geo_columns")) + nb::arg("json_columns"), nb::arg("geo_columns"), + nb::arg("process_params")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), - nb::arg("query")); + nb::arg("query"), nb::arg("process_params")); PyDateTime_IMPORT; } diff --git a/src/pantab/reader.cpp b/src/pantab/reader.cpp index ee0b63ab..30a3e7a2 100644 --- a/src/pantab/reader.cpp +++ b/src/pantab/reader.cpp @@ -431,12 +431,19 @@ static auto ReleaseArrowStream(void *ptr) noexcept -> void { /// because the former detects a schema from the hyper Result object /// which does not hold nullability information /// -auto read_from_hyper_query(const std::string &path, const std::string &query) +auto read_from_hyper_query( + const std::string &path, const std::string &query, + std::unordered_map &&process_params) -> nb::capsule { - const std::unordered_map params = { - {"log_config", ""}, {"default_database_version", "4"}}; + + if (!process_params.count("log_config")) + process_params["log_config"] = ""; + if (!process_params.count("default_database_version")) + process_params["default_database_version"] = "4"; + const hyperapi::HyperProcess hyper{ - hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)}; + hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", + std::move(process_params)}; hyperapi::Connection connection(hyper.getEndpoint(), path); auto hyperResult = connection.executeQuery(query); diff --git a/src/pantab/reader.hpp b/src/pantab/reader.hpp index 7d451bf4..5469aede 100644 --- a/src/pantab/reader.hpp +++ b/src/pantab/reader.hpp @@ -2,6 +2,9 @@ #include #include +#include -auto read_from_hyper_query(const std::string &path, const std::string &query) +auto read_from_hyper_query( + const std::string &path, const std::string &query, + std::unordered_map &&process_params) -> nanobind::capsule; diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index d604e66b..9f24a308 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -580,8 +580,9 @@ using SchemaAndTableName = std::tuple; void write_to_hyper( const std::map &dict_of_capsules, const std::string &path, const std::string &table_mode, - nb::iterable not_null_columns, nb::iterable json_columns, - nb::iterable geo_columns) { + const nb::iterable not_null_columns, const nb::iterable json_columns, + const nb::iterable geo_columns, + std::unordered_map &&process_params) { std::set not_null_set; for (auto col : not_null_columns) { @@ -601,10 +602,14 @@ void write_to_hyper( geo_set.insert(colstr); } - const std::unordered_map params = { - {"log_config", ""}, {"default_database_version", "4"}}; + if (!process_params.count("log_config")) + process_params["log_config"] = ""; + if (!process_params.count("default_database_version")) + process_params["default_database_version"] = "4"; + const hyperapi::HyperProcess hyper{ - hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)}; + hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", + std::move(process_params)}; // TODO: we don't have separate table / database create modes in the API // but probably should; for now we infer this from table mode diff --git a/src/pantab/writer.hpp b/src/pantab/writer.hpp index 8e5c5a77..66426727 100644 --- a/src/pantab/writer.hpp +++ b/src/pantab/writer.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace nb = nanobind; @@ -13,4 +14,5 @@ void write_to_hyper( const std::map &dict_of_capsules, const std::string &path, const std::string &table_mode, const nb::iterable not_null_columns, const nb::iterable json_columns, - const nb::iterable geo_columns); + const nb::iterable geo_columns, + std::unordered_map &&process_params); diff --git a/tests/test_reader.py b/tests/test_reader.py index 08e14d40..561c053e 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -2,6 +2,7 @@ import pandas as pd import pandas.testing as tm +import pytest import tableauhyperapi as tab_api import pantab as pt @@ -145,3 +146,21 @@ def test_frames_from_hyper_doesnt_generate_hyperd_log(frame, tmp_hyper): pt.frame_to_hyper(frame, tmp_hyper, table="test") pt.frames_from_hyper(tmp_hyper) assert not pathlib.Path("hyperd.log").is_file() + + +def test_reader_accepts_process_params(tmp_hyper): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + pt.frame_to_hyper(frame, tmp_hyper, table="test") + + params = {"default_database_version": "0"} + pt.frames_from_hyper(tmp_hyper, process_params=params) + + +def test_reader_invalid_process_params_raises(frame, tmp_hyper): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + pt.frame_to_hyper(frame, tmp_hyper, table="test") + + params = {"not_a_real_parameter": "0"} + msg = r"No internal setting named 'not_a_real_parameter'" + with pytest.raises(RuntimeError, match=msg): + pt.frames_from_hyper(tmp_hyper, process_params=params) diff --git a/tests/test_writer.py b/tests/test_writer.py index 97e6b271..d4cdd7dc 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -380,3 +380,18 @@ def test_eight_bit_int(tmp_hyper): num_col = table_def.get_column_by_name("nums") assert num_col is not None assert num_col.type == tab_api.SqlType.small_int() + + +def test_writer_accepts_process_params(tmp_hyper): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + params = {"default_database_version": "0"} + pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params) + + +def test_writer_invalid_process_params_raises(tmp_hyper): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + params = {"not_a_real_parameter": "0"} + + msg = r"No internal setting named 'not_a_real_parameter'" + with pytest.raises(RuntimeError, match=msg): + pt.frame_to_hyper(frame, tmp_hyper, table="test", process_params=params)