Skip to content

Commit

Permalink
fix(python): only raise on actual parameter collision if both "dtypes…
Browse files Browse the repository at this point in the history
…" and "schema_overrides" specified for `read_excel`
  • Loading branch information
alexander-beedie committed Sep 17, 2023
1 parent b08bc7d commit 98bd178
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
4 changes: 4 additions & 0 deletions py-polars/polars/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class NoRowsReturnedError(RowsError):
"""Exception raised when no rows are returned, but at least one row is expected."""


class ParameterCollisionError(RuntimeError):
"""Exception raised when the same parameter occurs multiple times."""


class PolarsInefficientMapWarning(Warning):
"""Warning raised when a potentially slow `apply` operation is performed."""

Expand Down
16 changes: 13 additions & 3 deletions py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import polars._reexport as pl
from polars import functions as F
from polars.datatypes import Date, Datetime
from polars.exceptions import NoDataError
from polars.exceptions import NoDataError, ParameterCollisionError
from polars.io.csv.functions import read_csv
from polars.utils.various import normalize_filepath

Expand Down Expand Up @@ -506,13 +506,23 @@ def _csv_buffer_to_frame(
)
return pl.DataFrame()

if read_csv_options is None:
read_csv_options = {}
if schema_overrides:
if (csv_dtypes := read_csv_options.get("dtypes", {})) and set(
csv_dtypes
).intersection(schema_overrides):
raise ParameterCollisionError(
"Cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`"
)
read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides}

# otherwise rewind the buffer and parse as csv
csv.seek(0)
df = read_csv(
csv,
separator=separator,
dtypes=schema_overrides,
**(read_csv_options or {}),
**read_csv_options,
)
return _drop_unnamed_null_columns(df)

Expand Down
52 changes: 51 additions & 1 deletion py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import polars as pl
import polars.selectors as cs
from polars.exceptions import NoDataError
from polars.exceptions import NoDataError, ParameterCollisionError
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,6 +183,56 @@ def test_write_excel_bytes(engine: Literal["xlsx2csv", "openpyxl"]) -> None:
assert_frame_equal(df, df_read)


def test_schema_overrides_11161(excel_file_path: Path) -> None:
df1 = pl.read_excel(
excel_file_path,
sheet_name="test4",
schema_overrides={"cardinality": pl.UInt16},
).drop_nulls()
assert df1.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float64,
"iter_groups": pl.Float64,
}

df2 = pl.read_excel(
excel_file_path,
sheet_name="test4",
read_csv_options={"dtypes": {"cardinality": pl.UInt16}},
).drop_nulls()
assert df2.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float64,
"iter_groups": pl.Float64,
}

df3 = pl.read_excel(
excel_file_path,
sheet_name="test4",
schema_overrides={"cardinality": pl.UInt16},
read_csv_options={
"dtypes": {
"rows_by_key": pl.Float32,
"iter_groups": pl.Float32,
},
},
).drop_nulls()
assert df3.schema == {
"cardinality": pl.UInt16,
"rows_by_key": pl.Float32,
"iter_groups": pl.Float32,
}

with pytest.raises(ParameterCollisionError):
# cannot specify 'cardinality' in both schema_overrides and read_csv_options
pl.read_excel(
excel_file_path,
sheet_name="test4",
schema_overrides={"cardinality": pl.UInt16},
read_csv_options={"dtypes": {"cardinality": pl.Int32}},
)


def test_unsupported_engine() -> None:
with pytest.raises(NotImplementedError):
pl.read_excel(None, engine="foo") # type: ignore[call-overload]
Expand Down

0 comments on commit 98bd178

Please sign in to comment.