diff --git a/CMakeLists.txt b/CMakeLists.txt index 75ce0298..5e8c177e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,11 +20,11 @@ list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") find_package(nanobind CONFIG REQUIRED) if(WIN32) - set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-windows-x86_64-release-main.0.0.18825.rd90b6d69.zip") + set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-windows-x86_64-release-main.0.0.19691.r2d7e5bc8.zip") elseif(APPLE) - set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-macos-x86_64-release-main.0.0.18825.rd90b6d69.zip") + set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-macos-x86_64-release-main.0.0.19691.r2d7e5bc8.zip") else() - set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-linux-x86_64-release-main.0.0.18825.rd90b6d69.zip") + set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-linux-x86_64-release-main.0.0.19691.r2d7e5bc8.zip") endif() include(FetchContent) diff --git a/environment.yml b/environment.yml index 2b516b33..41db238c 100644 --- a/environment.yml +++ b/environment.yml @@ -7,6 +7,7 @@ dependencies: - isort - mypy - nanobind + - narwhals - pandas - pandas-stubs - pip diff --git a/pyproject.toml b/pyproject.toml index bb678623..05037d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64" skip = "*musllinux*" test-command = "pytest {project}/tests" -test-requires = ["pytest", "pandas>=2.0.0", "polars", "numpy"] +test-requires = ["pytest", "pandas>=2.0.0", "polars<1.3.0", "narwhals", "numpy"] [tool.ruff] line-length = 88 diff --git a/src/pantab/reader.cpp b/src/pantab/reader.cpp index 41e7fcb3..9100e163 100644 --- a/src/pantab/reader.cpp +++ b/src/pantab/reader.cpp @@ -49,7 +49,7 @@ class OidReadHelper : public ReadHelper { } }; -class FloatReadHelper : public ReadHelper { +template class FloatReadHelper : public ReadHelper { using ReadHelper::ReadHelper; auto Read(const hyperapi::Value &value) -> void override { @@ -59,7 +59,7 @@ class FloatReadHelper : public ReadHelper { } return; } - if (ArrowArrayAppendDouble(array_, value.get())) { + if (ArrowArrayAppendDouble(array_, value.get())) { throw std::runtime_error("ArrowAppendDouble failed"); }; } @@ -258,8 +258,10 @@ static auto MakeReadHelper(const ArrowSchemaView *schema_view, return std::unique_ptr(new IntegralReadHelper(array)); case NANOARROW_TYPE_UINT32: return std::unique_ptr(new OidReadHelper(array)); + case NANOARROW_TYPE_FLOAT: + return std::unique_ptr(new FloatReadHelper(array)); case NANOARROW_TYPE_DOUBLE: - return std::unique_ptr(new FloatReadHelper(array)); + return std::unique_ptr(new FloatReadHelper(array)); case NANOARROW_TYPE_LARGE_BINARY: return std::unique_ptr(new BytesReadHelper(array)); case NANOARROW_TYPE_LARGE_STRING: @@ -291,6 +293,7 @@ static auto GetArrowTypeFromHyper(const hyperapi::SqlType &sqltype) case hyperapi::TypeTag::Int : return NANOARROW_TYPE_INT32; case hyperapi::TypeTag::BigInt : return NANOARROW_TYPE_INT64; case hyperapi::TypeTag::Oid : return NANOARROW_TYPE_UINT32; + case hyperapi::TypeTag::Float : return NANOARROW_TYPE_FLOAT; case hyperapi::TypeTag::Double : return NANOARROW_TYPE_DOUBLE; case hyperapi::TypeTag::Geography : case hyperapi::TypeTag:: Bytes : return NANOARROW_TYPE_LARGE_BINARY; diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 2064b473..b2caf511 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -28,6 +28,7 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema, case NANOARROW_TYPE_UINT32: return hyperapi::SqlType::oid(); case NANOARROW_TYPE_FLOAT: + return hyperapi::SqlType::real(); case NANOARROW_TYPE_DOUBLE: return hyperapi::SqlType::doublePrecision(); case NANOARROW_TYPE_BOOL: @@ -541,7 +542,7 @@ void write_to_hyper( } const std::unordered_map params = { - {"log_config", ""}}; + {"log_config", ""}, {"default_database_version", "4"}}; const hyperapi::HyperProcess hyper{ hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)}; diff --git a/tests/conftest.py b/tests/conftest.py index 38d7acc8..c0a6b026 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import datetime import pathlib +import narwhals as nw import numpy as np import pandas as pd import pandas.testing as tm @@ -246,7 +247,7 @@ def basic_dataframe(): "int16_limits": np.int16, "int32_limits": np.int32, "int64_limits": np.int64, - "float32_limits": np.float64, + "float32_limits": np.float32, "float64_limits": np.float64, "oid": "UInt32", "non-ascii": "string", @@ -322,9 +323,9 @@ def roundtripped_pyarrow(): ("Int16", pa.int16()), ("Int32", pa.int32()), ("Int64", pa.int64()), - ("float32", pa.float64()), + ("float32", pa.float32()), ("float64", pa.float64()), - ("Float32", pa.float64()), + ("Float32", pa.float32()), ("Float64", pa.float64()), ("bool", pa.bool_()), ("boolean", pa.bool_()), @@ -336,7 +337,7 @@ def roundtripped_pyarrow(): ("int16_limits", pa.int16()), ("int32_limits", pa.int32()), ("int64_limits", pa.int64()), - ("float32_limits", pa.float64()), + ("float32_limits", pa.float32()), ("float64_limits", pa.float64()), ("oid", pa.uint32()), ("non-ascii", pa.large_string()), @@ -363,9 +364,9 @@ def roundtripped_pandas(): "Int16": "int16[pyarrow]", "Int32": "int32[pyarrow]", "Int64": "int64[pyarrow]", - "float32": "double[pyarrow]", + "float32": "float[pyarrow]", "float64": "double[pyarrow]", - "Float32": "double[pyarrow]", + "Float32": "float[pyarrow]", "Float64": "double[pyarrow]", "bool": "boolean[pyarrow]", "boolean": "boolean[pyarrow]", @@ -375,7 +376,7 @@ def roundtripped_pandas(): "int16_limits": "int16[pyarrow]", "int32_limits": "int32[pyarrow]", "int64_limits": "int64[pyarrow]", - "float32_limits": "double[pyarrow]", + "float32_limits": "float[pyarrow]", "float64_limits": "double[pyarrow]", "oid": "uint32[pyarrow]", "non-ascii": "large_string[pyarrow]", @@ -478,15 +479,9 @@ def empty_like(frame): raise NotImplementedError("empty_like not implemented for type") @staticmethod + @nw.narwhalify def drop_columns(frame, columns): - if isinstance(frame, pd.DataFrame): - return frame.drop(columns=columns, errors="ignore") - elif isinstance(frame, pa.Table): - return frame.drop_columns(columns) - elif isinstance(frame, pl.DataFrame): - return frame.drop(columns, strict=False) - else: - raise NotImplementedError("drop_columns not implemented for type") + return frame.drop(columns) @staticmethod def select_columns(frame, columns): @@ -500,22 +495,9 @@ def select_columns(frame, columns): raise NotImplementedError("select_columns not implemented for type") @staticmethod + @nw.narwhalify def cast_column_to_type(frame, column, type_): - if isinstance(frame, pd.DataFrame): - frame[column] = frame[column].astype(type_) - return frame - elif isinstance(frame, pa.Table): - schema = pa.schema([pa.field(column, type_)]) - return frame.cast(schema) - elif isinstance(frame, pl.DataFrame): - # hacky :-( - if type_ == "int64": - frame = frame.cast({column: pl.Int64()}) - elif type_ == "float": - frame = frame.cast({column: pl.Float64()}) - return frame - else: - raise NotImplementedError("cast_column_to_type not implemented for type") + return frame.with_columns(nw.col(column).cast(type_)) @staticmethod def add_non_writeable_column(frame): diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 8ca6c888..4c61ad7d 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -6,8 +6,10 @@ def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat): return_type, expected = roundtripped - if not (isinstance(frame, pa.Table) and return_type == "pyarrow"): + if isinstance(frame, pa.Table) and return_type != "pyarrow": frame = compat.drop_columns(frame, ["interval"]) + + if return_type == "pyarrow" and not isinstance(frame, pa.Table): expected = compat.drop_columns(expected, ["interval"]) # Write twice; depending on mode this should either overwrite or duplicate entries @@ -40,8 +42,10 @@ def test_multiple_tables( frame, roundtripped, tmp_hyper, table_name, table_mode, compat ): return_type, expected = roundtripped - if not (isinstance(frame, pa.Table) and return_type == "pyarrow"): + if isinstance(frame, pa.Table) and return_type != "pyarrow": frame = compat.drop_columns(frame, ["interval"]) + + if return_type == "pyarrow" and not isinstance(frame, pa.Table): expected = compat.drop_columns(expected, ["interval"]) # Write twice; depending on mode this should either overwrite or duplicate entries @@ -78,8 +82,10 @@ def test_empty_roundtrip( frame, roundtripped, tmp_hyper, table_name, table_mode, compat ): return_type, expected = roundtripped - if not (isinstance(frame, pa.Table) and return_type == "pyarrow"): + if isinstance(frame, pa.Table) and return_type != "pyarrow": frame = compat.drop_columns(frame, ["interval"]) + + if return_type == "pyarrow" and not isinstance(frame, pa.Table): expected = compat.drop_columns(expected, ["interval"]) # object case is by definition vague, so lets punt that for now diff --git a/tests/test_writer.py b/tests/test_writer.py index 077a8c76..7eb4f539 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -1,6 +1,7 @@ import datetime import re +import narwhals as nw import pandas as pd import pyarrow as pa import pytest @@ -24,7 +25,8 @@ def test_bad_table_mode_raises(frame, tmp_hyper): @pytest.mark.parametrize( - "new_dtype,hyper_type_name", [("int64", "BIGINT"), ("float", "DOUBLE PRECISION")] + "new_dtype,hyper_type_name", + [(nw.Int64, "BIGINT"), (nw.Float64, "DOUBLE PRECISION")], ) def test_append_mode_raises_column_dtype_mismatch( new_dtype, hyper_type_name, frame, tmp_hyper, table_name, compat