Skip to content

Commit

Permalink
refactor(rust): Refactor compute kernels in polars-arrow to avoid usi…
Browse files Browse the repository at this point in the history
…ng gather
  • Loading branch information
nameexhaustion committed Nov 6, 2024
1 parent 18a4204 commit c35304e
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 201 deletions.
55 changes: 33 additions & 22 deletions crates/polars-arrow/src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ pub use binview_to::utf8view_to_utf8;
pub use boolean_to::*;
pub use decimal_to::*;
use dictionary_to::*;
use growable::make_growable;
use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult};
use polars_utils::IdxSize;
pub use primitive_to::*;
pub use utf8_to::*;

Expand Down Expand Up @@ -170,9 +170,13 @@ fn cast_fixed_size_list_to_list<O: Offset>(
fn cast_list_to_fixed_size_list<O: Offset>(
list: &ListArray<O>,
inner: &Field,
size: usize,
size: usize, // width
options: CastOptionsImpl,
) -> PolarsResult<FixedSizeListArray> {
) -> PolarsResult<FixedSizeListArray>
where
ListArray<O>: crate::array::StaticArray
+ ArrayFromIter<std::option::Option<Box<dyn crate::array::Array>>>,
{
let null_cnt = list.null_count();
let new_values = if null_cnt == 0 {
let start_offset = list.offsets().first().to_usize();
Expand All @@ -190,7 +194,8 @@ fn cast_list_to_fixed_size_list<O: Offset>(
.sliced(start_offset, list.offsets().range().to_usize());
cast(sliced_values.as_ref(), inner.dtype(), options)?
} else {
let offsets = list.offsets().as_slice();
let offsets = list.offsets();

// Check the lengths of each list are equal to the fixed size.
// SAFETY: we know the index is in bound.
let mut expected_offset = unsafe { *offsets.get_unchecked(0) } + O::from_as_usize(size);
Expand All @@ -206,27 +211,33 @@ fn cast_list_to_fixed_size_list<O: Offset>(
}
}

// Build take indices for the values. This is used to fill in the null slots.
let mut indices =
MutablePrimitiveArray::<IdxSize>::with_capacity(list.values().len() + null_cnt * size);
for i in 0..list.len() {
if list.is_null(i) {
indices.extend_constant(size, None)
} else {
// SAFETY: we know the index is in bound.
let current_offset = unsafe { *offsets.get_unchecked(i) };
for j in 0..size {
indices.push(Some(
(current_offset + O::from_as_usize(j)).to_usize() as IdxSize
));
let list_validity = list.validity().unwrap();
let mut growable = make_growable(&[list.values().as_ref()], true, list.len());

if cfg!(debug_assertions) {
let msg = "fn cast_list_to_fixed_size_list < nullable >";
dbg!(msg);
}

for (outer_idx, x) in offsets.windows(2).enumerate() {
let [i, j] = x else { unreachable!() };
let i = i.to_usize();
let j = j.to_usize();

unsafe {
let outer_is_valid = list_validity.get_bit_unchecked(outer_idx);

if outer_is_valid {
growable.extend(0, i, j - i);
} else {
growable.extend_validity(size)
}
}
};
}
let take_values = unsafe {
crate::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze())
};

cast(take_values.as_ref(), inner.dtype(), options)?
let values = growable.as_box();

cast(values.as_ref(), inner.dtype(), options)?
};

FixedSizeListArray::try_new(
Expand Down
158 changes: 106 additions & 52 deletions crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,127 @@
use polars_error::{polars_bail, PolarsResult};
use polars_utils::index::NullCount;
use polars_utils::IdxSize;

use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::compute::take::take_unchecked;
use crate::legacy::prelude::*;
use crate::legacy::utils::CustomIterTools;

fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr {
(0..len)
.map(|i| {
if index >= width as i64 {
return None;
}

index
.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
})
.collect_trusted()
}

fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray<i64>) -> IdxArr {
index
.iter()
.enumerate()
.map(|(i, idx)| {
if let Some(idx) = idx {
if *idx >= width as i64 {
return None;
}
use polars_error::{polars_bail, PolarsError, PolarsResult};

idx.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
} else {
None
}
})
.collect_trusted()
}
use crate::array::growable::make_growable;
use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::bitmap::BitmapBuilder;
use crate::compute::utils::combine_validities_and;
use crate::datatypes::ArrowDataType;

pub fn sub_fixed_size_list_get_literal(
arr: &FixedSizeListArray,
index: i64,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
if cfg!(debug_assertions) {
let msg = "fn sub_fixed_size_list_get_literal";
dbg!(msg);
}

let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else {
unreachable!();
};

let width = *width;

let index = usize::try_from(index).unwrap();

if !null_on_oob && index >= width {
polars_bail!(
ComputeError:
"get index {} is out of bounds for array(width={})",
index,
width
);
}

let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { Ok(take_unchecked(&**values, &take_by)) }

let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len());

for i in 0..arr.len() {
unsafe { growable.extend(0, i * width + index, 1) }
}

Ok(growable.as_box())
}

pub fn sub_fixed_size_list_get(
arr: &FixedSizeListArray,
index: &PrimitiveArray<i64>,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes(arr.size(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
if cfg!(debug_assertions) {
let msg = "fn sub_fixed_size_list_get";
dbg!(msg);
}

fn idx_oob_err(index: i64, width: usize) -> PolarsError {
PolarsError::ComputeError(
format!(
"get index {} is out of bounds for array(width={})",
index, width
)
.into(),
)
}

let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else {
unreachable!();
};

let width = *width;

if arr.is_empty() {
if !null_on_oob {
if let Some(i) = index.non_null_values_iter().max() {
if usize::try_from(i).unwrap() >= width {
return Err(idx_oob_err(i, width));
}
}
}

let values = arr.values();
assert!(values.is_empty());
return Ok(values.clone());
}

if !null_on_oob && width == 0 {
if let Some(i) = index.non_null_values_iter().next() {
return Err(idx_oob_err(i, width));
}
}

// Array is non-empty and has non-zero width at this point
let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { Ok(take_unchecked(&**values, &take_by)) }

let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len());
let mut output_validity = BitmapBuilder::with_capacity(arr.len());
let opt_index_validity = index.validity();
let mut exceeded_width_idx = 0;

for i in 0..arr.len() {
let idx = usize::try_from(index.value(i)).unwrap();
let idx_is_oob = idx >= width;
let idx_is_valid = opt_index_validity.map_or(true, |x| unsafe { x.get_bit_unchecked(i) });

if idx_is_oob && idx_is_valid && exceeded_width_idx < width {
exceeded_width_idx = idx;
}

let idx = if idx_is_oob { 0 } else { idx };

unsafe {
growable.extend(0, i * width + idx, 1);
let output_is_valid = idx_is_valid & !idx_is_oob;
output_validity.push_unchecked(output_is_valid);
}
}

if !null_on_oob && exceeded_width_idx >= width {
return Err(idx_oob_err(exceeded_width_idx as i64, width));
}

let output = growable.as_box();
let output_validity = combine_validities_and(Some(&output_validity.freeze()), arr.validity());

Ok(output.with_validity(output_validity))
}
Loading

0 comments on commit c35304e

Please sign in to comment.