From c35304e7e0d9dfdbd9983bd4d8553b9f2d91adeb Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Thu, 7 Nov 2024 09:21:30 +1100 Subject: [PATCH] refactor(rust): Refactor compute kernels in polars-arrow to avoid using gather --- crates/polars-arrow/src/compute/cast/mod.rs | 55 +++--- .../src/legacy/kernels/fixed_size_list.rs | 158 +++++++++++------ .../polars-arrow/src/legacy/kernels/list.rs | 167 +++++------------- 3 files changed, 179 insertions(+), 201 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index f34d9ebba2a5..e3abcb22dd11 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -16,8 +16,8 @@ pub use binview_to::utf8view_to_utf8; pub use boolean_to::*; pub use decimal_to::*; use dictionary_to::*; +use growable::make_growable; use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult}; -use polars_utils::IdxSize; pub use primitive_to::*; pub use utf8_to::*; @@ -170,9 +170,13 @@ fn cast_fixed_size_list_to_list( fn cast_list_to_fixed_size_list( list: &ListArray, inner: &Field, - size: usize, + size: usize, // width options: CastOptionsImpl, -) -> PolarsResult { +) -> PolarsResult +where + ListArray: crate::array::StaticArray + + ArrayFromIter>>, +{ let null_cnt = list.null_count(); let new_values = if null_cnt == 0 { let start_offset = list.offsets().first().to_usize(); @@ -190,7 +194,8 @@ fn cast_list_to_fixed_size_list( .sliced(start_offset, list.offsets().range().to_usize()); cast(sliced_values.as_ref(), inner.dtype(), options)? } else { - let offsets = list.offsets().as_slice(); + let offsets = list.offsets(); + // Check the lengths of each list are equal to the fixed size. // SAFETY: we know the index is in bound. let mut expected_offset = unsafe { *offsets.get_unchecked(0) } + O::from_as_usize(size); @@ -206,27 +211,33 @@ fn cast_list_to_fixed_size_list( } } - // Build take indices for the values. This is used to fill in the null slots. - let mut indices = - MutablePrimitiveArray::::with_capacity(list.values().len() + null_cnt * size); - for i in 0..list.len() { - if list.is_null(i) { - indices.extend_constant(size, None) - } else { - // SAFETY: we know the index is in bound. - let current_offset = unsafe { *offsets.get_unchecked(i) }; - for j in 0..size { - indices.push(Some( - (current_offset + O::from_as_usize(j)).to_usize() as IdxSize - )); + let list_validity = list.validity().unwrap(); + let mut growable = make_growable(&[list.values().as_ref()], true, list.len()); + + if cfg!(debug_assertions) { + let msg = "fn cast_list_to_fixed_size_list < nullable >"; + dbg!(msg); + } + + for (outer_idx, x) in offsets.windows(2).enumerate() { + let [i, j] = x else { unreachable!() }; + let i = i.to_usize(); + let j = j.to_usize(); + + unsafe { + let outer_is_valid = list_validity.get_bit_unchecked(outer_idx); + + if outer_is_valid { + growable.extend(0, i, j - i); + } else { + growable.extend_validity(size) } - } + }; } - let take_values = unsafe { - crate::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze()) - }; - cast(take_values.as_ref(), inner.dtype(), options)? + let values = growable.as_box(); + + cast(values.as_ref(), inner.dtype(), options)? }; FixedSizeListArray::try_new( diff --git a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs index 641442f1a9e5..98345ae13474 100644 --- a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs @@ -1,59 +1,47 @@ -use polars_error::{polars_bail, PolarsResult}; -use polars_utils::index::NullCount; -use polars_utils::IdxSize; - -use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray}; -use crate::compute::take::take_unchecked; -use crate::legacy::prelude::*; -use crate::legacy::utils::CustomIterTools; - -fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr { - (0..len) - .map(|i| { - if index >= width as i64 { - return None; - } - - index - .negative_to_usize(width) - .map(|idx| (idx + i * width) as IdxSize) - }) - .collect_trusted() -} - -fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray) -> IdxArr { - index - .iter() - .enumerate() - .map(|(i, idx)| { - if let Some(idx) = idx { - if *idx >= width as i64 { - return None; - } +use polars_error::{polars_bail, PolarsError, PolarsResult}; - idx.negative_to_usize(width) - .map(|idx| (idx + i * width) as IdxSize) - } else { - None - } - }) - .collect_trusted() -} +use crate::array::growable::make_growable; +use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray}; +use crate::bitmap::BitmapBuilder; +use crate::compute::utils::combine_validities_and; +use crate::datatypes::ArrowDataType; pub fn sub_fixed_size_list_get_literal( arr: &FixedSizeListArray, index: i64, null_on_oob: bool, ) -> PolarsResult { - let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index); - if !null_on_oob && take_by.null_count() > 0 { - polars_bail!(ComputeError: "get index is out of bounds"); + if cfg!(debug_assertions) { + let msg = "fn sub_fixed_size_list_get_literal"; + dbg!(msg); + } + + let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else { + unreachable!(); + }; + + let width = *width; + + let index = usize::try_from(index).unwrap(); + + if !null_on_oob && index >= width { + polars_bail!( + ComputeError: + "get index {} is out of bounds for array(width={})", + index, + width + ); } let values = arr.values(); - // SAFETY: - // the indices we generate are in bounds - unsafe { Ok(take_unchecked(&**values, &take_by)) } + + let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len()); + + for i in 0..arr.len() { + unsafe { growable.extend(0, i * width + index, 1) } + } + + Ok(growable.as_box()) } pub fn sub_fixed_size_list_get( @@ -61,13 +49,79 @@ pub fn sub_fixed_size_list_get( index: &PrimitiveArray, null_on_oob: bool, ) -> PolarsResult { - let take_by = sub_fixed_size_list_get_indexes(arr.size(), index); - if !null_on_oob && take_by.null_count() > 0 { - polars_bail!(ComputeError: "get index is out of bounds"); + if cfg!(debug_assertions) { + let msg = "fn sub_fixed_size_list_get"; + dbg!(msg); + } + + fn idx_oob_err(index: i64, width: usize) -> PolarsError { + PolarsError::ComputeError( + format!( + "get index {} is out of bounds for array(width={})", + index, width + ) + .into(), + ) } + let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else { + unreachable!(); + }; + + let width = *width; + + if arr.is_empty() { + if !null_on_oob { + if let Some(i) = index.non_null_values_iter().max() { + if usize::try_from(i).unwrap() >= width { + return Err(idx_oob_err(i, width)); + } + } + } + + let values = arr.values(); + assert!(values.is_empty()); + return Ok(values.clone()); + } + + if !null_on_oob && width == 0 { + if let Some(i) = index.non_null_values_iter().next() { + return Err(idx_oob_err(i, width)); + } + } + + // Array is non-empty and has non-zero width at this point let values = arr.values(); - // SAFETY: - // the indices we generate are in bounds - unsafe { Ok(take_unchecked(&**values, &take_by)) } + + let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len()); + let mut output_validity = BitmapBuilder::with_capacity(arr.len()); + let opt_index_validity = index.validity(); + let mut exceeded_width_idx = 0; + + for i in 0..arr.len() { + let idx = usize::try_from(index.value(i)).unwrap(); + let idx_is_oob = idx >= width; + let idx_is_valid = opt_index_validity.map_or(true, |x| unsafe { x.get_bit_unchecked(i) }); + + if idx_is_oob && idx_is_valid && exceeded_width_idx < width { + exceeded_width_idx = idx; + } + + let idx = if idx_is_oob { 0 } else { idx }; + + unsafe { + growable.extend(0, i * width + idx, 1); + let output_is_valid = idx_is_valid & !idx_is_oob; + output_validity.push_unchecked(output_is_valid); + } + } + + if !null_on_oob && exceeded_width_idx >= width { + return Err(idx_oob_err(exceeded_width_idx as i64, width)); + } + + let output = growable.as_box(); + let output_validity = combine_validities_and(Some(&output_validity.freeze()), arr.validity()); + + Ok(output.with_validity(output_validity)) } diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index 4f3f332dac28..64bcbeb12d5c 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -1,104 +1,52 @@ -use polars_utils::IdxSize; - +use crate::array::growable::make_growable; use crate::array::{Array, ArrayRef, ListArray}; -use crate::compute::take::take_unchecked; +use crate::bitmap::BitmapBuilder; +use crate::compute::utils::combine_validities_and; use crate::legacy::prelude::*; use crate::legacy::trusted_len::TrustedLenPush; -use crate::legacy::utils::CustomIterTools; use crate::offset::{Offsets, OffsetsBuffer}; -/// Get the indices that would result in a get operation on the lists values. -/// for example, consider this list: -/// ```text -/// [[1, 2, 3], -/// [4, 5], -/// [6]] -/// -/// This contains the following values array: -/// [1, 2, 3, 4, 5, 6] -/// -/// get index 0 -/// would lead to the following indexes: -/// [0, 3, 5]. -/// if we use those in a take operation on the values array we get: -/// [1, 4, 6] -/// -/// -/// get index -1 -/// would lead to the following indexes: -/// [2, 4, 5]. -/// if we use those in a take operation on the values array we get: -/// [3, 5, 6] -/// -/// ``` -fn sublist_get_indexes(arr: &ListArray, index: i64) -> IdxArr { - let offsets = arr.offsets().as_slice(); - let mut iter = offsets.iter(); - - // the indices can be sliced, so we should not start at 0. - let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize; - - if let Some(mut previous) = iter.next().copied() { - if arr.null_count() == 0 { - iter.map(|&offset| { - let len = offset - previous; - previous = offset; - // make sure that empty lists don't get accessed - // and out of bounds return null - if len == 0 { - return None; - } - if index >= len { - cum_offset += len as IdxSize; - return None; - } +pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { + if cfg!(debug_assertions) { + let msg = "fn sublist_get"; + dbg!(msg); + } - let out = index - .negative_to_usize(len as usize) - .map(|idx| idx as IdxSize + cum_offset); - cum_offset += len as IdxSize; - out - }) - .collect_trusted() - } else { - // we can ensure that validity is not none as we have null value. - let validity = arr.validity().unwrap(); - iter.enumerate() - .map(|(i, &offset)| { - let len = offset - previous; - previous = offset; - // make sure that empty and null lists don't get accessed and return null. - // SAFETY, we are within bounds - if len == 0 || !unsafe { validity.get_bit_unchecked(i) } { - cum_offset += len as IdxSize; - return None; - } - - // make sure that out of bounds return null - if index >= len { - cum_offset += len as IdxSize; - return None; - } - - let out = index - .negative_to_usize(len as usize) - .map(|idx| idx as IdxSize + cum_offset); - cum_offset += len as IdxSize; - out - }) - .collect_trusted() + let values = arr.values(); + + let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len()); + let mut result_validity = BitmapBuilder::with_capacity(arr.len()); + let opt_outer_validity = arr.validity(); + let index = usize::try_from(index).unwrap(); + + for (outer_idx, x) in arr.offsets().windows(2).enumerate() { + let [i, j] = x else { unreachable!() }; + let i = usize::try_from(*i).unwrap(); + let j = usize::try_from(*j).unwrap(); + + let (offset, len) = (i, j - i); + + let idx_is_oob = index >= len; + let outer_is_valid = + opt_outer_validity.map_or(true, |x| unsafe { x.get_bit_unchecked(outer_idx) }); + + unsafe { + if idx_is_oob { + growable.extend_validity(1); + } else { + growable.extend(0, offset + index, 1); + } + + result_validity.push_unchecked(!idx_is_oob & outer_is_valid); } - } else { - IdxArr::from_slice([]) } -} -pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { - let take_by = sublist_get_indexes(arr, index); - let values = arr.values(); - // SAFETY: - // the indices we generate are in bounds - unsafe { take_unchecked(&**values, &take_by) } + let values = growable.as_box(); + + values.with_validity(combine_validities_and( + Some(&result_validity.freeze()), + values.validity(), + )) } /// Check if an index is out of bounds for at least one sublist. @@ -158,41 +106,6 @@ mod test { ListArray::::new(dtype, offsets, Box::new(values), None) } - #[test] - fn test_sublist_get_indexes() { - let arr = get_array(); - let out = sublist_get_indexes(&arr, 0); - assert_eq!(out.values().as_slice(), &[0, 3, 5]); - let out = sublist_get_indexes(&arr, -1); - assert_eq!(out.values().as_slice(), &[2, 4, 5]); - let out = sublist_get_indexes(&arr, 3); - assert_eq!(out.null_count(), 3); - - let values = Int32Array::from_iter([ - Some(1), - Some(1), - Some(3), - Some(4), - Some(5), - Some(6), - Some(7), - Some(8), - Some(9), - None, - Some(11), - ]); - let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap(); - - let dtype = ListArray::::default_datatype(ArrowDataType::Int32); - let arr = ListArray::::new(dtype, offsets, Box::new(values), None); - - let out = sublist_get_indexes(&arr, 1); - assert_eq!( - out.into_iter().collect::>(), - &[None, None, None, Some(4), Some(7), Some(10)] - ); - } - #[test] fn test_sublist_get() { let arr = get_array();