Skip to content

Commit

Permalink
Support __arrow_c_stream__ in DataFrame constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 18, 2024
1 parent ebba58d commit a0e10be
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
3 changes: 3 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def __init__(
nan_to_null=nan_to_null,
)

elif hasattr(data, "__arrow_c_stream__"):
self._df = PyDataFrame.from_arrow_c_stream(data)

elif _check_for_pyarrow(data) and isinstance(data, pa.Table):
self._df = arrow_to_pydf(
data, schema=schema, schema_overrides=schema_overrides, strict=strict
Expand Down
93 changes: 93 additions & 0 deletions py-polars/src/dataframe/import.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use polars::export::arrow::ffi::{ArrowArrayStream, ArrowArrayStreamReader};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyType};

use super::*;

/// Validate PyCapsule has provided name
fn validate_pycapsule_name(capsule: &Bound<PyCapsule>, expected_name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if let Some(capsule_name) = capsule_name {
let capsule_name = capsule_name.to_str()?;
if capsule_name != expected_name {
return Err(PyValueError::new_err(format!(
"Expected name '{}' in PyCapsule, instead got '{}'",
expected_name, capsule_name
)));
}
} else {
return Err(PyValueError::new_err(
"Expected schema PyCapsule to have name set.",
));
}

Ok(())
}

/// Import `__arrow_c_stream__` across Python boundary.
fn call_arrow_c_stream<'py>(ob: &'py Bound<PyAny>) -> PyResult<Bound<'py, PyCapsule>> {
if !ob.hasattr("__arrow_c_stream__")? {
return Err(PyValueError::new_err(
"Expected an object with dunder __arrow_c_stream__",
));
}

let capsule = ob.getattr("__arrow_c_stream__")?.call0()?.downcast_into()?;
Ok(capsule)
}

pub(crate) fn import_stream_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<PyDataFrame> {
validate_pycapsule_name(capsule, "arrow_array_stream")?;

// Takes ownership of the pointed to ArrowArrayStream
// This acts to move the data out of the capsule pointer, setting the release callback to NULL
let stream_ptr =
Box::new(unsafe { std::ptr::replace(capsule.pointer() as _, ArrowArrayStream::empty()) });

let mut stream = unsafe {
ArrowArrayStreamReader::try_new(stream_ptr)
.map_err(|err| PyValueError::new_err(err.to_string()))?
};

// For now we'll assume that these are struct arrays to represent record batches
let mut produced_arrays = vec![];
while let Some(array) = unsafe { stream.next() } {
let arr = array.map_err(|err| PyValueError::new_err(err.to_string()))?;
let struct_arr = match arr.data_type() {
ArrowDataType::Struct(_) => arr.as_any().downcast_ref::<StructArray>().unwrap().clone(),
_ => return Err(PyValueError::new_err("Expected struct data type")),
};
produced_arrays.push(struct_arr);
}

let stream_field = stream.field();
// For now we'll assume that these are struct arrays to represent record batches
let struct_fields = match stream_field.data_type() {
ArrowDataType::Struct(struct_fields) => struct_fields,
_ => return Err(PyValueError::new_err("Expected struct data type")),
};

let mut columns: Vec<Series> = vec![];
for (col_idx, column_field) in struct_fields.iter().enumerate() {
let column_chunks = produced_arrays
.iter()
.map(|arr| arr.values()[col_idx].clone())
.collect::<Vec<_>>();
// TODO: remove unwrap
columns.push(Series::try_from((column_field, column_chunks)).unwrap());
}

// TODO: remove unwrap
Ok(PyDataFrame::new(
polars::frame::DataFrame::new(columns).unwrap(),
))
}
#[pymethods]
impl PyDataFrame {
#[classmethod]
pub fn from_arrow_c_stream(_cls: &Bound<PyType>, ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let capsule = call_arrow_c_stream(ob)?;
import_stream_pycapsule(&capsule)
}
}
1 change: 1 addition & 0 deletions py-polars/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod construction;
mod export;
mod general;
mod import;
mod io;
mod serde;

Expand Down

0 comments on commit a0e10be

Please sign in to comment.