Skip to content

Commit

Permalink
fix(python): Fix interchange protocol allowing copy even when `allow_…
Browse files Browse the repository at this point in the history
…copy` was set to False (#10262)
  • Loading branch information
stinodego authored Aug 4, 2023
1 parent 1bbe101 commit 3683344
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
47 changes: 40 additions & 7 deletions py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,20 +746,29 @@ def from_dataframe(df: Any, *, allow_copy: bool = True) -> DataFrame:
Details on the dataframe interchange protocol:
https://data-apis.org/dataframe-protocol/latest/index.html
Zero-copy conversions currently cannot be guaranteed and will throw a
``RuntimeError``.
Using a dedicated function like :func:`from_pandas` or :func:`from_arrow` is a more
efficient method of conversion.
Polars currently relies on pyarrow's implementation of the dataframe interchange
protocol. Therefore, pyarrow>=11.0.0 is required for this function to work.
Because Polars can not currently guarantee zero-copy conversion from Arrow for
categorical columns, ``allow_copy=False`` will not work if the dataframe contains
categorical data.
"""
if isinstance(df, pl.DataFrame):
return df
if not hasattr(df, "__dataframe__"):
raise TypeError(
f"`df` of type {type(df)} does not support the dataframe interchange"
" protocol."
f"`df` of type {type(df)} does not support the dataframe interchange protocol."
)

pa_table = _df_to_pyarrow_table(df, allow_copy=allow_copy)
return from_arrow(pa_table, rechunk=allow_copy) # type: ignore[return-value]


def _df_to_pyarrow_table(df: Any, *, allow_copy: bool = False) -> pa.Table:
if not _PYARROW_AVAILABLE or parse_version(pa.__version__) < parse_version("11"):
raise ImportError(
"pyarrow>=11.0.0 is required for converting a dataframe interchange object"
Expand All @@ -768,5 +777,29 @@ def from_dataframe(df: Any, *, allow_copy: bool = True) -> DataFrame:

import pyarrow.interchange # noqa: F401

pa_table = pa.interchange.from_dataframe(df, allow_copy=allow_copy)
return from_arrow(pa_table, rechunk=allow_copy) # type: ignore[return-value]
if not allow_copy:
return _df_to_pyarrow_table_zero_copy(df)

return pa.interchange.from_dataframe(df, allow_copy=True)


def _df_to_pyarrow_table_zero_copy(df: Any) -> pa.Table:
dfi = df.__dataframe__(allow_copy=False)
if _dfi_contains_categorical_data(dfi):
raise TypeError(
"Polars can not currently guarantee zero-copy conversion from Arrow for "
" categorical columns. Set `allow_copy=True` or cast categorical columns to"
" string first."
)

if isinstance(df, pa.Table):
return df
elif isinstance(df, pa.RecordBatch):
return pa.Table.from_batches([df])
else:
return pa.interchange.from_dataframe(dfi, allow_copy=False)


def _dfi_contains_categorical_data(dfi: Any) -> bool:
CATEGORICAL_DTYPE = 23
return any(c.dtype[0] == CATEGORICAL_DTYPE for c in dfi.get_columns())
20 changes: 14 additions & 6 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Time,
Utf8,
py_type_to_dtype,
unpack_dtypes,
)
from polars.dependencies import (
_PYARROW_AVAILABLE,
Expand Down Expand Up @@ -1217,9 +1218,16 @@ def __dataframe__(
Details on the dataframe interchange protocol:
https://data-apis.org/dataframe-protocol/latest/index.html
`nan_as_null` currently has no effect; once support for nullable extension
``nan_as_null`` currently has no effect; once support for nullable extension
dtypes is added, this value should be propagated to columns.
Polars currently relies on pyarrow's implementation of the dataframe interchange
protocol. Therefore, pyarrow>=11.0.0 is required for this method to work.
Because Polars can not currently guarantee zero-copy conversion to Arrow for
categorical columns, ``allow_copy=False`` will not work if the dataframe
contains categorical data.
"""
if not _PYARROW_AVAILABLE or parse_version(pa.__version__) < parse_version(
"11"
Expand All @@ -1228,11 +1236,11 @@ def __dataframe__(
"pyarrow>=11.0.0 is required for converting a Polars dataframe to a"
" dataframe interchange object."
)
if not allow_copy and Categorical in self.schema.values():
raise NotImplementedError(
"Polars does not offer zero-copy conversion to Arrow for categorical"
" columns. Set `allow_copy=True` or cast categorical columns to"
" string first."
if not allow_copy and Categorical in unpack_dtypes(*self.dtypes):
raise TypeError(
"Polars can not currently guarantee zero-copy conversion to Arrow for"
" categorical columns. Set `allow_copy=True` or cast categorical"
" columns to string first."
)
return self.to_arrow().__dataframe__(nan_as_null, allow_copy)

Expand Down
7 changes: 3 additions & 4 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,9 @@ def unpack_dtypes(
Parameters
----------
*dtypes : PolarsDataType | Collection[PolarsDataType] | None
one or more polars dtypes.
include_compound : bool, default True
*dtypes
One or more Polars dtypes.
include_compound
* if True, any parent/compound dtypes (List, Struct) are included in the result.
* if False, only the child/scalar dtypes are returned from these types.
Expand Down
22 changes: 20 additions & 2 deletions py-polars/tests/unit/test_interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,35 @@ def test_interchange_categorical() -> None:
assert dfi.get_column_by_name("a").dtype[0] == 23 # 23 signifies categorical dtype

# If copy not allowed, throws an error
with pytest.raises(NotImplementedError, match="categorical"):
with pytest.raises(TypeError, match="categorical"):
df.__dataframe__(allow_copy=False)


def test_from_dataframe() -> None:
def test_interchange_nested_categorical() -> None:
df = pl.DataFrame(
{"a": [1, 2], "b": ["a", "b"], "c": [["q"], ["x"]]},
schema_overrides={"c": pl.List(pl.Categorical)},
)

with pytest.raises(TypeError, match="categorical"):
df.__dataframe__(allow_copy=False)


def test_from_dataframe_polars() -> None:
df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]})
dfi = df.__dataframe__()
result = pl.from_dataframe(dfi)
assert_frame_equal(result, df)


def test_from_dataframe_categorical_zero_copy() -> None:
df = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical})
df_pa = df.to_arrow()

with pytest.raises(TypeError):
pl.from_dataframe(df_pa, allow_copy=False)


def test_from_dataframe_pandas() -> None:
data = {"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}

Expand Down

0 comments on commit 3683344

Please sign in to comment.