From ab6229d549a7d32a09396876124f940842526f18 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 21 Oct 2023 00:42:52 +0400 Subject: [PATCH] fix(python): address issue with inadvertently shared options dict in `read_excel` --- py-polars/polars/io/spreadsheet/functions.py | 13 +++++++------ py-polars/tests/unit/io/test_spreadsheet.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 198b69404290..e835a94ad193 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -28,7 +28,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -43,7 +43,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -58,7 +58,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> NoReturn: ... @@ -75,7 +75,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -90,7 +90,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -105,7 +105,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -548,6 +548,7 @@ def _csv_buffer_to_frame( raise ParameterCollisionError( "cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" ) + read_csv_options = read_csv_options.copy() read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} # otherwise rewind the buffer and parse as csv diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a777e8af318b..5101c842c833 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections import OrderedDict from datetime import date, datetime from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Literal @@ -275,6 +276,22 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N read_csv_options={"dtypes": {"cardinality": pl.Int32}}, ) + # read multiple sheets in conjunction with 'schema_overrides' + # (note: reading the same sheet twice simulates the issue in #11850) + overrides = OrderedDict( + [ + ("cardinality", pl.UInt32), + ("rows_by_key", pl.Float32), + ("iter_groups", pl.Float64), + ] + ) + df = pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_name=["test4", "test4"], + schema_overrides=overrides, + ) + assert df["test4"].schema == overrides + def test_unsupported_engine() -> None: with pytest.raises(NotImplementedError):