diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index 2fdd3ae95bc1..253761b592ea 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -1624,3 +1624,73 @@ def test_array_construction() -> None: df = pl.from_dicts(rows, schema=schema) assert df.schema == schema assert df.rows() == [("a", [1, 2, 3]), ("b", [2, 3, 4])] + + +class PyCapsuleStreamHolder: + """ + Hold the Arrow C Stream pycapsule. + + A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules. This + ensures that the consumer is seeing _only_ the `__arrow_c_stream__` dunder, and that + nothing else (e.g. the dataframe or array interface) is actually being used. + """ + + arrow_obj: Any + + def __init__(self, arrow_obj: object) -> None: + self.arrow_obj = arrow_obj + + def __arrow_c_stream__(self, requested_schema: object = None) -> object: + return self.arrow_obj.__arrow_c_stream__(requested_schema) + + +class PyCapsuleArrayHolder: + """ + Hold the Arrow C Array pycapsule. + + A class that exposes _only_ the Arrow C Array interface via Arrow PyCapsules. This + ensures that the consumer is seeing _only_ the `__arrow_c_array__` dunder, and that + nothing else (e.g. the dataframe or array interface) is actually being used. + """ + + arrow_obj: Any + + def __init__(self, arrow_obj: object) -> None: + self.arrow_obj = arrow_obj + + def __arrow_c_array__(self, requested_schema: object = None) -> object: + return self.arrow_obj.__arrow_c_array__(requested_schema) + + +def test_pycapsule_interface(df: pl.DataFrame) -> None: + pyarrow_table = df.to_arrow() + + # Array via C data interface + pyarrow_array = pyarrow_table["bools"].chunk(0) + round_trip_series = pl.Series(PyCapsuleArrayHolder(pyarrow_array)) + df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False) + + # RecordBatch via C array interface + pyarrow_record_batch = pyarrow_table.to_batches()[0] + round_trip_df = pl.DataFrame(PyCapsuleArrayHolder(pyarrow_record_batch)) + assert df.equals(round_trip_df) + + # ChunkedArray via C stream interface + pyarrow_chunked_array = pyarrow_table["bools"] + round_trip_series = pl.Series(PyCapsuleStreamHolder(pyarrow_chunked_array)) + df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False) + + # Table via C stream interface + round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_table)) + assert df.equals(round_trip_df) + + # empty Table via C stream interface + # empty_df = df[:0].to_arrow() + # round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(empty_df)) + + # RecordBatchReader via C stream interface + pyarrow_reader = pa.RecordBatchReader.from_batches( + pyarrow_table.schema, pyarrow_table.to_batches() + ) + round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_reader)) + assert df.equals(round_trip_df)