-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve speed of median
by implementing special GroupsAccumulator
#13681
base: main
Are you sure you want to change the base?
Changes from all commits
cff822e
7f10006
6f172ef
cacc693
11e6753
955036f
17bd90b
28d8716
1244df4
c812350
fdc9b33
e2f384f
4b8a4ad
5603bc0
7e6a73a
1c7b57a
5eb7711
5a52e7c
6f56a63
e963d50
5fd9d8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter}; | |||||||||||||||||||||||||||||||||||||||||||||||||||
use std::mem::{size_of, size_of_val}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use std::sync::Arc; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::array::{downcast_integer, ArrowNumericType}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::array::{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
PrimitiveBuilder, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::buffer::{OffsetBuffer, ScalarBuffer}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
array::{ArrayRef, AsArray}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
datatypes::{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -33,12 +37,17 @@ use arrow::array::Array; | |||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::array::ArrowNativeTypeOp; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_common::{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_expr::function::StateFieldsArgs; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_expr::{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Documentation, Signature, Volatility, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_expr::{EmitTo, GroupsAccumulator}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_functions_aggregate_common::utils::Hashable; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
use datafusion_macros::user_doc; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median { | |||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
!args.is_distinct | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn create_groups_accumulator( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
&self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
args: AccumulatorArgs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Result<Box<dyn GroupsAccumulator>> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let num_args = args.exprs.len(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if num_args != 1 { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return internal_err!( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"median should only have 1 arg, but found num args:{}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
args.exprs.len() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
let dt = args.exprs[0].data_type(args.schema)?; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
macro_rules! helper { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
($t:ty, $dt:expr) => { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
downcast_integer! { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dt => (helper, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
DataType::Float16 => helper!(Float16Type, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
DataType::Float32 => helper!(Float32Type, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
DataType::Float64 => helper!(Float64Type, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
_ => Err(DataFusionError::NotImplemented(format!( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"MedianGroupsAccumulator not supported for {} with {}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
args.name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
dt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
))), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn aliases(&self) -> &[String] { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
&[] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -230,6 +278,216 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> { | |||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
/// The median groups accumulator accumulates the raw input values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// For calculating the accurate medians of groups, we need to store all values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// of groups before final evaluation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// So values in each group will be stored in a `Vec<T>`, and the total group values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// will be actually organized as a `Vec<Vec<T>>`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given it is important to track the median values for each group separately I don't really see a way around Vec/Vec -- I think it is the simplest version and will have pretty reasonable performance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I tried not to use |
||||||||||||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
#[derive(Debug)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
struct MedianGroupsAccumulator<T: ArrowNumericType + Send> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type: DataType, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
group_values: Vec<Vec<T::Native>>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wonder -- using P.S. asking just because when I was doing +- same for count distinct (PR), the performance for GroupsAccumulator with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think among other things, the intermediate state management (creating ListArrays directly rather than from ScalarValue) probably helps a lot: There is also an extra allocation per group when using the groups accumulator adapter thingie That being said, it is a fair question how much better the existing MedianAccumulator could be if it built the ListArrays as does this PR directly 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @korowa I think what mentioned by @alamb is a important point about the improvement. Following are some other points for me:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was some improvements, but overall results for clickbench q9 (I was mostly looking at this query) were like x2.63 for GroupsAccumulator, and x2.30 for the regular Accumulator -- so it would be like 13-15% overall difference, which is not as massive as this PR results. However, maybe things has changed in GroupsAccumulator implementation, and now even plain UPD: and, yes, maybe producing state, as pointed out by @alamb above, was (at least partially) the cause of non-significant improvement -- in count distinct it was implemented via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seem really worth seeking the reason more deeply. |
||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
pub fn new(data_type: DataType) -> Self { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Self { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
data_type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
group_values: Vec::new(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
fn update_batch( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
&mut self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
values: &[ArrayRef], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
group_indices: &[usize], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
opt_filter: Option<&BooleanArray>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
total_num_groups: usize, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
assert_eq!(values.len(), 1, "single argument to update_batch"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let values = values[0].as_primitive::<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Push the `not nulls + not filtered` row into its group | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.group_values.resize(total_num_groups, Vec::new()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
accumulate( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
group_indices, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
values, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
opt_filter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|group_index, new_value| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.group_values[group_index].push(new_value); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn merge_batch( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
&mut self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
values: &[ArrayRef], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
group_indices: &[usize], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Since aggregate filter should be applied in partial stage, in final stage there should be no filter | ||||||||||||||||||||||||||||||||||||||||||||||||||||
_opt_filter: Option<&BooleanArray>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
total_num_groups: usize, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
assert_eq!(values.len(), 1, "one argument to merge_batch"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// The merged values should be organized like as a `ListArray` which is nullable | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// (input with nulls usually generated from `convert_to_state`), but `inner array` of | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// `ListArray` is `non-nullable`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Following is the possible and impossible input `values`: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// # Possible values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// ```text | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// group 0: [1, 2, 3] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// group 1: null (list array is nullable) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// group 2: [6, 7, 8] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// ... | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// group n: [...] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// # Impossible values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// ```text | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// group x: [1, 2, null] (values in list array is non-nullable) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let input_group_values = values[0].as_list::<i32>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Ensure group values big enough | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.group_values.resize(total_num_groups, Vec::new()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Extend values to related groups | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: avoid using iterator of the `ListArray`, this will lead to | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// many calls of `slice` of its ``inner array`, and `slice` is not | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// so efficient(due to the calculation of `null_count` for each `slice`). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's safe to directly use the value without checking null, null values should be ignored during accumulation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🤔 The And batch like:
will be converted to a list like:
I think we can implement a simple version for correctness firstly. |
||||||||||||||||||||||||||||||||||||||||||||||||||||
group_indices | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.iter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.zip(input_group_values.iter()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.for_each(|(&group_index, values_opt)| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if let Some(values) = values_opt { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let values = values.as_primitive::<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.group_values[group_index].extend(values.values().iter()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Emit values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let emit_group_values = emit_to.take_needed(&mut self.group_values); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Build offsets | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let mut offsets = Vec::with_capacity(self.group_values.len() + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
offsets.push(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let mut cur_len = 0_i32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for group_value in &emit_group_values { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
cur_len += group_value.len() as i32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
offsets.push(cur_len); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// but safety should be considered more carefully here(and I am not sure if it can get | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// performance improvement when we introduce checks to keep the safety...). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Can see more details in: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// https://github.com/apache/datafusion/pull/13681#discussion_r1931209791 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Build inner array | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let flatten_group_values = | ||||||||||||||||||||||||||||||||||||||||||||||||||||
emit_group_values.into_iter().flatten().collect::<Vec<_>>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let group_values_array = | ||||||||||||||||||||||||||||||||||||||||||||||||||||
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.with_data_type(self.data_type.clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Build the result list array | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let result_list_array = ListArray::new( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Arc::new(Field::new_list_field(self.data_type.clone(), true)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
offsets, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Arc::new(group_values_array), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(vec![Arc::new(result_list_array)]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Emit values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let emit_group_values = emit_to.take_needed(&mut self.group_values); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Calculate median for each group | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let mut evaluate_result_builder = | ||||||||||||||||||||||||||||||||||||||||||||||||||||
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for values in emit_group_values { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let median = calculate_median::<T>(values); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
evaluate_result_builder.append_option(median); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(Arc::new(evaluate_result_builder.finish())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn convert_to_state( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
&self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
values: &[ArrayRef], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
opt_filter: Option<&BooleanArray>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Result<Vec<ArrayRef>> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
assert_eq!(values.len(), 1, "one argument to merge_batch"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
let input_array = values[0].as_primitive::<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Directly convert the input array to states, each row will be | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// seen as a respective group. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// For detail, the `input_array` will be converted to a `ListArray`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// And if row is `not null + not filtered`, it will be converted to a list | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// with only one element; otherwise, this row in `ListArray` will be set | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// to null. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Reuse values buffer in `input_array` to build `values` in `ListArray` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.with_data_type(self.data_type.clone()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// `offsets` in `ListArray`, each row as a list element | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let offset_end = i32::try_from(input_array.len()).map_err(|e| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
internal_datafusion_err!( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"cast array_len to i32 failed in convert_to_state of group median, err:{e:?}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
})?; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let offsets = (0..=offset_end).collect::<Vec<_>>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Safety: all checks in `OffsetBuffer::new` are ensured to pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// `nulls` for converted `ListArray` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
let nulls = filtered_null_mask(opt_filter, input_array); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
let converted_list_array = ListArray::new( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Arc::new(Field::new_list_field(self.data_type.clone(), true)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
offsets, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Arc::new(values), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
nulls, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Ok(vec![Arc::new(converted_list_array)]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn supports_convert_to_state(&self) -> bool { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
true | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
fn size(&self) -> usize { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.group_values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.iter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.map(|values| values.capacity() * size_of::<T>()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
.sum::<usize>() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Rachelint marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// account for size of self.grou_values too | ||||||||||||||||||||||||||||||||||||||||||||||||||||
+ self.group_values.capacity() * size_of::<Vec<T>>() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
/// The distinct median accumulator accumulates the raw input values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// as `ScalarValue`s | ||||||||||||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️