diff --git a/py-polars/src/dataframe/export.rs b/py-polars/src/dataframe/export.rs index 592a52061c6a..f595ac92f02c 100644 --- a/py-polars/src/dataframe/export.rs +++ b/py-polars/src/dataframe/export.rs @@ -135,10 +135,11 @@ impl PyDataFrame { #[allow(unused_variables)] #[pyo3(signature = (requested_schema=None))] fn __arrow_c_stream__<'py>( - &'py self, + &'py mut self, py: Python<'py>, requested_schema: Option, ) -> PyResult> { - dataframe_to_stream(&self.df, py) + self.df.align_chunks(); + dataframe_to_stream(self.df.clone(), py) } } diff --git a/py-polars/src/interop/arrow/to_py.rs b/py-polars/src/interop/arrow/to_py.rs index cfa7f6c2398f..5048e399e988 100644 --- a/py-polars/src/interop/arrow/to_py.rs +++ b/py-polars/src/interop/arrow/to_py.rs @@ -2,11 +2,12 @@ use std::ffi::CString; use arrow::ffi; use arrow::record_batch::RecordBatch; -use polars::datatypes::{CompatLevel, DataType, Field}; +use polars::datatypes::CompatLevel; use polars::frame::DataFrame; use polars::prelude::{ArrayRef, ArrowField}; use polars::series::Series; use polars_core::utils::arrow; +use polars_error::PolarsResult; use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -69,25 +70,60 @@ pub(crate) fn series_to_stream<'py>( PyCapsule::new_bound(py, stream, Some(stream_capsule_name)) } -pub(crate) fn dataframe_to_stream<'py>( - df: &'py DataFrame, - py: Python<'py>, -) -> PyResult> { - let schema_fields = df.schema().iter_fields().collect::>(); - - let struct_field = - Field::new("", DataType::Struct(schema_fields)).to_arrow(CompatLevel::oldest()); - let struct_data_type = struct_field.data_type().clone(); - - let iter = df - .iter_chunks(CompatLevel::oldest(), false) - .into_iter() - .map(|chunk| { - let arrays = chunk.into_arrays(); - let x = arrow::array::StructArray::new(struct_data_type.clone(), arrays, None); - Ok(Box::new(x) as Box) - }); - let stream = ffi::export_iterator(Box::new(iter), struct_field); +pub(crate) fn dataframe_to_stream(df: DataFrame, py: Python) -> PyResult> { + let iter = Box::new(DataFrameStreamIterator::new(df)); + let field = iter.field(); + let stream = ffi::export_iterator(iter, field); let stream_capsule_name = CString::new("arrow_array_stream").unwrap(); PyCapsule::new_bound(py, stream, Some(stream_capsule_name)) } + +pub struct DataFrameStreamIterator { + columns: Vec, + data_type: arrow::datatypes::ArrowDataType, + idx: usize, + n_chunks: usize, +} + +impl DataFrameStreamIterator { + fn new(df: polars::frame::DataFrame) -> Self { + let schema = df.schema().to_arrow(CompatLevel::oldest()); + let data_type = arrow::datatypes::ArrowDataType::Struct(schema.fields); + + Self { + columns: df.get_columns().to_vec(), + data_type, + idx: 0, + n_chunks: df.n_chunks(), + } + } + + fn field(&self) -> ArrowField { + ArrowField::new("", self.data_type.clone(), false) + } +} + +impl Iterator for DataFrameStreamIterator { + type Item = PolarsResult>; + + fn next(&mut self) -> Option { + if self.idx >= self.n_chunks { + None + } else { + // create a batch of the columns with the same chunk no. + let batch_cols = self + .columns + .iter() + .map(|s| s.to_arrow(self.idx, CompatLevel::oldest())) + .collect(); + self.idx += 1; + + let array = arrow::array::StructArray::new( + self.data_type.clone(), + batch_cols, + std::option::Option::None, + ); + Some(std::result::Result::Ok(Box::new(array))) + } + } +}