From 36833440ac1b86b25ea670b7a18f5b65b0c232f1 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 4 Aug 2023 10:49:56 +0200 Subject: [PATCH] fix(python): Fix interchange protocol allowing copy even when `allow_copy` was set to False (#10262) --- py-polars/polars/convert.py | 47 ++++++++++++++++++++---- py-polars/polars/dataframe/frame.py | 20 +++++++--- py-polars/polars/datatypes/convert.py | 7 ++-- py-polars/tests/unit/test_interchange.py | 22 ++++++++++- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/py-polars/polars/convert.py b/py-polars/polars/convert.py index 47aca8f8caef..2c72dc94a7e1 100644 --- a/py-polars/polars/convert.py +++ b/py-polars/polars/convert.py @@ -746,20 +746,29 @@ def from_dataframe(df: Any, *, allow_copy: bool = True) -> DataFrame: Details on the dataframe interchange protocol: https://data-apis.org/dataframe-protocol/latest/index.html - Zero-copy conversions currently cannot be guaranteed and will throw a - ``RuntimeError``. - Using a dedicated function like :func:`from_pandas` or :func:`from_arrow` is a more efficient method of conversion. + Polars currently relies on pyarrow's implementation of the dataframe interchange + protocol. Therefore, pyarrow>=11.0.0 is required for this function to work. + + Because Polars can not currently guarantee zero-copy conversion from Arrow for + categorical columns, ``allow_copy=False`` will not work if the dataframe contains + categorical data. + """ if isinstance(df, pl.DataFrame): return df if not hasattr(df, "__dataframe__"): raise TypeError( - f"`df` of type {type(df)} does not support the dataframe interchange" - " protocol." + f"`df` of type {type(df)} does not support the dataframe interchange protocol." ) + + pa_table = _df_to_pyarrow_table(df, allow_copy=allow_copy) + return from_arrow(pa_table, rechunk=allow_copy) # type: ignore[return-value] + + +def _df_to_pyarrow_table(df: Any, *, allow_copy: bool = False) -> pa.Table: if not _PYARROW_AVAILABLE or parse_version(pa.__version__) < parse_version("11"): raise ImportError( "pyarrow>=11.0.0 is required for converting a dataframe interchange object" @@ -768,5 +777,29 @@ def from_dataframe(df: Any, *, allow_copy: bool = True) -> DataFrame: import pyarrow.interchange # noqa: F401 - pa_table = pa.interchange.from_dataframe(df, allow_copy=allow_copy) - return from_arrow(pa_table, rechunk=allow_copy) # type: ignore[return-value] + if not allow_copy: + return _df_to_pyarrow_table_zero_copy(df) + + return pa.interchange.from_dataframe(df, allow_copy=True) + + +def _df_to_pyarrow_table_zero_copy(df: Any) -> pa.Table: + dfi = df.__dataframe__(allow_copy=False) + if _dfi_contains_categorical_data(dfi): + raise TypeError( + "Polars can not currently guarantee zero-copy conversion from Arrow for " + " categorical columns. Set `allow_copy=True` or cast categorical columns to" + " string first." + ) + + if isinstance(df, pa.Table): + return df + elif isinstance(df, pa.RecordBatch): + return pa.Table.from_batches([df]) + else: + return pa.interchange.from_dataframe(dfi, allow_copy=False) + + +def _dfi_contains_categorical_data(dfi: Any) -> bool: + CATEGORICAL_DTYPE = 23 + return any(c.dtype[0] == CATEGORICAL_DTYPE for c in dfi.get_columns()) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 9812138a7ff9..ab9a21278cb6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -47,6 +47,7 @@ Time, Utf8, py_type_to_dtype, + unpack_dtypes, ) from polars.dependencies import ( _PYARROW_AVAILABLE, @@ -1217,9 +1218,16 @@ def __dataframe__( Details on the dataframe interchange protocol: https://data-apis.org/dataframe-protocol/latest/index.html - `nan_as_null` currently has no effect; once support for nullable extension + ``nan_as_null`` currently has no effect; once support for nullable extension dtypes is added, this value should be propagated to columns. + Polars currently relies on pyarrow's implementation of the dataframe interchange + protocol. Therefore, pyarrow>=11.0.0 is required for this method to work. + + Because Polars can not currently guarantee zero-copy conversion to Arrow for + categorical columns, ``allow_copy=False`` will not work if the dataframe + contains categorical data. + """ if not _PYARROW_AVAILABLE or parse_version(pa.__version__) < parse_version( "11" @@ -1228,11 +1236,11 @@ def __dataframe__( "pyarrow>=11.0.0 is required for converting a Polars dataframe to a" " dataframe interchange object." ) - if not allow_copy and Categorical in self.schema.values(): - raise NotImplementedError( - "Polars does not offer zero-copy conversion to Arrow for categorical" - " columns. Set `allow_copy=True` or cast categorical columns to" - " string first." + if not allow_copy and Categorical in unpack_dtypes(*self.dtypes): + raise TypeError( + "Polars can not currently guarantee zero-copy conversion to Arrow for" + " categorical columns. Set `allow_copy=True` or cast categorical" + " columns to string first." ) return self.to_arrow().__dataframe__(nan_as_null, allow_copy) diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 9b4f1a2a985c..b754de07ae7b 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -171,10 +171,9 @@ def unpack_dtypes( Parameters ---------- - *dtypes : PolarsDataType | Collection[PolarsDataType] | None - one or more polars dtypes. - - include_compound : bool, default True + *dtypes + One or more Polars dtypes. + include_compound * if True, any parent/compound dtypes (List, Struct) are included in the result. * if False, only the child/scalar dtypes are returned from these types. diff --git a/py-polars/tests/unit/test_interchange.py b/py-polars/tests/unit/test_interchange.py index 06bfb754825b..54acf3bc64b5 100644 --- a/py-polars/tests/unit/test_interchange.py +++ b/py-polars/tests/unit/test_interchange.py @@ -45,17 +45,35 @@ def test_interchange_categorical() -> None: assert dfi.get_column_by_name("a").dtype[0] == 23 # 23 signifies categorical dtype # If copy not allowed, throws an error - with pytest.raises(NotImplementedError, match="categorical"): + with pytest.raises(TypeError, match="categorical"): df.__dataframe__(allow_copy=False) -def test_from_dataframe() -> None: +def test_interchange_nested_categorical() -> None: + df = pl.DataFrame( + {"a": [1, 2], "b": ["a", "b"], "c": [["q"], ["x"]]}, + schema_overrides={"c": pl.List(pl.Categorical)}, + ) + + with pytest.raises(TypeError, match="categorical"): + df.__dataframe__(allow_copy=False) + + +def test_from_dataframe_polars() -> None: df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}) dfi = df.__dataframe__() result = pl.from_dataframe(dfi) assert_frame_equal(result, df) +def test_from_dataframe_categorical_zero_copy() -> None: + df = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical}) + df_pa = df.to_arrow() + + with pytest.raises(TypeError): + pl.from_dataframe(df_pa, allow_copy=False) + + def test_from_dataframe_pandas() -> None: data = {"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}