Skip to content

Commit

Permalink
specialization to bypass repeated Arc deref
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Sep 20, 2023
1 parent 5e4397b commit 6b098d7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
23 changes: 9 additions & 14 deletions crates/polars-core/src/chunked_array/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,12 @@ unsafe fn gather_idx_array_unchecked<A: StaticArray>(
) -> A {
let it = indices.iter().copied();
if targets.len() == 1 {
let arr = targets.iter().next().unwrap();
let target = targets.first().unwrap();
if has_nulls {
it.map(|i| arr.get_unchecked(i as usize))
it.map(|i| target.get_unchecked(i as usize))
.collect_arr_trusted_with_dtype(dtype)
} else {
it.map(|i| arr.value_unchecked(i as usize))
.collect_arr_trusted_with_dtype(dtype)
target.gather_unchecked_trusted(indices.iter().map(|i| *i as usize), dtype)
}
} else {
let cumlens = cumulative_lengths(targets);
Expand Down Expand Up @@ -152,38 +151,34 @@ impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T> {
let targets: Vec<_> = ca.downcast_iter().collect();

let chunks = indices.downcast_iter().map(|idx_arr| {
let dtype = ca.dtype().clone();
if idx_arr.null_count() == 0 {
gather_idx_array_unchecked(
ca.dtype().clone(),
&targets,
targets_have_nulls,
idx_arr.values(),
)
gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
} else if targets.len() == 1 {
let target = targets.first().unwrap();
if targets_have_nulls {
idx_arr
.iter()
.map(|i| target.get_unchecked(*i? as usize))
.collect_arr_trusted_with_dtype(ca.dtype().clone())
.collect_arr_trusted_with_dtype(dtype)
} else {
idx_arr
.iter()
.map(|i| Some(target.value_unchecked(*i? as usize)))
.collect_arr_trusted_with_dtype(ca.dtype().clone())
.collect_arr_trusted_with_dtype(dtype)
}
} else {
let cumlens = cumulative_lengths(&targets);
if targets_have_nulls {
idx_arr
.iter()
.map(|i| target_get_unchecked(&targets, &cumlens, *i?))
.collect_arr_trusted_with_dtype(ca.dtype().clone())
.collect_arr_trusted_with_dtype(dtype)
} else {
idx_arr
.iter()
.map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
.collect_arr_trusted_with_dtype(ca.dtype().clone())
.collect_arr_trusted_with_dtype(dtype)
}
}
});
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-core/src/datatypes/static_array.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use arrow::bitmap::utils::{BitmapIter, ZipValidity};
use arrow::bitmap::Bitmap;
use polars_arrow::trusted_len::TrustedLenPush;

#[cfg(feature = "object")]
use crate::chunked_array::object::{ObjectArray, ObjectValueIter};
Expand Down Expand Up @@ -54,6 +55,11 @@ pub trait StaticArray:
/// # Safety
/// It is the callers responsibility that the `idx < self.len()`.
unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_>;

#[inline]
unsafe fn gather_unchecked_trusted<I: Iterator<Item=usize> + TrustedLen>(&self, it: I, dtype: DataType) -> Self {
it.map(|i| self.value_unchecked(i)).collect_arr_with_dtype(dtype)
}

fn iter(&self) -> ZipValidity<Self::ValueT<'_>, Self::ValueIterT<'_>, BitmapIter>;
fn values_iter(&self) -> Self::ValueIterT<'_>;
Expand All @@ -77,6 +83,13 @@ impl<T: NumericNative> StaticArray for PrimitiveArray<T> {
self.values_iter().copied()
}

#[inline]
unsafe fn gather_unchecked_trusted<I: Iterator<Item=usize> + TrustedLen>(&self, it: I, _dtype: DataType) -> Self {
let arr: &[T] = self.values();
let v = Vec::from_trusted_len_iter(it.map(|i| *arr.get_unchecked(i)));
PrimitiveArray::from_vec(v)
}

fn iter(&self) -> ZipValidity<Self::ValueT<'_>, Self::ValueIterT<'_>, BitmapIter> {
ZipValidity::new_with_validity(self.values().iter().copied(), self.validity())
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/datatypes/static_array_collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,12 @@ macro_rules! impl_trusted_collect_vec_validity {
}

impl<T: NumericNative> ArrayFromIter<T> for PrimitiveArray<T> {
#[inline]
fn arr_from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
PrimitiveArray::from_vec(iter.into_iter().collect())
}

#[inline]
fn arr_from_iter_trusted<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
Expand All @@ -310,11 +312,13 @@ impl<T: NumericNative> ArrayFromIter<T> for PrimitiveArray<T> {
PrimitiveArray::from_vec(Vec::from_trusted_len_iter(iter))
}

#[inline]
fn try_arr_from_iter<E, I: IntoIterator<Item = Result<T, E>>>(iter: I) -> Result<Self, E> {
let v: Result<Vec<T>, E> = iter.into_iter().collect();
Ok(PrimitiveArray::from_vec(v?))
}

#[inline]
fn try_arr_from_iter_trusted<E, I>(iter: I) -> Result<Self, E>
where
I: IntoIterator<Item = Result<T, E>>,
Expand Down

0 comments on commit 6b098d7

Please sign in to comment.