From 98fbec602572fe8d2c1eff98f3dd5682c5ed1993 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Fri, 3 Nov 2023 10:05:20 -0400 Subject: [PATCH 01/21] Expose metrics from FileSinkExec impl of ExecutionPlan --- datafusion/physical-plan/src/insert.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 81cdfd753fe6..296b49526449 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -278,6 +278,11 @@ impl ExecutionPlan for FileSinkExec { stream, ))) } + + /// Returns the metrics of the underlying [DataSink] + fn metrics(&self) -> Option { + self.sink.metrics() + } } /// Create a output record batch with a count From cb3341f8b9dfc737d985ac97b3230c73ddc72aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 30 Nov 2023 17:53:11 +0000 Subject: [PATCH 02/21] Add pool_size method to MemoryPool (#218) * Add pool_size method to MemoryPool * Fix * Fmt --- datafusion/execution/src/memory_pool/mod.rs | 7 ++++++- datafusion/execution/src/memory_pool/pool.rs | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 55555014f2ef..15880e14e06d 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -75,6 +75,9 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Return the total amount of memory reserved fn reserved(&self) -> usize; + + /// Return the configured pool size (if any) + fn pool_size(&self) -> Option; } /// A memory consumer that can be tracked by [`MemoryReservation`] in @@ -286,7 +289,9 @@ mod tests { #[test] fn test_memory_pool_underflow() { - let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let pool: Arc = Arc::new(GreedyMemoryPool::new(50)) as _; + assert_eq!(pool.pool_size(), Some(50)); + let mut a1 = MemoryConsumer::new("a1").register(&pool); assert_eq!(pool.reserved(), 0); diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 4a491630fe20..bd9f818a7aad 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -44,6 +44,10 @@ impl MemoryPool for UnboundedMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn pool_size(&self) -> Option { + None + } } /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. @@ -96,6 +100,10 @@ impl MemoryPool for GreedyMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn pool_size(&self) -> Option { + Some(self.pool_size) + } } /// A [`MemoryPool`] that prevents spillable reservations from using more than @@ -229,6 +237,10 @@ impl MemoryPool for FairSpillPool { let state = self.state.lock(); state.spillable + state.unspillable } + + fn pool_size(&self) -> Option { + Some(self.pool_size) + } } fn insufficient_capacity_err( @@ -246,7 +258,8 @@ mod tests { #[test] fn test_fair() { - let pool = Arc::new(FairSpillPool::new(100)) as _; + let pool: Arc = Arc::new(FairSpillPool::new(100)) as _; + assert_eq!(pool.pool_size(), Some(100)); let mut r1 = MemoryConsumer::new("unspillable").register(&pool); // Can grow beyond capacity of pool From 425247eed3e9f4017274d7d1ca215c76aebea1fb Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Thu, 22 Feb 2024 06:03:10 -0500 Subject: [PATCH 03/21] Handle hashing list of struct arrays --- datafusion/common/src/hash_utils.rs | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 9198461e00bf..b751bb7171c9 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -240,6 +240,33 @@ where Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let row_len = array.len(); + + let valid_row_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..row_len).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + let mut values_hashes = vec![0u64; row_len]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + for i in valid_row_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + + /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) /// @@ -327,6 +354,10 @@ pub fn create_hashes<'a>( array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::Struct(_) => { + let array = as_struct_array(array); + hash_struct_array(array, random_state, hashes_buffer)?; + } DataType::List(_) => { let array = as_list_array(array); hash_list_array(array, random_state, hashes_buffer)?; From cef9a847c4992cabc81a5fa17959cadca843c4c3 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 4 Jan 2024 01:50:46 +0800 Subject: [PATCH 04/21] Minor: Introduce utils::hash for StructArray (#8552) * hash struct Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * row-wise hash Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * create hashes once Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/hash_utils.rs | 90 +++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index b751bb7171c9..69d495f477ae 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, + as_primitive_array, as_string_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -207,6 +208,35 @@ fn hash_dictionary( Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let num_columns = array.num_columns(); + + // Skip null columns + let valid_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..num_columns).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + // array.len() is the number of rows. + let mut values_hashes = vec![0u64; array.len()]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + // Skip the null columns, nulls should get hash value 0. + for i in valid_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -355,15 +385,15 @@ pub fn create_hashes<'a>( _ => unreachable!() } DataType::Struct(_) => { - let array = as_struct_array(array); + let array = as_struct_array(array)?; hash_struct_array(array, random_state, hashes_buffer)?; } DataType::List(_) => { - let array = as_list_array(array); + let array = as_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } DataType::LargeList(_) => { - let array = as_large_list_array(array); + let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } _ => { @@ -546,6 +576,58 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays() { + use arrow_buffer::Buffer; + + let boolarr = Arc::new(BooleanArray::from(vec![ + false, false, true, true, true, true, + ])); + let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31])); + + let struct_array = StructArray::from(( + vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ], + Buffer::from(&[0b001011]), + )); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_null(2)); + assert!(struct_array.is_valid(3)); + assert!(struct_array.is_null(4)); + assert!(struct_array.is_null(5)); + + let array = Arc::new(struct_array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + // same value but the third row ( hashes[2] ) is null + assert_ne!(hashes[2], hashes[3]); + // different values but both are null + assert_eq!(hashes[4], hashes[5]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] From f69abae51f979cdba66f00d8e6b5ba0afc176ec6 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 6 Jan 2024 22:34:57 +0800 Subject: [PATCH 05/21] Minor: Fix incorrect indices for hashing struct (#8775) * fix bug Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * add rowsort Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/hash_utils.rs | 46 ++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 69d495f477ae..e18e8ac45be0 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -214,22 +214,19 @@ fn hash_struct_array( hashes_buffer: &mut [u64], ) -> Result<()> { let nulls = array.nulls(); - let num_columns = array.num_columns(); + let row_len = array.len(); - // Skip null columns - let valid_indices: Vec = if let Some(nulls) = nulls { + let valid_row_indices: Vec = if let Some(nulls) = nulls { nulls.valid_indices().collect() } else { - (0..num_columns).collect() + (0..row_len).collect() }; // Create hashes for each row that combines the hashes over all the column at that row. - // array.len() is the number of rows. - let mut values_hashes = vec![0u64; array.len()]; + let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - // Skip the null columns, nulls should get hash value 0. - for i in valid_indices { + for i in valid_row_indices { let hash = &mut hashes_buffer[i]; *hash = combine_hashes(*hash, values_hashes[i]); } @@ -628,6 +625,39 @@ mod tests { assert_eq!(hashes[4], hashes[5]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays_more_column_than_row() { + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-1", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-2", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ( + Arc::new(Field::new("i32-3", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef, + ), + ]); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + + let array = Arc::new(struct_array) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] From 70797a4717cb580e7dde509330b5285428018f58 Mon Sep 17 00:00:00 2001 From: Dan Harris Date: Mon, 26 Feb 2024 12:06:01 -0500 Subject: [PATCH 06/21] Cherry-pick fixes for hashing Struct values --- datafusion/common/src/hash_utils.rs | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index e18e8ac45be0..6a77cd41c201 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -267,32 +267,6 @@ where Ok(()) } -fn hash_struct_array( - array: &StructArray, - random_state: &RandomState, - hashes_buffer: &mut [u64], -) -> Result<()> { - let nulls = array.nulls(); - let row_len = array.len(); - - let valid_row_indices: Vec = if let Some(nulls) = nulls { - nulls.valid_indices().collect() - } else { - (0..row_len).collect() - }; - - // Create hashes for each row that combines the hashes over all the column at that row. - let mut values_hashes = vec![0u64; row_len]; - create_hashes(array.columns(), random_state, &mut values_hashes)?; - - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); - } - - Ok(()) -} - /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) From fce8bf81dca8e81b64f1f3b7b4419c632dc74360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 13 Mar 2024 09:02:43 +0100 Subject: [PATCH 07/21] Fix ApproxPercentileAccumulator on zero values (#9582) * Fix ApproxPercentileAccumulator * Imports * Use return type --- .../physical-expr/src/aggregate/approx_percentile_cont.rs | 6 +++--- datafusion/sqllogictest/test_files/aggregate.slt | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index aa4749f64ae9..e7997b316fd2 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -28,8 +28,8 @@ use arrow::{ datatypes::{DataType, Field}, }; use datafusion_common::{ - downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, - Result, ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::Accumulator; use std::{any::Any, iter, sync::Arc}; @@ -394,7 +394,7 @@ impl Accumulator for ApproxPercentileAccumulator { fn evaluate(&self) -> Result { if self.digest.count() == 0.0 { - return exec_err!("aggregate function needs at least one non-null element"); + return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 78575c9dffc5..9e27b42b0e2b 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2197,9 +2197,10 @@ select median(a) from (select 1 as a where 1=0); ---- NULL -query error DataFusion error: Execution error: aggregate function needs at least one non-null element +query I select approx_median(a) from (select 1 as a where 1=0); - +---- +NULL # aggregate_decimal_sum query RT From d88e4146eafdf5a0a7d444fdea4cee179e0517aa Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 19 Mar 2024 12:14:13 -0600 Subject: [PATCH 08/21] Support Union types in `ScalarValue` (#9683) --- datafusion/common/src/scalar.rs | 104 ++++++- datafusion/physical-plan/src/filter.rs | 35 +++ datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 26 +- .../proto/src/logical_plan/from_proto.rs | 35 +++ datafusion/proto/src/logical_plan/to_proto.rs | 28 ++ 7 files changed, 503 insertions(+), 12 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d730fbf89b72..4359280caa50 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -48,6 +48,8 @@ use arrow::{ use arrow_array::cast::as_list_array; use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::Buffer; +use arrow_schema::{UnionFields, UnionMode}; /// A dynamically typed, nullable single value, (the single-valued counter-part /// to arrow's [`Array`]) @@ -187,6 +189,11 @@ pub enum ScalarValue { DurationNanosecond(Option), /// struct of nested ScalarValue Struct(Option>, Fields), + /// A nested datatype that can represent slots of differing types. Components: + /// `.0`: a tuple of union `type_id` and the single value held by this Scalar + /// `.1`: the list of fields, zero-to-one of which will by set in `.0` + /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came + Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -287,6 +294,10 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => { + val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) + } + (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -448,6 +459,14 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Union(v1, t1, m1), Union(v2, t2, m2)) => { + if t1.eq(t2) && m1.eq(m2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -546,6 +565,11 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + Union(v, t, m) => { + v.hash(state); + t.hash(state); + m.hash(state); + } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -968,6 +992,7 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), + ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1167,6 +1192,7 @@ impl ScalarValue { ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), ScalarValue::Struct(v, _) => v.is_none(), + ScalarValue::Union(v, _, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1992,6 +2018,39 @@ impl ScalarValue { new_null_array(&dt, size) } }, + ScalarValue::Union(value, fields, _mode) => match value { + Some((v_id, value)) => { + let mut field_type_ids = Vec::::with_capacity(fields.len()); + let mut child_arrays = + Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + for (f_id, field) in fields.iter() { + let ar = if f_id == *v_id { + value.to_array_of_size(size)? + } else { + let dt = field.data_type(); + new_null_array(dt, size) + }; + let field = (**field).clone(); + child_arrays.push((field, ar)); + field_type_ids.push(f_id); + } + let type_ids = repeat(*v_id).take(size).collect::>(); + let type_ids = Buffer::from_slice_ref(type_ids); + let value_offsets: Option = None; + let ar = UnionArray::try_new( + field_type_ids.as_slice(), + type_ids, + value_offsets, + child_arrays, + ) + .map_err(|e| DataFusionError::ArrowError(e))?; + Arc::new(ar) + } + None => { + let dt = self.data_type(); + new_null_array(&dt, size) + } + }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2492,6 +2551,9 @@ impl ScalarValue { ScalarValue::Struct(_, _) => { return _not_impl_err!("Struct is not supported yet") } + ScalarValue::Union(_, _, _) => { + return _not_impl_err!("Union is not supported yet") + } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2560,22 +2622,31 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), - ScalarValue::Struct(vals, fields) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::Struct(vals, fields) => { + vals.as_ref() + .map(|vals| { + vals.iter() + .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .sum::() + + (std::mem::size_of::() * vals.capacity()) + }) + .unwrap_or_default() + // `fields` is boxed, so it is NOT already included in `self` + + std::mem::size_of_val(fields) + + (std::mem::size_of::() * fields.len()) + + fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::() + } + ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() - .map(|vals| { - vals.iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) - .sum::() - + (std::mem::size_of::() * vals.capacity()) - }) + .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) .unwrap_or_default() // `fields` is boxed, so it is NOT already included in `self` + std::mem::size_of_val(fields) + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::() + + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` @@ -2873,6 +2944,9 @@ impl TryFrom<&DataType> for ScalarValue { 1, )), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -2971,6 +3045,10 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{}:{}", id, val)?, + None => write!(f, "NULL")?, + }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3069,6 +3147,10 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "Union {}:{}", id, val), + None => write!(f, "Union(NULL)"), + }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 56a1b4e17821..785ebb736409 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -408,7 +408,9 @@ mod tests { use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use crate::empty::EmptyExec; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; @@ -1057,4 +1059,37 @@ mod tests { assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) } + + #[test] + fn test_equivalence_properties_union_type() -> Result<()> { + let union_type = DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", union_type, true), + ])); + + let exec = FilterExec::try_new( + binary( + binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?, + Operator::And, + binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, + &schema, + )?, + Arc::new(EmptyExec::new(schema.clone())), + )?; + + exec.statistics().unwrap(); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd8053c817e7..0af599128e66 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -961,6 +961,20 @@ message StructValue { repeated Field fields = 3; } +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1015,6 +1029,7 @@ message ScalarValue{ IntervalMonthDayNanoValue interval_month_day_nano = 31; StructValue struct_value = 32; ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 88310be0318a..019b57471b10 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22297,6 +22297,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } } } struct_ser.end() @@ -22381,6 +22384,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "structValue", "fixed_size_binary_value", "fixedSizeBinaryValue", + "union_value", + "unionValue", ]; #[allow(clippy::enum_variant_names)] @@ -22421,6 +22426,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { IntervalMonthDayNano, StructValue, FixedSizeBinaryValue, + UnionValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22478,6 +22484,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "structValue" | "struct_value" => Ok(GeneratedField::StructValue), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22727,6 +22734,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) ; } } @@ -25150,6 +25164,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UnionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -25312,6 +25437,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3dfd3938615f..1e5cd69d5449 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1198,6 +1198,28 @@ pub struct StructValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1209,7 +1231,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34, 42" )] pub value: ::core::option::Option, } @@ -1293,6 +1315,8 @@ pub mod scalar_value { StructValue(super::StructValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 193e0947d6d9..5a2bf8d6ecdb 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -846,6 +846,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Self::Struct(values, fields) } + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = fields + .iter() + .map(Field::try_from) + .collect::, _>>()?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..3bd2ccf9082f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use crate::protobuf::{ }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, + UnionField, UnionValue, }; use arrow::{ datatypes::{ @@ -1446,6 +1447,33 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { From b2225c4326470497fe1b85148745cf5a8815cbae Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Thu, 4 Apr 2024 18:33:37 +0200 Subject: [PATCH 09/21] Proof of concept for GroupsAccumulator for ArrayAgg --- .../physical-expr/src/aggregate/array_agg.rs | 195 +++++++++++++++++- .../groups_accumulator/accumulate.rs | 77 ++++++- 2 files changed, 265 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 91d5c867d312..68d359018557 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -19,17 +19,20 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; +use arrow_array::cast::AsArray; +use arrow_array::types::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; +use crate::aggregate::groups_accumulator::accumulate::{accumulate_array, accumulate_array_elements, NullState}; /// ARRAY_AGG aggregate expression #[derive(Debug)] @@ -96,6 +99,29 @@ impl AggregateExpr for ArrayAgg { fn name(&self) -> &str { &self.name } + + fn groups_accumulator_supported(&self) -> bool { + self.input_data_type.is_primitive() + } + + fn create_groups_accumulator(&self) -> Result> { + match self.input_data_type { + DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + _ => Err(DataFusionError::Internal(format!( + "ArrayAggGroupsAccumulator not supported for data type {:?}", + self.input_data_type + ))) + } + } } impl PartialEq for ArrayAgg { @@ -187,11 +213,137 @@ impl Accumulator for ArrayAggAccumulator { } } +struct ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: Vec::Native>>>>, + null_state: NullState, +} + +impl ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + pub fn new() -> Self { + Self { + values: vec![], + null_state: NullState::new(), + } + } +} + +impl GroupsAccumulator for ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send + Sync, +{ + + // TODO: + // 1. Implement support for null state + // 2. Implement support for low level ListArray creation api with offsets and nulls + // 3. Implement support for variable size types such as Utf8 + // 4. Implement support for accumulating Lists of any level of nesting + // 5. Use this group accumulator in array_agg_distinct.rs + + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + for _ in self.values.len()..total_num_groups { + self.values.push(None); + } + + accumulate_array_elements( + group_indices, + values, + opt_filter, + |group_index, new_value| { + if let Some(array) = &mut self.values[group_index] { + array.push(Some(new_value)); + } else { + self.values[group_index] = Some(vec![Some(new_value)]); + } + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + for _ in self.values.len()..total_num_groups { + self.values.push(None); + } + + accumulate_array( + group_indices, + values, + opt_filter, + |group_index, new_value: &arrow_array::PrimitiveArray| { + if let Some(value) = &mut self.values[group_index] { + new_value.iter().for_each(|v| { + value.push(v); + }); + } else { + self.values[group_index] = Some(new_value.iter().collect()); + } + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + // let nulls = self.null_state.build(emit_to); + + // assert_eq!(array.len(), nulls.len()); + + Ok(Arc::new(ListArray::from_iter_primitive::(array))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + + // TODO: do we need null state? + // let nulls = self.null_state.build(emit_to); + // let nulls = Some(nulls); + + let values = emit_to.take_needed(&mut self.values); + let values = ListArray::from_iter_primitive::(values); + + Ok(vec![Arc::new(values) as ArrayRef]) + } + + fn size(&self) -> usize { + self.values.capacity() + + self.values + .iter() + .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) + .sum::() * std::mem::size_of::() + + self.null_state.size() + } +} + #[cfg(test)] mod tests { use super::*; use crate::expressions::col; - use crate::expressions::tests::aggregate; + use crate::expressions::tests::{aggregate, aggregate_new}; + use crate::{generic_test_op, generic_test_op_new}; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; @@ -226,6 +378,34 @@ mod tests { }}; } + macro_rules! test_op_new { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + generic_test_op_new!( + $ARRAY, + $DATATYPE, + $OP, + $EXPECTED, + $EXPECTED.data_type().clone() + ) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate_new(&batch, agg)?; + assert_eq!($EXPECTED, &actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -237,9 +417,12 @@ mod tests { Some(4), Some(5), ])]); - let list = ScalarValue::List(Arc::new(list)); + let expected = ScalarValue::List(Arc::new(list.clone())); + + test_op!(a.clone(), DataType::Int32, ArrayAgg, expected, DataType::Int32); - test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + let expected: ArrayRef = Arc::new(list); + test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32) } #[test] diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 596265a737da..3f2425ca6dc0 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,7 +20,8 @@ //! [`GroupsAccumulator`]: crate::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, BooleanArray, PrimitiveArray}; +use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; +use arrow_array::cast::AsArray; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -437,6 +438,80 @@ pub fn accumulate_indices( } } +pub fn accumulate_array_elements( + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + F: FnMut(usize, ::Native) + Send, + T: ArrowPrimitiveType + Send +{ + assert_eq!(values.len(), group_indices.len()); + + match opt_filter { + // no filter, + None => { + let iter = values.iter(); + group_indices.iter().zip(iter).for_each( + |(&group_index, new_value)| { + value_fn(group_index, new_value.unwrap()) + }, + ) + } + // a filter + Some(filter) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value.unwrap()); + } + }) + } + } +} + +pub fn accumulate_array( + group_indices: &[usize], + values: &ListArray, + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + F: FnMut(usize, &PrimitiveArray) + Send, + T: ArrowPrimitiveType + Send +{ + assert_eq!(values.len(), group_indices.len()); + + match opt_filter { + // no filter, + None => { + let iter = values.iter(); + group_indices.iter().zip(iter).for_each( + |(&group_index, new_value)| { + value_fn(group_index, new_value.unwrap().as_primitive::()) + }, + ) + } + // a filter + Some(filter) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value.unwrap().as_primitive::()); + } + }) + } + } +} + /// Ensures that `builder` contains a `BooleanBufferBuilder with at /// least `total_num_groups`. /// From b637f84b806071fd0d419f2f2ab91bca99563c27 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Thu, 4 Apr 2024 18:38:08 +0200 Subject: [PATCH 10/21] fmt --- .../physical-expr/src/aggregate/array_agg.rs | 79 +++++++++++++------ .../groups_accumulator/accumulate.rs | 24 +++--- 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 68d359018557..da5ab4e14e01 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,22 +17,27 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use crate::aggregate::groups_accumulator::accumulate::{ + accumulate_array, accumulate_array_elements, NullState, +}; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; -use datafusion_common::{DataFusionError, Result}; use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; -use arrow_array::cast::AsArray; -use arrow_array::types::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; -use crate::aggregate::groups_accumulator::accumulate::{accumulate_array, accumulate_array_elements, NullState}; /// ARRAY_AGG aggregate expression #[derive(Debug)] @@ -107,19 +112,37 @@ impl AggregateExpr for ArrayAgg { fn create_groups_accumulator(&self) -> Result> { match self.input_data_type { DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt8 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } _ => Err(DataFusionError::Internal(format!( "ArrayAggGroupsAccumulator not supported for data type {:?}", self.input_data_type - ))) + ))), } } } @@ -237,7 +260,6 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, { - // TODO: // 1. Implement support for null state // 2. Implement support for low level ListArray creation api with offsets and nulls @@ -250,7 +272,7 @@ where values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - total_num_groups: usize + total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); let values = values[0].as_primitive::(); @@ -280,7 +302,7 @@ where values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - total_num_groups: usize + total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to merge_batch"); let values = values[0].as_list(); @@ -317,7 +339,6 @@ where } fn state(&mut self, emit_to: EmitTo) -> Result> { - // TODO: do we need null state? // let nulls = self.null_state.build(emit_to); // let nulls = Some(nulls); @@ -329,12 +350,14 @@ where } fn size(&self) -> usize { - self.values.capacity() + - self.values - .iter() - .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) - .sum::() * std::mem::size_of::() + - self.null_state.size() + self.values.capacity() + + self + .values + .iter() + .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) + .sum::() + * std::mem::size_of::() + + self.null_state.size() } } @@ -419,7 +442,13 @@ mod tests { ])]); let expected = ScalarValue::List(Arc::new(list.clone())); - test_op!(a.clone(), DataType::Int32, ArrayAgg, expected, DataType::Int32); + test_op!( + a.clone(), + DataType::Int32, + ArrayAgg, + expected, + DataType::Int32 + ); let expected: ArrayRef = Arc::new(list); test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32) diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 3f2425ca6dc0..01f13d38adc0 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,8 +20,8 @@ //! [`GroupsAccumulator`]: crate::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; use arrow_array::cast::AsArray; +use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -445,7 +445,7 @@ pub fn accumulate_array_elements( mut value_fn: F, ) where F: FnMut(usize, ::Native) + Send, - T: ArrowPrimitiveType + Send + T: ArrowPrimitiveType + Send, { assert_eq!(values.len(), group_indices.len()); @@ -453,11 +453,12 @@ pub fn accumulate_array_elements( // no filter, None => { let iter = values.iter(); - group_indices.iter().zip(iter).for_each( - |(&group_index, new_value)| { + group_indices + .iter() + .zip(iter) + .for_each(|(&group_index, new_value)| { value_fn(group_index, new_value.unwrap()) - }, - ) + }) } // a filter Some(filter) => { @@ -482,7 +483,7 @@ pub fn accumulate_array( mut value_fn: F, ) where F: FnMut(usize, &PrimitiveArray) + Send, - T: ArrowPrimitiveType + Send + T: ArrowPrimitiveType + Send, { assert_eq!(values.len(), group_indices.len()); @@ -490,11 +491,12 @@ pub fn accumulate_array( // no filter, None => { let iter = values.iter(); - group_indices.iter().zip(iter).for_each( - |(&group_index, new_value)| { + group_indices + .iter() + .zip(iter) + .for_each(|(&group_index, new_value)| { value_fn(group_index, new_value.unwrap().as_primitive::()) - }, - ) + }) } // a filter Some(filter) => { From 0217847cb2504d840b0d20b7b65feb9039b9a9af Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Thu, 4 Apr 2024 18:40:21 +0200 Subject: [PATCH 11/21] Fix size --- datafusion/physical-expr/src/aggregate/array_agg.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index da5ab4e14e01..a652e69c7ea5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -331,6 +331,7 @@ where fn evaluate(&mut self, emit_to: EmitTo) -> Result { let array = emit_to.take_needed(&mut self.values); + // TODO: do we need null state? // let nulls = self.null_state.build(emit_to); // assert_eq!(array.len(), nulls.len()); @@ -356,7 +357,7 @@ where .iter() .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) .sum::() - * std::mem::size_of::() + * std::mem::size_of::<::Native>() + self.null_state.size() } } @@ -366,7 +367,6 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::{aggregate, aggregate_new}; - use crate::{generic_test_op, generic_test_op_new}; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; From ca793b16b7219f3c771516bc9137d0c6006c7d11 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Mon, 8 Apr 2024 18:56:52 +0300 Subject: [PATCH 12/21] Use null state for primitive types --- datafusion/common/src/hash_utils.rs | 1 - datafusion/common/src/scalar.rs | 12 +- .../physical-expr/src/aggregate/array_agg.rs | 94 +++++----- .../groups_accumulator/accumulate.rs | 160 +++++++++--------- 4 files changed, 137 insertions(+), 130 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 6a77cd41c201..8dcc00ca1c29 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -267,7 +267,6 @@ where Ok(()) } - /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) /// diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 4359280caa50..21c096bad4a6 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2622,11 +2622,11 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), - ScalarValue::Struct(vals, fields) => { - vals.as_ref() + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::Struct(vals, fields) => { + vals.as_ref() .map(|vals| { vals.iter() .map(|sv| sv.size() - std::mem::size_of_val(sv)) @@ -2638,7 +2638,7 @@ impl ScalarValue { + std::mem::size_of_val(fields) + (std::mem::size_of::() * fields.len()) + fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::() - } + } ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index a652e69c7ea5..7df1a6cf21c3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,20 +17,19 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use crate::aggregate::groups_accumulator::accumulate::{ - accumulate_array, accumulate_array_elements, NullState, -}; +use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::builder::{ListBuilder, PrimitiveBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray}; +use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, PrimitiveArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -240,7 +239,7 @@ struct ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send, { - values: Vec::Native>>>>, + values: Vec::Native>>>, null_state: NullState, } @@ -256,6 +255,33 @@ where } } +impl ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + assert_eq!(array.len(), nulls.len()); + + let mut builder = + ListBuilder::with_capacity(PrimitiveBuilder::::new(), nulls.len()); + for (is_valid, arr) in nulls.iter().zip(array.iter()) { + if is_valid { + for value in arr.iter() { + builder.values().append_option(*value); + } + builder.append(true); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, @@ -277,20 +303,15 @@ where assert_eq!(values.len(), 1, "single argument to update_batch"); let values = values[0].as_primitive::(); - for _ in self.values.len()..total_num_groups { - self.values.push(None); - } + self.values.resize(total_num_groups, vec![]); - accumulate_array_elements( + self.null_state.accumulate( group_indices, values, opt_filter, + total_num_groups, |group_index, new_value| { - if let Some(array) = &mut self.values[group_index] { - array.push(Some(new_value)); - } else { - self.values[group_index] = Some(vec![Some(new_value)]); - } + self.values[group_index].push(Some(new_value)); }, ); @@ -307,22 +328,20 @@ where assert_eq!(values.len(), 1, "single argument to merge_batch"); let values = values[0].as_list(); - for _ in self.values.len()..total_num_groups { - self.values.push(None); - } + self.values.resize(total_num_groups, vec![]); - accumulate_array( + self.null_state.accumulate_array( group_indices, values, opt_filter, - |group_index, new_value: &arrow_array::PrimitiveArray| { - if let Some(value) = &mut self.values[group_index] { - new_value.iter().for_each(|v| { - value.push(v); - }); - } else { - self.values[group_index] = Some(new_value.iter().collect()); - } + total_num_groups, + |group_index, new_value: &PrimitiveArray| { + self.values[group_index].append( + new_value + .into_iter() + .collect::>>() + .as_mut(), + ); }, ); @@ -330,33 +349,16 @@ where } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let array = emit_to.take_needed(&mut self.values); - // TODO: do we need null state? - // let nulls = self.null_state.build(emit_to); - - // assert_eq!(array.len(), nulls.len()); - - Ok(Arc::new(ListArray::from_iter_primitive::(array))) + Ok(self.build_list(emit_to)?) } fn state(&mut self, emit_to: EmitTo) -> Result> { - // TODO: do we need null state? - // let nulls = self.null_state.build(emit_to); - // let nulls = Some(nulls); - - let values = emit_to.take_needed(&mut self.values); - let values = ListArray::from_iter_primitive::(values); - - Ok(vec![Arc::new(values) as ArrayRef]) + Ok(vec![self.build_list(emit_to)?]) } fn size(&self) -> usize { self.values.capacity() - + self - .values - .iter() - .map(|arr| arr.as_ref().unwrap_or(&Vec::new()).capacity()) - .sum::() + + self.values.iter().map(|arr| arr.capacity()).sum::() * std::mem::size_of::<::Native>() + self.null_state.size() } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 01f13d38adc0..2179033f6ee2 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -21,7 +21,7 @@ use arrow::datatypes::ArrowPrimitiveType; use arrow_array::cast::AsArray; -use arrow_array::{Array, BooleanArray, ListArray, PrimitiveArray}; +use arrow_array::{Array, BooleanArray, GenericListArray, ListArray, PrimitiveArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -320,6 +320,88 @@ impl NullState { } } + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_array( + &mut self, + group_indices: &[usize], + values: &ListArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, &PrimitiveArray) + Send, + { + let data: &GenericListArray = values.values().as_list(); + assert_eq!(data.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap().as_primitive()); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(data.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap().as_primitive()); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap().as_primitive()); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.as_primitive()); + } + } + }); + } + } + } + /// Creates the a [`NullBuffer`] representing which group_indices /// should have null values (because they never saw any values) /// for the `emit_to` rows. @@ -438,82 +520,6 @@ pub fn accumulate_indices( } } -pub fn accumulate_array_elements( - group_indices: &[usize], - values: &PrimitiveArray, - opt_filter: Option<&BooleanArray>, - mut value_fn: F, -) where - F: FnMut(usize, ::Native) + Send, - T: ArrowPrimitiveType + Send, -{ - assert_eq!(values.len(), group_indices.len()); - - match opt_filter { - // no filter, - None => { - let iter = values.iter(); - group_indices - .iter() - .zip(iter) - .for_each(|(&group_index, new_value)| { - value_fn(group_index, new_value.unwrap()) - }) - } - // a filter - Some(filter) => { - assert_eq!(filter.len(), group_indices.len()); - group_indices - .iter() - .zip(values.iter()) - .zip(filter.iter()) - .for_each(|((&group_index, new_value), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, new_value.unwrap()); - } - }) - } - } -} - -pub fn accumulate_array( - group_indices: &[usize], - values: &ListArray, - opt_filter: Option<&BooleanArray>, - mut value_fn: F, -) where - F: FnMut(usize, &PrimitiveArray) + Send, - T: ArrowPrimitiveType + Send, -{ - assert_eq!(values.len(), group_indices.len()); - - match opt_filter { - // no filter, - None => { - let iter = values.iter(); - group_indices - .iter() - .zip(iter) - .for_each(|(&group_index, new_value)| { - value_fn(group_index, new_value.unwrap().as_primitive::()) - }) - } - // a filter - Some(filter) => { - assert_eq!(filter.len(), group_indices.len()); - group_indices - .iter() - .zip(values.iter()) - .zip(filter.iter()) - .for_each(|((&group_index, new_value), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, new_value.unwrap().as_primitive::()); - } - }) - } - } -} - /// Ensures that `builder` contains a `BooleanBufferBuilder with at /// least `total_num_groups`. /// From 5e25fa76bac1f0bd76cda0f826da68960c4b45d9 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Tue, 9 Apr 2024 15:25:45 +0300 Subject: [PATCH 13/21] Small fic --- datafusion/physical-expr/src/aggregate/array_agg.rs | 8 ++++---- .../src/aggregate/groups_accumulator/accumulate.rs | 9 ++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 7df1a6cf21c3..93470667cc99 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -295,19 +295,19 @@ where fn update_batch( &mut self, - values: &[ArrayRef], + new_values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_primitive::(); self.values.resize(total_num_groups, vec![]); self.null_state.accumulate( group_indices, - values, + new_values, opt_filter, total_num_groups, |group_index, new_value| { diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 2179033f6ee2..3afa759d4da1 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -338,8 +338,7 @@ impl NullState { T: ArrowPrimitiveType + Send, F: FnMut(usize, &PrimitiveArray) + Send, { - let data: &GenericListArray = values.values().as_list(); - assert_eq!(data.len(), group_indices.len()); + assert_eq!(values.len(), group_indices.len()); // ensure the seen_values is big enough (start everything at // "not seen" valid) @@ -349,7 +348,7 @@ impl NullState { match (values.null_count() > 0, opt_filter) { // no nulls, no filter, (false, None) => { - let iter = group_indices.iter().zip(data.iter()); + let iter = group_indices.iter().zip(values.iter()); for (&group_index, new_value) in iter { seen_values.set_bit(group_index, true); value_fn(group_index, new_value.unwrap().as_primitive()); @@ -360,7 +359,7 @@ impl NullState { let nulls = values.nulls().unwrap(); group_indices .iter() - .zip(data.iter()) + .zip(values.iter()) .zip(nulls.iter()) .for_each(|((&group_index, new_value), is_valid)| { if is_valid { @@ -374,7 +373,7 @@ impl NullState { assert_eq!(filter.len(), group_indices.len()); group_indices .iter() - .zip(data.iter()) + .zip(values.iter()) .zip(filter.iter()) .for_each(|((&group_index, new_value), filter_value)| { if let Some(true) = filter_value { From 3885af2375cdc28cc4f1f24e76d23e0c0d3541e6 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Wed, 10 Apr 2024 17:41:10 +0300 Subject: [PATCH 14/21] String support in array_agg --- .../physical-expr/src/aggregate/array_agg.rs | 195 ++++++++++++++---- .../groups_accumulator/accumulate.rs | 97 ++++++++- 2 files changed, 247 insertions(+), 45 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 93470667cc99..5a228d75d8b3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -23,13 +23,13 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::builder::{ListBuilder, PrimitiveBuilder}; +use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, PrimitiveArray}; +use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, StringArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -111,33 +111,16 @@ impl AggregateExpr for ArrayAgg { fn create_groups_accumulator(&self) -> Result> { match self.input_data_type { DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int16 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::Int32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::Int64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::UInt8 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::UInt16 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::UInt32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::UInt64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::Float32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } - DataType::Float64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) - } + DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())), _ => Err(DataFusionError::Internal(format!( "ArrayAggGroupsAccumulator not supported for data type {:?}", self.input_data_type @@ -335,7 +318,8 @@ where values, opt_filter, total_num_groups, - |group_index, new_value: &PrimitiveArray| { + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_primitive::(); self.values[group_index].append( new_value .into_iter() @@ -364,6 +348,124 @@ where } } +struct StringArrayAggGroupsAccumulator { + values: Vec>>, + null_state: NullState, +} + +impl StringArrayAggGroupsAccumulator { + pub fn new() -> Self { + Self { + values: vec![], + null_state: NullState::new(), + } + } +} + +impl StringArrayAggGroupsAccumulator { + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + assert_eq!(array.len(), nulls.len()); + + let mut builder = + ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); + for (is_valid, arr) in nulls.iter().zip(array.iter()) { + if is_valid { + for value in arr.iter() { + builder.values().append_option(value.as_deref()); + } + builder.append(true); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + +impl GroupsAccumulator for StringArrayAggGroupsAccumulator { + fn update_batch( + &mut self, + new_values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_string(); + + self.values.resize(total_num_groups, vec![]); + + self.null_state.accumulate_string( + group_indices, + new_values, + opt_filter, + total_num_groups, + |group_index, new_value| { + self.values[group_index].push(Some(new_value.to_string())); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + self.values.resize(total_num_groups, Vec::>::new()); + + self.null_state.accumulate_array( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_string::(); + + self.values[group_index].append(new_value + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect::>>() + .as_mut()); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + Ok(self.build_list(emit_to)?) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.build_list(emit_to)?]) + } + + fn size(&self) -> usize { + self.values.capacity() + + self.values.iter().map( + |arr| + arr.iter().map( + |e| + e.as_ref().map(|s| s.len()).unwrap_or(0) + ).sum::() + ).sum::() + + + self.null_state.size() + } +} + + #[cfg(test)] mod tests { use super::*; @@ -398,8 +500,6 @@ mod tests { let expected = ScalarValue::from($EXPECTED); assert_eq!(expected, actual); - - Ok(()) as Result<(), DataFusionError> }}; } @@ -426,8 +526,6 @@ mod tests { )); let actual = aggregate_new(&batch, agg)?; assert_eq!($EXPECTED, &actual); - - Ok(()) as Result<(), DataFusionError> }}; } @@ -453,7 +551,30 @@ mod tests { ); let expected: ArrayRef = Arc::new(list); - test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32) + test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32); + + Ok(()) + } + + #[test] + fn array_agg_str() -> Result<()> { + let a: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3", "4", "5"])); + + let mut list_builder = ListBuilder::with_capacity(StringBuilder::new(), 5); + list_builder.values().append_value("1"); + list_builder.values().append_value("2"); + list_builder.values().append_value("3"); + list_builder.values().append_value("4"); + list_builder.values().append_value("5"); + list_builder.append(true); + + let list = list_builder.finish(); + let expected = ScalarValue::List(Arc::new(list.clone())); + + let expected: ArrayRef = Arc::new(list); + test_op_new!(a, DataType::Utf8, ArrayAgg, &expected, DataType::Utf8); + + Ok(()) } #[test] @@ -519,6 +640,8 @@ mod tests { ArrayAgg, list, DataType::List(Arc::new(Field::new("item", DataType::Int32, true,))) - ) + ); + + Ok(()) } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 3afa759d4da1..7e4c462eb5c7 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,8 +20,7 @@ //! [`GroupsAccumulator`]: crate::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::cast::AsArray; -use arrow_array::{Array, BooleanArray, GenericListArray, ListArray, PrimitiveArray}; +use arrow_array::{Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, StringArray}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -327,7 +326,7 @@ impl NullState { /// /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for /// more details on other arguments. - pub fn accumulate_array( + pub fn accumulate_array( &mut self, group_indices: &[usize], values: &ListArray, @@ -335,8 +334,88 @@ impl NullState { total_num_groups: usize, mut value_fn: F, ) where - T: ArrowPrimitiveType + Send, - F: FnMut(usize, &PrimitiveArray) + Send, + F: FnMut(usize, ArrayRef) + Send, + { + assert_eq!(values.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(values.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(values.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + }); + } + } + } + + + /// Invokes `value_fn(group_index, value)` for each non-null, + /// non-filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_string( + &mut self, + group_indices: &[usize], + values: &StringArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, &str) + Send, { assert_eq!(values.len(), group_indices.len()); @@ -351,7 +430,7 @@ impl NullState { let iter = group_indices.iter().zip(values.iter()); for (&group_index, new_value) in iter { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value.unwrap().as_primitive()); + value_fn(group_index, new_value.unwrap()); } } // nulls, no filter @@ -364,7 +443,7 @@ impl NullState { .for_each(|((&group_index, new_value), is_valid)| { if is_valid { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value.unwrap().as_primitive()); + value_fn(group_index, new_value.unwrap()); } }) } @@ -378,7 +457,7 @@ impl NullState { .for_each(|((&group_index, new_value), filter_value)| { if let Some(true) = filter_value { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value.unwrap().as_primitive()); + value_fn(group_index, new_value.unwrap()); } }); } @@ -393,7 +472,7 @@ impl NullState { if let Some(true) = filter_value { if let Some(new_value) = new_value { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value.as_primitive()); + value_fn(group_index, new_value); } } }); From c4c2d96627043c99e093a65addea3e8b843a8668 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Wed, 10 Apr 2024 17:54:24 +0300 Subject: [PATCH 15/21] fmt --- .../physical-expr/src/aggregate/array_agg.rs | 74 ++++++++++++------- .../groups_accumulator/accumulate.rs | 5 +- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 5a228d75d8b3..0c869badca9c 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -111,15 +111,33 @@ impl AggregateExpr for ArrayAgg { fn create_groups_accumulator(&self) -> Result> { match self.input_data_type { DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::UInt64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), - DataType::Float64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Int64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt8 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::UInt64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } + DataType::Float64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + } DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())), _ => Err(DataFusionError::Internal(format!( "ArrayAggGroupsAccumulator not supported for data type {:?}", @@ -369,8 +387,7 @@ impl StringArrayAggGroupsAccumulator { assert_eq!(array.len(), nulls.len()); - let mut builder = - ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); + let mut builder = ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); for (is_valid, arr) in nulls.iter().zip(array.iter()) { if is_valid { for value in arr.iter() { @@ -422,7 +439,8 @@ impl GroupsAccumulator for StringArrayAggGroupsAccumulator { assert_eq!(values.len(), 1, "single argument to merge_batch"); let values = values[0].as_list(); - self.values.resize(total_num_groups, Vec::>::new()); + self.values + .resize(total_num_groups, Vec::>::new()); self.null_state.accumulate_array( group_indices, @@ -432,11 +450,13 @@ impl GroupsAccumulator for StringArrayAggGroupsAccumulator { |group_index, new_value: ArrayRef| { let new_value = new_value.as_string::(); - self.values[group_index].append(new_value - .into_iter() - .map(|s| s.map(|s| s.to_string())) - .collect::>>() - .as_mut()); + self.values[group_index].append( + new_value + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect::>>() + .as_mut(), + ); }, ); @@ -452,20 +472,20 @@ impl GroupsAccumulator for StringArrayAggGroupsAccumulator { } fn size(&self) -> usize { - self.values.capacity() + - self.values.iter().map( - |arr| - arr.iter().map( - |e| - e.as_ref().map(|s| s.len()).unwrap_or(0) - ).sum::() - ).sum::() - + self.values.capacity() + + self + .values + .iter() + .map(|arr| { + arr.iter() + .map(|e| e.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum::() + }) + .sum::() + self.null_state.size() } } - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 7e4c462eb5c7..941588b82026 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,7 +20,9 @@ //! [`GroupsAccumulator`]: crate::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, StringArray}; +use arrow_array::{ + Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, StringArray, +}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use crate::EmitTo; @@ -399,7 +401,6 @@ impl NullState { } } - /// Invokes `value_fn(group_index, value)` for each non-null, /// non-filtered value in `values`, while tracking which groups have /// seen null inputs and which groups have seen any inputs, for From 1913c6493ea83a1e01c8c70cf13a1afdee493680 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Thu, 11 Apr 2024 11:23:36 +0300 Subject: [PATCH 16/21] Remove comment --- datafusion/physical-expr/src/aggregate/array_agg.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 0c869badca9c..166ea68ba40e 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -287,12 +287,6 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, { - // TODO: - // 1. Implement support for null state - // 2. Implement support for low level ListArray creation api with offsets and nulls - // 3. Implement support for variable size types such as Utf8 - // 4. Implement support for accumulating Lists of any level of nesting - // 5. Use this group accumulator in array_agg_distinct.rs fn update_batch( &mut self, From 917f63284c0d67c84421d9b723b33f632f7744bc Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Fri, 12 Apr 2024 14:18:54 +0300 Subject: [PATCH 17/21] Not compiling generics problem --- datafusion/core/src/main.rs | 5 + .../physical-expr/src/aggregate/array_agg.rs | 169 +++++++++++++++--- 2 files changed, 151 insertions(+), 23 deletions(-) create mode 100644 datafusion/core/src/main.rs diff --git a/datafusion/core/src/main.rs b/datafusion/core/src/main.rs new file mode 100644 index 000000000000..146ba8b47a17 --- /dev/null +++ b/datafusion/core/src/main.rs @@ -0,0 +1,5 @@ +fn main() { + println!("Hello, world!"); + + +} \ No newline at end of file diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 166ea68ba40e..a03574142962 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,18 +17,18 @@ //! Defines physical expressions that can evaluated at runtime during query execution +#![feature(specialization)] + + use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; +use arrow_array::builder::{GenericListBuilder, ListBuilder, PrimitiveBuilder, StringBuilder}; use arrow_array::cast::AsArray; -use arrow_array::types::{ - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, -}; +use arrow_array::types::{ArrowTimestampType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, StringArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; @@ -37,6 +37,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; +use arrow_schema::TimeUnit; /// ARRAY_AGG aggregate expression #[derive(Debug)] @@ -105,38 +106,60 @@ impl AggregateExpr for ArrayAgg { } fn groups_accumulator_supported(&self) -> bool { - self.input_data_type.is_primitive() + self.input_data_type.is_primitive() || + match self.input_data_type { + DataType::Utf8 => true, + _ => false, + } } fn create_groups_accumulator(&self) -> Result> { match self.input_data_type { - DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new())), + DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))), DataType::Int16 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::Int32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::Int64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::UInt8 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::UInt16 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::UInt32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::UInt64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::Float32 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::Float64 => { - Ok(Box::new(ArrayAggGroupsAccumulator::::new())) + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Date32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Date64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Second, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())), _ => Err(DataFusionError::Internal(format!( @@ -241,6 +264,7 @@ where T: ArrowPrimitiveType + Send, { values: Vec::Native>>>, + data_type: DataType, null_state: NullState, } @@ -248,26 +272,94 @@ impl ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send, { - pub fn new() -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: vec![], + data_type: data_type.clone(), null_state: NullState::new(), } } } -impl ArrayAggGroupsAccumulator -where - T: ArrowPrimitiveType + Send, + +// trait Builder { +// fn new_builder(self, len: usize) -> ListBuilder; +// } +// +// impl Builder for ArrayAggGroupsAccumulator { +// fn new_builder(self, len: usize) -> ListBuilder { +// ListBuilder::with_capacity(PrimitiveBuilder::::new(), len) +// } +// } +// +// impl Builder for ArrayAggGroupsAccumulator { +// fn new_builder(self, len: usize) -> ListBuilder { +// match &self.data_type { +// DataType::Timestamp(TimeUnit::Nanosecond, tz) => +// ListBuilder::with_capacity( +// PrimitiveBuilder::::new() +// .with_timezone_opt(tz.clone()), len, +// ), +// DataType::Timestamp(TimeUnit::Microsecond, tz) => +// ListBuilder::with_capacity( +// PrimitiveBuilder::::new() +// .with_timezone_opt(tz.clone()), len, +// ), +// DataType::Timestamp(TimeUnit::Millisecond, tz) => +// ListBuilder::with_capacity( +// PrimitiveBuilder::::new() +// .with_timezone_opt(tz.clone()), len, +// ), +// DataType::Timestamp(TimeUnit::Second, tz) => +// ListBuilder::with_capacity( +// PrimitiveBuilder::::new() +// .with_timezone_opt(tz.clone()), len, +// ), +// _ => +// ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), +// } +// } +// } + +impl ArrayAggGroupsAccumulator { fn build_list(&mut self, emit_to: EmitTo) -> Result { let array = emit_to.take_needed(&mut self.values); let nulls = self.null_state.build(emit_to); - assert_eq!(array.len(), nulls.len()); + let len = nulls.len(); + assert_eq!(array.len(), len); + + let mut builder: GenericListBuilder> = + match &self.data_type { + DataType::Timestamp(TimeUnit::Nanosecond, tz) => + ListBuilder::with_capacity( + (PrimitiveBuilder::::new() as PrimitiveBuilder) + .with_timezone_opt(tz.clone()) as PrimitiveBuilder, + len, + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => + ListBuilder::with_capacity( + (PrimitiveBuilder::::new() as PrimitiveBuilder) + .with_timezone_opt(tz.clone()) as PrimitiveBuilder, + len, + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => + ListBuilder::with_capacity( + (PrimitiveBuilder::::new() as PrimitiveBuilder) + .with_timezone_opt(tz.clone()) as PrimitiveBuilder, + len, + ), + DataType::Timestamp(TimeUnit::Second, tz) => + ListBuilder::with_capacity( + (PrimitiveBuilder::::new() as PrimitiveBuilder) + .with_timezone_opt(tz.clone()) as PrimitiveBuilder, + len, + ), + _ => + ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), + }; - let mut builder = - ListBuilder::with_capacity(PrimitiveBuilder::::new(), nulls.len()); for (is_valid, arr) in nulls.iter().zip(array.iter()) { if is_valid { for value in arr.iter() { @@ -283,6 +375,37 @@ where } } + +impl ArrayAggGroupsAccumulator { + fn timestamp_builder(&mut self, len: usize) -> GenericListBuilder> { + match &self.data_type { + DataType::Timestamp(TimeUnit::Nanosecond, tz) => + ListBuilder::with_capacity( + PrimitiveBuilder::::new() + .with_timezone_opt(tz.clone()), len, + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => + ListBuilder::with_capacity( + PrimitiveBuilder::::new() + .with_timezone_opt(tz.clone()), len, + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => + ListBuilder::with_capacity( + PrimitiveBuilder::::new() + .with_timezone_opt(tz.clone()), len, + ), + DataType::Timestamp(TimeUnit::Second, tz) => + ListBuilder::with_capacity( + PrimitiveBuilder::::new() + .with_timezone_opt(tz.clone()), len, + ), + _ => + ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), + } + } +} + + impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, From d79c5347e9515b4f8768f7443897b434d7edb63c Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Fri, 12 Apr 2024 15:10:14 +0300 Subject: [PATCH 18/21] Pass datatype to PrimitiveBuilder --- .../physical-expr/src/aggregate/array_agg.rs | 108 +----------------- 1 file changed, 6 insertions(+), 102 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index a03574142962..1e6bddff2830 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -26,7 +26,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::builder::{GenericListBuilder, ListBuilder, PrimitiveBuilder, StringBuilder}; +use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{ArrowTimestampType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, StringArray}; @@ -281,46 +281,6 @@ where } } - -// trait Builder { -// fn new_builder(self, len: usize) -> ListBuilder; -// } -// -// impl Builder for ArrayAggGroupsAccumulator { -// fn new_builder(self, len: usize) -> ListBuilder { -// ListBuilder::with_capacity(PrimitiveBuilder::::new(), len) -// } -// } -// -// impl Builder for ArrayAggGroupsAccumulator { -// fn new_builder(self, len: usize) -> ListBuilder { -// match &self.data_type { -// DataType::Timestamp(TimeUnit::Nanosecond, tz) => -// ListBuilder::with_capacity( -// PrimitiveBuilder::::new() -// .with_timezone_opt(tz.clone()), len, -// ), -// DataType::Timestamp(TimeUnit::Microsecond, tz) => -// ListBuilder::with_capacity( -// PrimitiveBuilder::::new() -// .with_timezone_opt(tz.clone()), len, -// ), -// DataType::Timestamp(TimeUnit::Millisecond, tz) => -// ListBuilder::with_capacity( -// PrimitiveBuilder::::new() -// .with_timezone_opt(tz.clone()), len, -// ), -// DataType::Timestamp(TimeUnit::Second, tz) => -// ListBuilder::with_capacity( -// PrimitiveBuilder::::new() -// .with_timezone_opt(tz.clone()), len, -// ), -// _ => -// ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), -// } -// } -// } - impl ArrayAggGroupsAccumulator { fn build_list(&mut self, emit_to: EmitTo) -> Result { @@ -330,35 +290,11 @@ impl ArrayAggGroupsAccumulator let len = nulls.len(); assert_eq!(array.len(), len); - let mut builder: GenericListBuilder> = - match &self.data_type { - DataType::Timestamp(TimeUnit::Nanosecond, tz) => - ListBuilder::with_capacity( - (PrimitiveBuilder::::new() as PrimitiveBuilder) - .with_timezone_opt(tz.clone()) as PrimitiveBuilder, - len, - ), - DataType::Timestamp(TimeUnit::Microsecond, tz) => - ListBuilder::with_capacity( - (PrimitiveBuilder::::new() as PrimitiveBuilder) - .with_timezone_opt(tz.clone()) as PrimitiveBuilder, - len, - ), - DataType::Timestamp(TimeUnit::Millisecond, tz) => - ListBuilder::with_capacity( - (PrimitiveBuilder::::new() as PrimitiveBuilder) - .with_timezone_opt(tz.clone()) as PrimitiveBuilder, - len, - ), - DataType::Timestamp(TimeUnit::Second, tz) => - ListBuilder::with_capacity( - (PrimitiveBuilder::::new() as PrimitiveBuilder) - .with_timezone_opt(tz.clone()) as PrimitiveBuilder, - len, - ), - _ => - ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), - }; + let mut builder = + ListBuilder::with_capacity( + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()), + len + ); for (is_valid, arr) in nulls.iter().zip(array.iter()) { if is_valid { @@ -375,42 +311,10 @@ impl ArrayAggGroupsAccumulator } } - -impl ArrayAggGroupsAccumulator { - fn timestamp_builder(&mut self, len: usize) -> GenericListBuilder> { - match &self.data_type { - DataType::Timestamp(TimeUnit::Nanosecond, tz) => - ListBuilder::with_capacity( - PrimitiveBuilder::::new() - .with_timezone_opt(tz.clone()), len, - ), - DataType::Timestamp(TimeUnit::Microsecond, tz) => - ListBuilder::with_capacity( - PrimitiveBuilder::::new() - .with_timezone_opt(tz.clone()), len, - ), - DataType::Timestamp(TimeUnit::Millisecond, tz) => - ListBuilder::with_capacity( - PrimitiveBuilder::::new() - .with_timezone_opt(tz.clone()), len, - ), - DataType::Timestamp(TimeUnit::Second, tz) => - ListBuilder::with_capacity( - PrimitiveBuilder::::new() - .with_timezone_opt(tz.clone()), len, - ), - _ => - ListBuilder::with_capacity(PrimitiveBuilder::::new(), len), - } - } -} - - impl GroupsAccumulator for ArrayAggGroupsAccumulator where T: ArrowPrimitiveType + Send + Sync, { - fn update_batch( &mut self, new_values: &[ArrayRef], From 4e81e04a03aa9e4de222fc0dd659caa73ec0b550 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Fri, 12 Apr 2024 15:14:19 +0300 Subject: [PATCH 19/21] remove feature --- datafusion/physical-expr/src/aggregate/array_agg.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 1e6bddff2830..3c291d74da7e 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,9 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -#![feature(specialization)] - - use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; From fa702e88426ca78a422bf48c3ba662ebe172fa19 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Fri, 12 Apr 2024 17:04:21 +0300 Subject: [PATCH 20/21] Delete temp file --- datafusion/core/src/main.rs | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 datafusion/core/src/main.rs diff --git a/datafusion/core/src/main.rs b/datafusion/core/src/main.rs deleted file mode 100644 index 146ba8b47a17..000000000000 --- a/datafusion/core/src/main.rs +++ /dev/null @@ -1,5 +0,0 @@ -fn main() { - println!("Hello, world!"); - - -} \ No newline at end of file From 85b0db56824e5f42632f700b42353cf66bad76e4 Mon Sep 17 00:00:00 2001 From: Gediminas Aleknavicius Date: Fri, 12 Apr 2024 17:49:12 +0300 Subject: [PATCH 21/21] Respond to review --- .../physical-expr/src/aggregate/array_agg.rs | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 3c291d74da7e..1dbe08158590 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -25,7 +25,7 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; use arrow_array::cast::AsArray; -use arrow_array::types::{ArrowTimestampType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; +use arrow_array::types::{ArrowTimestampType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; use arrow_array::{Array, ArrowPrimitiveType, BooleanArray, ListArray, StringArray}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; @@ -34,6 +34,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; +use arrow::ipc::{Duration, Interval}; use arrow_schema::TimeUnit; /// ARRAY_AGG aggregate expression @@ -103,11 +104,7 @@ impl AggregateExpr for ArrayAgg { } fn groups_accumulator_supported(&self) -> bool { - self.input_data_type.is_primitive() || - match self.input_data_type { - DataType::Utf8 => true, - _ => false, - } + self.input_data_type.is_primitive() || self.input_data_type == DataType::Utf8 } fn create_groups_accumulator(&self) -> Result> { @@ -140,6 +137,12 @@ impl AggregateExpr for ArrayAgg { DataType::Float64 => { Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } + DataType::Decimal128(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Decimal256(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } DataType::Date32 => { Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } @@ -158,8 +161,26 @@ impl AggregateExpr for ArrayAgg { DataType::Timestamp(TimeUnit::Nanosecond, _) => { Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) } + DataType::Time32(TimeUnit::Second) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Time32(TimeUnit::Millisecond) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Time64(TimeUnit::Microsecond) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Duration(_) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } + DataType::Interval(_) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new(&self.input_data_type))) + } DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())), - _ => Err(DataFusionError::Internal(format!( + _ => Err(DataFusionError::Internal(format!( "ArrayAggGroupsAccumulator not supported for data type {:?}", self.input_data_type ))), @@ -293,12 +314,9 @@ impl ArrayAggGroupsAccumulator len ); - for (is_valid, arr) in nulls.iter().zip(array.iter()) { + for (is_valid, arr) in nulls.iter().zip(array.into_iter()) { if is_valid { - for value in arr.iter() { - builder.values().append_option(*value); - } - builder.append(true); + builder.append_value(arr); } else { builder.append_null(); } @@ -406,12 +424,9 @@ impl StringArrayAggGroupsAccumulator { assert_eq!(array.len(), nulls.len()); let mut builder = ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); - for (is_valid, arr) in nulls.iter().zip(array.iter()) { + for (is_valid, arr) in nulls.iter().zip(array.into_iter()) { if is_valid { - for value in arr.iter() { - builder.values().append_option(value.as_deref()); - } - builder.append(true); + builder.append_value(arr); } else { builder.append_null(); }