Skip to content

Commit

Permalink
add test cases for empty arrays/streams
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 21, 2024
1 parent b576372 commit 477ed23
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
8 changes: 7 additions & 1 deletion py-polars/src/series/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ pub(crate) fn import_stream_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Py
produced_arrays.push(array.unwrap());
}

let s = Series::try_from((stream.field(), produced_arrays)).unwrap();
// Series::try_from fails for an empty vec of chunks
let s = if produced_arrays.is_empty() {
let polars_dt = DataType::from_arrow(stream.field().data_type(), false);
Series::new_empty(&stream.field().name, &polars_dt)
} else {
Series::try_from((stream.field(), produced_arrays)).unwrap()
};
Ok(PySeries::new(s))
}
#[pymethods]
Expand Down
27 changes: 23 additions & 4 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,7 +1668,12 @@ def test_pycapsule_interface(df: pl.DataFrame) -> None:
# Array via C data interface
pyarrow_array = pyarrow_table["bools"].chunk(0)
round_trip_series = pl.Series(PyCapsuleArrayHolder(pyarrow_array))
df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)
assert df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)

# empty Array via C data interface
empty_pyarrow_array = pa.array([], type=pyarrow_array.type)
round_trip_series = pl.Series(PyCapsuleArrayHolder(empty_pyarrow_array))
assert df["bools"].dtype == round_trip_series.dtype

# RecordBatch via C array interface
pyarrow_record_batch = pyarrow_table.to_batches()[0]
Expand All @@ -1678,15 +1683,29 @@ def test_pycapsule_interface(df: pl.DataFrame) -> None:
# ChunkedArray via C stream interface
pyarrow_chunked_array = pyarrow_table["bools"]
round_trip_series = pl.Series(PyCapsuleStreamHolder(pyarrow_chunked_array))
df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)
assert df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False)

# empty ChunkedArray via C stream interface
empty_chunked_array = pa.chunked_array([], type=pyarrow_chunked_array.type)
round_trip_series = pl.Series(PyCapsuleStreamHolder(empty_chunked_array))
assert df["bools"].dtype == round_trip_series.dtype

# Table via C stream interface
round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_table))
assert df.equals(round_trip_df)

# empty Table via C stream interface
# empty_df = df[:0].to_arrow()
# round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(empty_df))
empty_df = df[:0].to_arrow()
round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(empty_df))
orig_schema = df.schema
round_trip_schema = round_trip_df.schema

# The "enum" schema is not preserved because categories are lost via C data
# interface
orig_schema.pop("enum")
round_trip_schema.pop("enum")

assert orig_schema == round_trip_schema

# RecordBatchReader via C stream interface
pyarrow_reader = pa.RecordBatchReader.from_batches(
Expand Down

0 comments on commit 477ed23

Please sign in to comment.