From 98bd1780526654657be6316fcb821c167ddec8e1 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 17 Sep 2023 16:24:46 +0400 Subject: [PATCH] fix(python): only raise on actual parameter collision if both "dtypes" and "schema_overrides" specified for `read_excel` --- py-polars/polars/exceptions.py | 4 ++ py-polars/polars/io/spreadsheet/functions.py | 16 ++++-- py-polars/tests/unit/io/test_spreadsheet.py | 52 +++++++++++++++++++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/exceptions.py b/py-polars/polars/exceptions.py index 716d67247581..963c5379cbb5 100644 --- a/py-polars/polars/exceptions.py +++ b/py-polars/polars/exceptions.py @@ -82,6 +82,10 @@ class NoRowsReturnedError(RowsError): """Exception raised when no rows are returned, but at least one row is expected.""" +class ParameterCollisionError(RuntimeError): + """Exception raised when the same parameter occurs multiple times.""" + + class PolarsInefficientMapWarning(Warning): """Warning raised when a potentially slow `apply` operation is performed.""" diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 9103902bc828..ecbf01a4d9c2 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -8,7 +8,7 @@ import polars._reexport as pl from polars import functions as F from polars.datatypes import Date, Datetime -from polars.exceptions import NoDataError +from polars.exceptions import NoDataError, ParameterCollisionError from polars.io.csv.functions import read_csv from polars.utils.various import normalize_filepath @@ -506,13 +506,23 @@ def _csv_buffer_to_frame( ) return pl.DataFrame() + if read_csv_options is None: + read_csv_options = {} + if schema_overrides: + if (csv_dtypes := read_csv_options.get("dtypes", {})) and set( + csv_dtypes + ).intersection(schema_overrides): + raise ParameterCollisionError( + "Cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" + ) + read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} + # otherwise rewind the buffer and parse as csv csv.seek(0) df = read_csv( csv, separator=separator, - dtypes=schema_overrides, - **(read_csv_options or {}), + **read_csv_options, ) return _drop_unnamed_null_columns(df) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index dd8bb5a62199..9f872f678f73 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -9,7 +9,7 @@ import polars as pl import polars.selectors as cs -from polars.exceptions import NoDataError +from polars.exceptions import NoDataError, ParameterCollisionError from polars.testing import assert_frame_equal if TYPE_CHECKING: @@ -183,6 +183,56 @@ def test_write_excel_bytes(engine: Literal["xlsx2csv", "openpyxl"]) -> None: assert_frame_equal(df, df_read) +def test_schema_overrides_11161(excel_file_path: Path) -> None: + df1 = pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + ).drop_nulls() + assert df1.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float64, + "iter_groups": pl.Float64, + } + + df2 = pl.read_excel( + excel_file_path, + sheet_name="test4", + read_csv_options={"dtypes": {"cardinality": pl.UInt16}}, + ).drop_nulls() + assert df2.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float64, + "iter_groups": pl.Float64, + } + + df3 = pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + read_csv_options={ + "dtypes": { + "rows_by_key": pl.Float32, + "iter_groups": pl.Float32, + }, + }, + ).drop_nulls() + assert df3.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float32, + "iter_groups": pl.Float32, + } + + with pytest.raises(ParameterCollisionError): + # cannot specify 'cardinality' in both schema_overrides and read_csv_options + pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + read_csv_options={"dtypes": {"cardinality": pl.Int32}}, + ) + + def test_unsupported_engine() -> None: with pytest.raises(NotImplementedError): pl.read_excel(None, engine="foo") # type: ignore[call-overload]