Skip to content

Commit

Permalink
add constructor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 21, 2024
1 parent 7446467 commit b576372
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,3 +1624,73 @@ def test_array_construction() -> None:
df = pl.from_dicts(rows, schema=schema)
assert df.schema == schema
assert df.rows() == [("a", [1, 2, 3]), ("b", [2, 3, 4])]


class PyCapsuleStreamHolder:
"""
Hold the Arrow C Stream pycapsule.
A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules. This
ensures that the consumer is seeing _only_ the `__arrow_c_stream__` dunder, and that
nothing else (e.g. the dataframe or array interface) is actually being used.
"""

arrow_obj: Any

def __init__(self, arrow_obj: object) -> None:
self.arrow_obj = arrow_obj

def __arrow_c_stream__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_stream__(requested_schema)


class PyCapsuleArrayHolder:
"""
Hold the Arrow C Array pycapsule.
A class that exposes _only_ the Arrow C Array interface via Arrow PyCapsules. This
ensures that the consumer is seeing _only_ the `__arrow_c_array__` dunder, and that
nothing else (e.g. the dataframe or array interface) is actually being used.
"""

arrow_obj: Any

def __init__(self, arrow_obj: object) -> None:
self.arrow_obj = arrow_obj

def __arrow_c_array__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_array__(requested_schema)


def test_pycapsule_interface(df: pl.DataFrame) -> None:
pyarrow_table = df.to_arrow()

# Array via C data interface
pyarrow_array = pyarrow_table["bools"].chunk(0)
round_trip_series = pl.Series(PyCapsuleArrayHolder(pyarrow_array))
df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)

# RecordBatch via C array interface
pyarrow_record_batch = pyarrow_table.to_batches()[0]
round_trip_df = pl.DataFrame(PyCapsuleArrayHolder(pyarrow_record_batch))
assert df.equals(round_trip_df)

# ChunkedArray via C stream interface
pyarrow_chunked_array = pyarrow_table["bools"]
round_trip_series = pl.Series(PyCapsuleStreamHolder(pyarrow_chunked_array))
df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)

# Table via C stream interface
round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_table))
assert df.equals(round_trip_df)

# empty Table via C stream interface
# empty_df = df[:0].to_arrow()
# round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(empty_df))

# RecordBatchReader via C stream interface
pyarrow_reader = pa.RecordBatchReader.from_batches(
pyarrow_table.schema, pyarrow_table.to_batches()
)
round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_reader))
assert df.equals(round_trip_df)

0 comments on commit b576372

Please sign in to comment.