From 806c94d0fc61e30271d885978390d4452a1708cc Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 24 Oct 2023 16:40:42 +0200 Subject: [PATCH] perf: fix regression non-null asof join --- .../polars-ops/src/frame/join/asof/default.rs | 32 +++++++++++++++---- .../polars-ops/src/frame/join/asof/groups.rs | 26 +++++++++------ crates/polars-ops/src/frame/join/asof/mod.rs | 3 ++ 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/crates/polars-ops/src/frame/join/asof/default.rs b/crates/polars-ops/src/frame/join/asof/default.rs index 377b2847c8b3..e0b2c432def8 100644 --- a/crates/polars-ops/src/frame/join/asof/default.rs +++ b/crates/polars-ops/src/frame/join/asof/default.rs @@ -22,16 +22,36 @@ where let mut mask = vec![0; (left.len() + 7) / 8]; let mut state = S::default(); - for (i, opt_val_l) in left.iter().enumerate() { - if let Some(val_l) = opt_val_l { - if let Some(r_idx) = - state.next(&val_l, |j| right.get(j as usize), right.len() as IdxSize) - { - let val_r = unsafe { right.get_unchecked(r_idx as usize).unwrap_unchecked() }; + if left.null_count() == 0 && right.null_count() == 0 { + for (i, val_l) in left.values_iter().enumerate() { + if let Some(r_idx) = state.next( + &val_l, + // SAFETY: next() only calls with indices < right.len(). + |j| Some(unsafe { right.value_unchecked(j as usize) }), + right.len() as IdxSize, + ) { + // SAFETY: r_idx is non-null and valid. + let val_r = unsafe { right.value_unchecked(r_idx as usize) }; out[i] = r_idx; mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8); } } + } else { + for (i, opt_val_l) in left.iter().enumerate() { + if let Some(val_l) = opt_val_l { + if let Some(r_idx) = state.next( + &val_l, + // SAFETY: next() only calls with indices < right.len(). + |j| unsafe { right.get_unchecked(j as usize) }, + right.len() as IdxSize, + ) { + // SAFETY: r_idx is non-null and valid. + let val_r = unsafe { right.value_unchecked(r_idx as usize) }; + out[i] = r_idx; + mask[i / 8] |= (filter(val_l, val_r) as u8) << (i % 8); + } + } + } } let bitmap = Bitmap::try_new(mask, out.len()).unwrap(); diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index bf5dd6dd85bc..b8b00b2d00ab 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -41,15 +41,23 @@ where let id = right_grp_idxs.first()?; let grp_state = group_states.entry(*id).or_default(); - let r_grp_idx = grp_state.next( - &left_val, - |i| right_val_arr.get(right_grp_idxs[i as usize] as usize), - right_grp_idxs.len() as IdxSize, - )?; - - let r_idx = right_grp_idxs[r_grp_idx as usize]; - let right_val = right_val_arr.get(r_idx as usize).unwrap(); - filter(left_val, right_val).then_some(r_idx) + unsafe { + let r_grp_idx = grp_state.next( + &left_val, + |i| { + // SAFETY: the group indices are valid, and next() only calls with + // i < right_grp_idxs.len(). + right_val_arr.get_unchecked(*right_grp_idxs.get_unchecked(i as usize) as usize) + }, + right_grp_idxs.len() as IdxSize, + )?; + + // SAFETY: r_grp_idx is valid, as is r_idx (which must be non-null) if + // we get here. + let r_idx = *right_grp_idxs.get_unchecked(r_grp_idx as usize); + let right_val = right_val_arr.value_unchecked(r_idx as usize); + filter(left_val, right_val).then_some(r_idx) + } } fn asof_join_by_numeric( diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs index c777344ec3ec..f1339d0cc0d2 100644 --- a/crates/polars-ops/src/frame/join/asof/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -33,6 +33,7 @@ struct AsofJoinForwardState { } impl AsofJoinState for AsofJoinForwardState { + #[inline] fn next Option>( &mut self, left_val: &T, @@ -59,6 +60,7 @@ struct AsofJoinBackwardState { } impl AsofJoinState for AsofJoinBackwardState { + #[inline] fn next Option>( &mut self, left_val: &T, @@ -87,6 +89,7 @@ struct AsofJoinNearestState { } impl AsofJoinState for AsofJoinNearestState { + #[inline] fn next Option>( &mut self, left_val: &T,