Skip to content

Commit

Permalink
Change to new database format for float32 support (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Aug 16, 2024
1 parent cd805e3 commit 676be4c
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 42 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)

if(WIN32)
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-windows-x86_64-release-main.0.0.18825.rd90b6d69.zip")
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-windows-x86_64-release-main.0.0.19691.r2d7e5bc8.zip")
elseif(APPLE)
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-macos-x86_64-release-main.0.0.18825.rd90b6d69.zip")
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-macos-x86_64-release-main.0.0.19691.r2d7e5bc8.zip")
else()
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-linux-x86_64-release-main.0.0.18825.rd90b6d69.zip")
set(TABLEAU_DOWNLOAD_URL "https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-linux-x86_64-release-main.0.0.19691.r2d7e5bc8.zip")
endif()

include(FetchContent)
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- isort
- mypy
- nanobind
- narwhals
- pandas
- pandas-stubs
- pip
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64"
skip = "*musllinux*"

test-command = "pytest {project}/tests"
test-requires = ["pytest", "pandas>=2.0.0", "polars", "numpy"]
test-requires = ["pytest", "pandas>=2.0.0", "polars<1.3.0", "narwhals", "numpy"]

[tool.ruff]
line-length = 88
Expand Down
9 changes: 6 additions & 3 deletions src/pantab/reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OidReadHelper : public ReadHelper {
}
};

class FloatReadHelper : public ReadHelper {
template <typename T> class FloatReadHelper : public ReadHelper {
using ReadHelper::ReadHelper;

auto Read(const hyperapi::Value &value) -> void override {
Expand All @@ -59,7 +59,7 @@ class FloatReadHelper : public ReadHelper {
}
return;
}
if (ArrowArrayAppendDouble(array_, value.get<double>())) {
if (ArrowArrayAppendDouble(array_, value.get<T>())) {
throw std::runtime_error("ArrowAppendDouble failed");
};
}
Expand Down Expand Up @@ -258,8 +258,10 @@ static auto MakeReadHelper(const ArrowSchemaView *schema_view,
return std::unique_ptr<ReadHelper>(new IntegralReadHelper<int64_t>(array));
case NANOARROW_TYPE_UINT32:
return std::unique_ptr<ReadHelper>(new OidReadHelper(array));
case NANOARROW_TYPE_FLOAT:
return std::unique_ptr<ReadHelper>(new FloatReadHelper<float>(array));
case NANOARROW_TYPE_DOUBLE:
return std::unique_ptr<ReadHelper>(new FloatReadHelper(array));
return std::unique_ptr<ReadHelper>(new FloatReadHelper<double>(array));
case NANOARROW_TYPE_LARGE_BINARY:
return std::unique_ptr<ReadHelper>(new BytesReadHelper(array));
case NANOARROW_TYPE_LARGE_STRING:
Expand Down Expand Up @@ -291,6 +293,7 @@ static auto GetArrowTypeFromHyper(const hyperapi::SqlType &sqltype)
case hyperapi::TypeTag::Int : return NANOARROW_TYPE_INT32;
case hyperapi::TypeTag::BigInt : return NANOARROW_TYPE_INT64;
case hyperapi::TypeTag::Oid : return NANOARROW_TYPE_UINT32;
case hyperapi::TypeTag::Float : return NANOARROW_TYPE_FLOAT;
case hyperapi::TypeTag::Double : return NANOARROW_TYPE_DOUBLE;
case hyperapi::TypeTag::Geography : case hyperapi::TypeTag::
Bytes : return NANOARROW_TYPE_LARGE_BINARY;
Expand Down
3 changes: 2 additions & 1 deletion src/pantab/writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ static auto GetHyperTypeFromArrowSchema(struct ArrowSchema *schema,
case NANOARROW_TYPE_UINT32:
return hyperapi::SqlType::oid();
case NANOARROW_TYPE_FLOAT:
return hyperapi::SqlType::real();
case NANOARROW_TYPE_DOUBLE:
return hyperapi::SqlType::doublePrecision();
case NANOARROW_TYPE_BOOL:
Expand Down Expand Up @@ -541,7 +542,7 @@ void write_to_hyper(
}

const std::unordered_map<std::string, std::string> params = {
{"log_config", ""}};
{"log_config", ""}, {"default_database_version", "4"}};
const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", std::move(params)};

Expand Down
42 changes: 12 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import pathlib

import narwhals as nw
import numpy as np
import pandas as pd
import pandas.testing as tm
Expand Down Expand Up @@ -246,7 +247,7 @@ def basic_dataframe():
"int16_limits": np.int16,
"int32_limits": np.int32,
"int64_limits": np.int64,
"float32_limits": np.float64,
"float32_limits": np.float32,
"float64_limits": np.float64,
"oid": "UInt32",
"non-ascii": "string",
Expand Down Expand Up @@ -322,9 +323,9 @@ def roundtripped_pyarrow():
("Int16", pa.int16()),
("Int32", pa.int32()),
("Int64", pa.int64()),
("float32", pa.float64()),
("float32", pa.float32()),
("float64", pa.float64()),
("Float32", pa.float64()),
("Float32", pa.float32()),
("Float64", pa.float64()),
("bool", pa.bool_()),
("boolean", pa.bool_()),
Expand All @@ -336,7 +337,7 @@ def roundtripped_pyarrow():
("int16_limits", pa.int16()),
("int32_limits", pa.int32()),
("int64_limits", pa.int64()),
("float32_limits", pa.float64()),
("float32_limits", pa.float32()),
("float64_limits", pa.float64()),
("oid", pa.uint32()),
("non-ascii", pa.large_string()),
Expand All @@ -363,9 +364,9 @@ def roundtripped_pandas():
"Int16": "int16[pyarrow]",
"Int32": "int32[pyarrow]",
"Int64": "int64[pyarrow]",
"float32": "double[pyarrow]",
"float32": "float[pyarrow]",
"float64": "double[pyarrow]",
"Float32": "double[pyarrow]",
"Float32": "float[pyarrow]",
"Float64": "double[pyarrow]",
"bool": "boolean[pyarrow]",
"boolean": "boolean[pyarrow]",
Expand All @@ -375,7 +376,7 @@ def roundtripped_pandas():
"int16_limits": "int16[pyarrow]",
"int32_limits": "int32[pyarrow]",
"int64_limits": "int64[pyarrow]",
"float32_limits": "double[pyarrow]",
"float32_limits": "float[pyarrow]",
"float64_limits": "double[pyarrow]",
"oid": "uint32[pyarrow]",
"non-ascii": "large_string[pyarrow]",
Expand Down Expand Up @@ -478,15 +479,9 @@ def empty_like(frame):
raise NotImplementedError("empty_like not implemented for type")

@staticmethod
@nw.narwhalify
def drop_columns(frame, columns):
if isinstance(frame, pd.DataFrame):
return frame.drop(columns=columns, errors="ignore")
elif isinstance(frame, pa.Table):
return frame.drop_columns(columns)
elif isinstance(frame, pl.DataFrame):
return frame.drop(columns, strict=False)
else:
raise NotImplementedError("drop_columns not implemented for type")
return frame.drop(columns)

@staticmethod
def select_columns(frame, columns):
Expand All @@ -500,22 +495,9 @@ def select_columns(frame, columns):
raise NotImplementedError("select_columns not implemented for type")

@staticmethod
@nw.narwhalify
def cast_column_to_type(frame, column, type_):
if isinstance(frame, pd.DataFrame):
frame[column] = frame[column].astype(type_)
return frame
elif isinstance(frame, pa.Table):
schema = pa.schema([pa.field(column, type_)])
return frame.cast(schema)
elif isinstance(frame, pl.DataFrame):
# hacky :-(
if type_ == "int64":
frame = frame.cast({column: pl.Int64()})
elif type_ == "float":
frame = frame.cast({column: pl.Float64()})
return frame
else:
raise NotImplementedError("cast_column_to_type not implemented for type")
return frame.with_columns(nw.col(column).cast(type_))

@staticmethod
def add_non_writeable_column(frame):
Expand Down
12 changes: 9 additions & 3 deletions tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

def test_basic(frame, roundtripped, tmp_hyper, table_name, table_mode, compat):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
if isinstance(frame, pa.Table) and return_type != "pyarrow":
frame = compat.drop_columns(frame, ["interval"])

if return_type == "pyarrow" and not isinstance(frame, pa.Table):
expected = compat.drop_columns(expected, ["interval"])

# Write twice; depending on mode this should either overwrite or duplicate entries
Expand Down Expand Up @@ -40,8 +42,10 @@ def test_multiple_tables(
frame, roundtripped, tmp_hyper, table_name, table_mode, compat
):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
if isinstance(frame, pa.Table) and return_type != "pyarrow":
frame = compat.drop_columns(frame, ["interval"])

if return_type == "pyarrow" and not isinstance(frame, pa.Table):
expected = compat.drop_columns(expected, ["interval"])

# Write twice; depending on mode this should either overwrite or duplicate entries
Expand Down Expand Up @@ -78,8 +82,10 @@ def test_empty_roundtrip(
frame, roundtripped, tmp_hyper, table_name, table_mode, compat
):
return_type, expected = roundtripped
if not (isinstance(frame, pa.Table) and return_type == "pyarrow"):
if isinstance(frame, pa.Table) and return_type != "pyarrow":
frame = compat.drop_columns(frame, ["interval"])

if return_type == "pyarrow" and not isinstance(frame, pa.Table):
expected = compat.drop_columns(expected, ["interval"])

# object case is by definition vague, so lets punt that for now
Expand Down
4 changes: 3 additions & 1 deletion tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re

import narwhals as nw
import pandas as pd
import pyarrow as pa
import pytest
Expand All @@ -24,7 +25,8 @@ def test_bad_table_mode_raises(frame, tmp_hyper):


@pytest.mark.parametrize(
"new_dtype,hyper_type_name", [("int64", "BIGINT"), ("float", "DOUBLE PRECISION")]
"new_dtype,hyper_type_name",
[(nw.Int64, "BIGINT"), (nw.Float64, "DOUBLE PRECISION")],
)
def test_append_mode_raises_column_dtype_mismatch(
new_dtype, hyper_type_name, frame, tmp_hyper, table_name, compat
Expand Down

0 comments on commit 676be4c

Please sign in to comment.