diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 09d0c8d5ca2e..bcd88bae739a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -148,6 +148,26 @@ async fn test_count() { .await; } +#[tokio::test(flavor = "multi_thread")] +async fn test_median() { + let data_gen_config = baseline_config(); + + // Queries like SELECT median(a), median(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("median") + .with_distinct_aggregate_function("median") + // median only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + /// Return a standard set of columns for testing data generation /// /// Includes numeric and string types diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 70f192c32ae1..defbbe737a9d 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; use std::sync::Arc; -use arrow::array::{downcast_integer, ArrowNumericType}; +use arrow::array::{ + downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, + PrimitiveBuilder, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{ArrayRef, AsArray}, datatypes::{ @@ -33,12 +37,17 @@ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, +}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_functions_aggregate_common::utils::Hashable; use datafusion_macros::user_doc; @@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median { } } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let num_args = args.exprs.len(); + if num_args != 1 { + return internal_err!( + "median should only have 1 arg, but found num args:{}", + args.exprs.len() + ); + } + + let dt = args.exprs[0].data_type(args.schema)?; + + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt))) + }; + } + + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianGroupsAccumulator not supported for {} with {}", + args.name, + dt, + ))), + } + } + fn aliases(&self) -> &[String] { &[] } @@ -230,6 +278,216 @@ impl Accumulator for MedianAccumulator { } } +/// The median groups accumulator accumulates the raw input values +/// +/// For calculating the accurate medians of groups, we need to store all values +/// of groups before final evaluation. +/// So values in each group will be stored in a `Vec`, and the total group values +/// will be actually organized as a `Vec>`. +/// +#[derive(Debug)] +struct MedianGroupsAccumulator { + data_type: DataType, + group_values: Vec>, +} + +impl MedianGroupsAccumulator { + pub fn new(data_type: DataType) -> Self { + Self { + data_type, + group_values: Vec::new(), + } + } +} + +impl GroupsAccumulator for MedianGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + // Push the `not nulls + not filtered` row into its group + self.group_values.resize(total_num_groups, Vec::new()); + accumulate( + group_indices, + values, + opt_filter, + |group_index, new_value| { + self.group_values[group_index].push(new_value); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + // The merged values should be organized like as a `ListArray` which is nullable + // (input with nulls usually generated from `convert_to_state`), but `inner array` of + // `ListArray` is `non-nullable`. + // + // Following is the possible and impossible input `values`: + // + // # Possible values + // ```text + // group 0: [1, 2, 3] + // group 1: null (list array is nullable) + // group 2: [6, 7, 8] + // ... + // group n: [...] + // ``` + // + // # Impossible values + // ```text + // group x: [1, 2, null] (values in list array is non-nullable) + // ``` + // + let input_group_values = values[0].as_list::(); + + // Ensure group values big enough + self.group_values.resize(total_num_groups, Vec::new()); + + // Extend values to related groups + // TODO: avoid using iterator of the `ListArray`, this will lead to + // many calls of `slice` of its ``inner array`, and `slice` is not + // so efficient(due to the calculation of `null_count` for each `slice`). + group_indices + .iter() + .zip(input_group_values.iter()) + .for_each(|(&group_index, values_opt)| { + if let Some(values) = values_opt { + let values = values.as_primitive::(); + self.group_values[group_index].extend(values.values().iter()); + } + }); + + Ok(()) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // Emit values + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Build offsets + let mut offsets = Vec::with_capacity(self.group_values.len() + 1); + offsets.push(0); + let mut cur_len = 0_i32; + for group_value in &emit_group_values { + cur_len += group_value.len() as i32; + offsets.push(cur_len); + } + // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`, + // but safety should be considered more carefully here(and I am not sure if it can get + // performance improvement when we introduce checks to keep the safety...). + // + // Can see more details in: + // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791 + // + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + // Build inner array + let flatten_group_values = + emit_group_values.into_iter().flatten().collect::>(); + let group_values_array = + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None) + .with_data_type(self.data_type.clone()); + + // Build the result list array + let result_list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offsets, + Arc::new(group_values_array), + None, + ); + + Ok(vec![Arc::new(result_list_array)]) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + // Emit values + let emit_group_values = emit_to.take_needed(&mut self.group_values); + + // Calculate median for each group + let mut evaluate_result_builder = + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()); + for values in emit_group_values { + let median = calculate_median::(values); + evaluate_result_builder.append_option(median); + } + + Ok(Arc::new(evaluate_result_builder.finish())) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + + let input_array = values[0].as_primitive::(); + + // Directly convert the input array to states, each row will be + // seen as a respective group. + // For detail, the `input_array` will be converted to a `ListArray`. + // And if row is `not null + not filtered`, it will be converted to a list + // with only one element; otherwise, this row in `ListArray` will be set + // to null. + + // Reuse values buffer in `input_array` to build `values` in `ListArray` + let values = PrimitiveArray::::new(input_array.values().clone(), None) + .with_data_type(self.data_type.clone()); + + // `offsets` in `ListArray`, each row as a list element + let offset_end = i32::try_from(input_array.len()).map_err(|e| { + internal_datafusion_err!( + "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}" + ) + })?; + let offsets = (0..=offset_end).collect::>(); + // Safety: all checks in `OffsetBuffer::new` are ensured to pass + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + + // `nulls` for converted `ListArray` + let nulls = filtered_null_mask(opt_filter, input_array); + + let converted_list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offsets, + Arc::new(values), + nulls, + ); + + Ok(vec![Arc::new(converted_list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.group_values + .iter() + .map(|values| values.capacity() * size_of::()) + .sum::() + // account for size of self.grou_values too + + self.group_values.capacity() * size_of::>() + } +} + /// The distinct median accumulator accumulates the raw input values /// as `ScalarValue`s /// diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bd3b40089519..4838911649bd 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -67,6 +67,62 @@ statement ok CREATE TABLE test (c1 BIGINT,c2 BIGINT) as values (0,null), (1,1), (null,1), (3,2), (3,2) +statement ok +CREATE TABLE group_median_table_non_nullable ( + col_group STRING NOT NULL, + col_i8 TINYINT NOT NULL, + col_i16 SMALLINT NOT NULL, + col_i32 INT NOT NULL, + col_i64 BIGINT NOT NULL, + col_u8 TINYINT UNSIGNED NOT NULL, + col_u16 SMALLINT UNSIGNED NOT NULL, + col_u32 INT UNSIGNED NOT NULL, + col_u64 BIGINT UNSIGNED NOT NULL, + col_f32 FLOAT NOT NULL, + col_f64 DOUBLE NOT NULL, + col_f64_nan DOUBLE NOT NULL, + col_decimal128 DECIMAL(10, 4) NOT NULL, + col_decimal256 NUMERIC(10, 4) NOT NULL +) as VALUES +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1, 0.0002, 0.0002 ), +( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64'), 0.0002, 0.0002 ), +( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0005, 0.0005 ) + +statement ok +CREATE TABLE group_median_table_nullable ( + col_group STRING NOT NULL, + col_i8 TINYINT, + col_i16 SMALLINT, + col_i32 INT, + col_i64 BIGINT, + col_u8 TINYINT UNSIGNED, + col_u16 SMALLINT UNSIGNED, + col_u32 INT UNSIGNED, + col_u64 BIGINT UNSIGNED, + col_f32 FLOAT, + col_f64 DOUBLE, + col_f64_nan DOUBLE, + col_decimal128 DECIMAL(10, 4), + col_decimal256 NUMERIC(10, 4) +) as VALUES +( 'group0', NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group0', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, 1.1, 0.0002, 0.0002 ), +( 'group0', 100, 100, 100, arrow_cast(100,'Int64'), 100, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group0', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1, 0.0001, 0.0001 ), +( 'group1', -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64'), 0.0002, 0.0002 ), +( 'group1', 100, 100, 100, arrow_cast(100,'Int64'), 101, 100, 100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64'), 0.0003, 0.0003 ), +( 'group1', 125, 32766, 2147483646, arrow_cast(9223372036854775806,'Int64'), 100, 101, 4294967294, arrow_cast(100,'UInt64'), 3.2, 5.5, arrow_cast('NAN','Float64'), 0.0004, 0.0004 ), +( 'group1', 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64'), 0.0005, 0.0005 ) + ####### # Error tests ####### @@ -6203,3 +6259,208 @@ physical_plan 14)--------------PlaceholderRowExec 15)------------ProjectionExec: expr=[1 as id, 2 as foo] 16)--------------PlaceholderRowExec + +####### +# Group median test +####### + +# group median i8 non-nullable +query TI rowsort +SELECT col_group, median(col_i8) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 -14 +group1 100 + +# group median i16 non-nullable +query TI +SELECT col_group, median(col_i16) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 -16334 +group1 100 + +# group median i32 non-nullable +query TI +SELECT col_group, median(col_i32) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 -1073741774 +group1 100 + +# group median i64 non-nullable +query TI +SELECT col_group, median(col_i64) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 -4611686018427387854 +group1 100 + +# group median u8 non-nullable +query TI rowsort +SELECT col_group, median(col_u8) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u16 non-nullable +query TI +SELECT col_group, median(col_u16) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u32 non-nullable +query TI +SELECT col_group, median(col_u32) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u64 non-nullable +query TI +SELECT col_group, median(col_u64) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median f32 non-nullable +query TR +SELECT col_group, median(col_f32) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.2 + +# group median f64 non-nullable +query TR +SELECT col_group, median(col_f64) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.3 + +# group median f64_nan non-nullable +query TR +SELECT col_group, median(col_f64_nan) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 NaN +group1 NaN + +# group median decimal128 non-nullable +query TR +SELECT col_group, median(col_decimal128) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median decimal256 non-nullable +query TR +SELECT col_group, median(col_decimal256) FROM group_median_table_non_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median i8 nullable +query TI rowsort +SELECT col_group, median(col_i8) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -14 +group1 100 + +# group median i16 nullable +query TI rowsort +SELECT col_group, median(col_i16) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -16334 +group1 100 + +# group median i32 nullable +query TI rowsort +SELECT col_group, median(col_i32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -1073741774 +group1 100 + +# group median i64 nullable +query TI rowsort +SELECT col_group, median(col_i64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 -4611686018427387854 +group1 100 + +# group median u8 nullable +query TI rowsort +SELECT col_group, median(col_u8) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u16 nullable +query TI rowsort +SELECT col_group, median(col_u16) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u32 nullable +query TI rowsort +SELECT col_group, median(col_u32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median u64 nullable +query TI rowsort +SELECT col_group, median(col_u64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 50 +group1 100 + +# group median f32 nullable +query TR rowsort +SELECT col_group, median(col_f32) FROM group_median_table_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.2 + +# group median f64 nullable +query TR rowsort +SELECT col_group, median(col_f64) FROM group_median_table_nullable GROUP BY col_group +---- +group0 2.75 +group1 3.3 + +# group median f64_nan nullable +query TR rowsort +SELECT col_group, median(col_f64_nan) FROM group_median_table_nullable GROUP BY col_group +---- +group0 NaN +group1 NaN + +# group median decimal128 nullable +query TR rowsort +SELECT col_group, median(col_decimal128) FROM group_median_table_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# group median decimal256 nullable +query TR rowsort +SELECT col_group, median(col_decimal256) FROM group_median_table_nullable GROUP BY col_group +---- +group0 0.0002 +group1 0.0003 + +# median with all nulls +statement ok +create table group_median_all_nulls( + a STRING NOT NULL, + b INT +) AS VALUES +( 'group0', NULL), +( 'group0', NULL), +( 'group0', NULL), +( 'group1', NULL), +( 'group1', NULL), +( 'group1', NULL) + +query TIT rowsort +SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP BY a +---- +group0 NULL Int32 +group1 NULL Int32