Skip to content

Commit

Permalink
perf: fix regression non-null asof join
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Oct 24, 2023
1 parent 5a32aab commit 806c94d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
32 changes: 26 additions & 6 deletions crates/polars-ops/src/frame/join/asof/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,36 @@ where
let mut mask = vec![0; (left.len() + 7) / 8];
let mut state = S::default();

for (i, opt_val_l) in left.iter().enumerate() {
if let Some(val_l) = opt_val_l {
if let Some(r_idx) =
state.next(&val_l, |j| right.get(j as usize), right.len() as IdxSize)
{
let val_r = unsafe { right.get_unchecked(r_idx as usize).unwrap_unchecked() };
if left.null_count() == 0 && right.null_count() == 0 {
for (i, val_l) in left.values_iter().enumerate() {
if let Some(r_idx) = state.next(
&val_l,
// SAFETY: next() only calls with indices < right.len().
|j| Some(unsafe { right.value_unchecked(j as usize) }),
right.len() as IdxSize,
) {
// SAFETY: r_idx is non-null and valid.
let val_r = unsafe { right.value_unchecked(r_idx as usize) };
out[i] = r_idx;
mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
} else {
for (i, opt_val_l) in left.iter().enumerate() {
if let Some(val_l) = opt_val_l {
if let Some(r_idx) = state.next(
&val_l,
// SAFETY: next() only calls with indices < right.len().
|j| unsafe { right.get_unchecked(j as usize) },
right.len() as IdxSize,
) {
// SAFETY: r_idx is non-null and valid.
let val_r = unsafe { right.value_unchecked(r_idx as usize) };
out[i] = r_idx;
mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8);
}
}
}
}

let bitmap = Bitmap::try_new(mask, out.len()).unwrap();
Expand Down
26 changes: 17 additions & 9 deletions crates/polars-ops/src/frame/join/asof/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,23 @@ where
let id = right_grp_idxs.first()?;
let grp_state = group_states.entry(*id).or_default();

let r_grp_idx = grp_state.next(
&left_val,
|i| right_val_arr.get(right_grp_idxs[i as usize] as usize),
right_grp_idxs.len() as IdxSize,
)?;

let r_idx = right_grp_idxs[r_grp_idx as usize];
let right_val = right_val_arr.get(r_idx as usize).unwrap();
filter(left_val, right_val).then_some(r_idx)
unsafe {
let r_grp_idx = grp_state.next(
&left_val,
|i| {
// SAFETY: the group indices are valid, and next() only calls with
// i < right_grp_idxs.len().
right_val_arr.get_unchecked(*right_grp_idxs.get_unchecked(i as usize) as usize)
},
right_grp_idxs.len() as IdxSize,
)?;

// SAFETY: r_grp_idx is valid, as is r_idx (which must be non-null) if
// we get here.
let r_idx = *right_grp_idxs.get_unchecked(r_grp_idx as usize);
let right_val = right_val_arr.value_unchecked(r_idx as usize);
filter(left_val, right_val).then_some(r_idx)
}
}

fn asof_join_by_numeric<T, S, A, F>(
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-ops/src/frame/join/asof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct AsofJoinForwardState {
}

impl<T: PartialOrd> AsofJoinState<T> for AsofJoinForwardState {
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
Expand All @@ -59,6 +60,7 @@ struct AsofJoinBackwardState {
}

impl<T: PartialOrd> AsofJoinState<T> for AsofJoinBackwardState {
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
Expand Down Expand Up @@ -87,6 +89,7 @@ struct AsofJoinNearestState {
}

impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {
#[inline]
fn next<F: FnMut(IdxSize) -> Option<T>>(
&mut self,
left_val: &T,
Expand Down

0 comments on commit 806c94d

Please sign in to comment.