From 06f6f29ae08aed4bb0a82ccd59358452d7393f86 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 3 Jul 2024 19:10:33 +0000 Subject: [PATCH] Rebasing on main after several other aggregate functions were removed --- .../examples/dataframe_subquery.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 7 +- .../aggregate_statistics.rs | 31 +--- datafusion/core/tests/dataframe/mod.rs | 4 +- datafusion/expr/src/aggregate_function.rs | 24 +-- datafusion/expr/src/expr.rs | 6 - datafusion/expr/src/expr_rewriter/order_by.rs | 3 +- datafusion/expr/src/test/function_stub.rs | 4 +- .../expr/src/type_coercion/aggregates.rs | 57 ------- datafusion/functions-aggregate/Cargo.toml | 2 +- datafusion/functions-aggregate/src/lib.rs | 2 +- datafusion/functions-aggregate/src/min_max.rs | 35 ++--- .../optimizer/src/eliminate_distinct.rs | 140 ------------------ datafusion/optimizer/src/lib.rs | 1 - .../physical-expr/src/aggregate/build_in.rs | 102 ------------- datafusion/proto/src/generated/pbjson.rs | 2 - datafusion/proto/src/generated/prost.rs | 7 - .../proto/src/logical_plan/from_proto.rs | 6 - datafusion/proto/src/logical_plan/to_proto.rs | 11 -- .../proto/src/physical_plan/to_proto.rs | 15 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- 21 files changed, 36 insertions(+), 429 deletions(-) delete mode 100644 datafusion/optimizer/src/eliminate_distinct.rs diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 93dcdd7ee893..b9c2a3ff9092 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -19,8 +19,8 @@ use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; -use datafusion::logical_expr::test::function_stub::max; use datafusion::functions_aggregate::average::avg; +use datafusion::logical_expr::test::function_stub::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 14f63754a118..9d7c810d22d9 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -53,10 +53,11 @@ use datafusion_common::{ }; use datafusion_expr::{case, is_null, lit}; use datafusion_expr::{ - avg, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, + utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, +}; +use datafusion_functions_aggregate::expr_fn::{ + avg, count, max, median, min, stddev, sum, }; -use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::{count, max, median, min, stddev, sum}; use async_trait::async_trait; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index a689d36432ea..66067d8cb5c4 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -171,23 +171,6 @@ fn take_optimizable_column_and_table_count( None } -fn unwrap_min(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> { - if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { - if casted_expr.fun().name() == "MIN" { - return Some(casted_expr); - } - } - None -} - -fn unwrap_max(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> { - if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { - if casted_expr.fun().name() == "MAX" { - return Some(casted_expr); - } - } - None -} /// If this agg_expr is a min that is exactly defined in the statistics, return it. fn take_optimizable_min( agg_expr: &dyn AggregateExpr, @@ -197,7 +180,7 @@ fn take_optimizable_min( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = unwrap_min(agg_expr) { + if is_min(agg_expr) { if let Ok(min_data_type) = ScalarValue::try_from(agg_expr.field().unwrap().data_type()) { @@ -207,8 +190,9 @@ fn take_optimizable_min( } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = unwrap_min(agg_expr) { - if casted_expr.expressions().len() == 1 { + if is_min(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = exprs[0].as_any().downcast_ref::() @@ -242,7 +226,7 @@ fn take_optimizable_max( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = unwrap_max(agg_expr){ + if is_max(agg_expr) { if let Ok(max_data_type) = ScalarValue::try_from(agg_expr.field().unwrap().data_type()) { @@ -252,8 +236,9 @@ fn take_optimizable_max( } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = unwrap_max(agg_expr){ - if casted_expr.expressions().len() == 1 { + if is_max(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = exprs[0].as_any().downcast_ref::() diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 240aecaa3791..be5fc4e06b74 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -53,11 +53,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, + array_agg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, max, sum}; +use datafusion_functions_aggregate::expr_fn::{avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index eb140e3ac797..f1a2ea4b6a98 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Average - Avg, /// Aggregation into an array ArrayAgg, } @@ -43,7 +41,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Avg => "AVG", ArrayAgg => "ARRAY_AGG", } } @@ -59,11 +56,6 @@ impl FromStr for AggregateFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name { - // general - "avg" => AggregateFunction::Avg, - "bool_and" => AggregateFunction::BoolAnd, - "bool_or" => AggregateFunction::BoolOr, - "mean" => AggregateFunction::Avg, "array_agg" => AggregateFunction::ArrayAgg, _ => { return plan_err!("There is no built-in function named {name}"); @@ -99,10 +91,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Correlation => { - correlation_return_type(&coerced_data_types[0]) - } - AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", coerced_data_types[0].clone(), @@ -115,7 +103,6 @@ impl AggregateFunction { /// nullability pub fn nullable(&self) -> Result { match self { - AggregateFunction::Max | AggregateFunction::Min => Ok(true), AggregateFunction::ArrayAgg => Ok(true), } } @@ -126,16 +113,7 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { - Signature::any(1, Volatility::Immutable) - } - AggregateFunction::Avg => { - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } + AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ae3a5ff1d3c7..5ac707c46f08 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2540,12 +2540,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Avg - )) - ); assert_eq!( find_df_window_func("cume_dist"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 0bc9019b6568..99296cb389a0 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,12 +156,13 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - avg, cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, + cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, LogicalPlanBuilder, }; use super::*; use crate::test::function_stub::min; + use crate::test::function_stub::avg; #[test] fn rewrite_sort_cols_by_agg() { diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index c206879765a2..19822c92d690 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -328,7 +328,7 @@ impl Default for Min { impl Min { pub fn new() -> Self { Self { - aliases: vec!["count".to_string()], + aliases: vec!["min".to_string()], signature: Signature::variadic_any(Volatility::Immutable), } } @@ -412,7 +412,7 @@ impl Default for Max { impl Max { pub fn new() -> Self { Self { - aliases: vec!["count".to_string()], + aliases: vec!["max".to_string()], signature: Signature::variadic_any(Volatility::Immutable), } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 00180be55e2f..adad003d98f8 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -24,8 +24,6 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -use crate::{AggregateFunction, Signature, TypeSignature}; - pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; pub static SIGNED_INTEGERS: &[DataType] = &[ @@ -93,53 +91,8 @@ pub fn coerce_types( ) -> Result> { // Validate input_types matches (at least one of) the func signature. check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; - match agg_fun { AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Avg => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval - let v = match &input_types[0] { - Decimal128(p, s) => Decimal128(*p, *s), - Decimal256(p, s) => Decimal256(*p, *s), - d if d.is_numeric() => Float64, - Dictionary(_, v) => { - return coerce_types(agg_fun, &[v.as_ref().clone()], signature) - } - _ => { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ) - } - }; - Ok(vec![v]) - } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_bool_and_or_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Correlation => { - if !is_correlation_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::NthValue => Ok(input_types.to_vec()), - AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } @@ -374,16 +327,6 @@ mod tests { use super::*; #[test] fn test_aggregate_coerce_types() { - let fun = AggregateFunction::Avg; - // test input args is invalid data type for avg - let input_types = vec![DataType::Utf8]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!( - "Error during planning: The function Avg does not support inputs of type Utf8.", - result.unwrap_err().strip_backtrace() - ); - // test count, array_agg, approx_distinct. // the coerced types is same with input types let funs = vec![AggregateFunction::ArrayAgg]; diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 05b627da3467..43ddd37cfb6f 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -50,4 +50,4 @@ paste = "1.0.14" sqlparser = { workspace = true } [dev-dependencies] -rand = { workspace = true } \ No newline at end of file +rand = { workspace = true } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 156d7ac71f5d..ce0e0c0c0d7a 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -194,7 +194,7 @@ mod tests { let migrated_functions = vec!["count", "max", "min"]; for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermidiate migration state, skip them + // These functions are in intermediate migration state, skip them if migrated_functions.contains(&func.name().to_lowercase().as_str()) { continue; } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index a5e52c5d5113..bcd7581cb2e3 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -65,21 +65,6 @@ use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -// min/max of two non-string scalar values. -macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $($EXTRA_ARGS.clone()),* - ) - }}; -} - macro_rules! typed_min_max_float { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ ScalarValue::$SCALAR(match ($VALUE, $DELTA) { @@ -783,7 +768,6 @@ impl MovingMax { } } - make_udaf_expr_and_func!( Max, max, @@ -961,10 +945,12 @@ impl AggregateUDFImpl for Max { } } - fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result> { + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) } - } /// An accumulator to compute the maximum value @@ -1161,11 +1147,12 @@ impl AggregateUDFImpl for Min { } } - - fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result> { + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) } - } /// An accumulator to compute the minimum value #[derive(Debug)] @@ -1209,8 +1196,6 @@ impl Accumulator for MinAccumulator { } } - - #[derive(Debug)] pub struct SlidingMinAccumulator { min: ScalarValue, @@ -1218,7 +1203,6 @@ pub struct SlidingMinAccumulator { } impl SlidingMinAccumulator { - pub fn try_new(datatype: &DataType) -> Result { Ok(Self { min: ScalarValue::try_from(datatype)?, @@ -1338,6 +1322,7 @@ impl Accumulator for SlidingMaxAccumulator { mod tests { use super::*; use std::sync::Arc; + use rand::*; #[test] fn float_min_max_with_nans() { @@ -1372,7 +1357,6 @@ mod tests { check(&mut max(), &[&[zero, neg_inf]], zero); } - use datafusion_common::Result; use rand::Rng; @@ -1440,5 +1424,4 @@ mod tests { moving_max_i32(100, 100)?; Ok(()) } - } diff --git a/datafusion/optimizer/src/eliminate_distinct.rs b/datafusion/optimizer/src/eliminate_distinct.rs deleted file mode 100644 index f1d5877b1b49..000000000000 --- a/datafusion/optimizer/src/eliminate_distinct.rs +++ /dev/null @@ -1,140 +0,0 @@ - -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`EliminateDistinctFromMinMax`] Removes redundant distinct in min and max - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; -use datafusion_expr::expr::AggregateFunction; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Aggregate, Expr}; -use std::sync::OnceLock; - -/// Optimization rule that eliminate redundant distinct in min and max expr. -#[derive(Default)] -pub struct EliminateDistinct; - -impl EliminateDistinct { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} -static WORKSPACE_ROOT_LOCK: OnceLock> = OnceLock::new(); - -fn rewrite_aggr_expr(expr:Expr) -> (bool, Expr) { - match expr { - Expr::AggregateFunction(ref fun) => { - let fn_name = fun.func_def.name().to_lowercase(); - if fun.distinct && WORKSPACE_ROOT_LOCK.get_or_init(|| vec!["min".to_string(), "max".to_string()]).contains(&fn_name) { - (true, Expr::AggregateFunction(AggregateFunction{ - func_def:fun.func_def.clone(), - args:fun.args.clone(), - distinct:false, - filter:fun.filter.clone(), - order_by:fun.order_by.clone(), - null_treatment: fun.null_treatment - })) - } else { - (false, expr) - } - }, - _ => (false, expr) - } -} -impl OptimizerRule for EliminateDistinct { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateDistinct::rewrite") - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Aggregate(agg) => { - let mut aggr_expr = vec![]; - let mut transformed = false; - for expr in agg.aggr_expr { - let rewrite_result = rewrite_aggr_expr(expr); - transformed = transformed || rewrite_result.0; - aggr_expr.push(rewrite_result.1); - } - - println!("Transformed yes {}", transformed); - let transformed = if transformed { - Transformed::yes - } else { - Transformed::no - }; - Aggregate::try_new(agg.input, agg.group_expr, aggr_expr) - .map(|f| transformed(LogicalPlan::Aggregate(f))) - } - _ => Ok(Transformed::no(plan)), - } - } - fn name(&self) -> &str { - "eliminate_distinct" - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test::*; - use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; - use datafusion_expr::AggregateExt; - use datafusion_expr::test::function_stub::min; - use std::sync::Arc; - - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(EliminateDistinct::new()), - plan, - expected, - ) - } - - #[test] - fn eliminate_distinct_from_min_expr() -> Result<()> { - let table_scan = test_table_scan().unwrap(); - let aggr_expr = min(col("b")).distinct().build()?; - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![aggr_expr])? - .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) - } -} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 60a302db7747..332d3e9fe54e 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,7 +35,6 @@ pub mod common_subexpr_eliminate; pub mod decorrelate; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; -pub mod eliminate_distinct; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_group_by_constant; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 08a6498a1f83..08740277e32c 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -92,43 +92,6 @@ pub fn create_aggregate_expr( is_expr_nullable, )) } - (AggregateFunction::Avg, false) => { - Arc::new(Avg::new(input_phy_exprs[0].clone(), name, data_type)) - } - (AggregateFunction::Avg, true) => { - return not_impl_err!("AVG(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Correlation, false) => { - Arc::new(expressions::Correlation::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::Correlation, true) => { - return not_impl_err!("CORR(DISTINCT) aggregations are not available"); - } - (AggregateFunction::NthValue, _) => { - let expr = &input_phy_exprs[0]; - let Some(n) = input_phy_exprs[1] - .as_any() - .downcast_ref::() - .map(|literal| literal.value()) - else { - return exec_err!("Second argument of NTH_VALUE needs to be a literal"); - }; - let nullable = expr.nullable(input_schema)?; - Arc::new(expressions::NthValueAgg::new( - expr.clone(), - n.clone().try_into()?, - name, - input_phy_types[0].clone(), - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } }) } @@ -204,71 +167,6 @@ mod tests { Ok(()) } - #[test] - fn test_avg_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Avg]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Avg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - }; - } - } - Ok(()) - } - - #[test] - fn test_avg_return_type() -> Result<()> { - let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(10, 6)])?; - assert_eq!(DataType::Decimal128(14, 10), observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(36, 6)])?; - assert_eq!(DataType::Decimal128(38, 10), observed); - Ok(()) - } - - #[test] - fn test_avg_no_utf8() { - let observed = AggregateFunction::Avg.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - // Helper function // Create aggregate expr with type coercion fn create_physical_agg_expr_for_test( diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 9187a90c39b3..4a4c56ad62bb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -533,7 +533,6 @@ impl serde::Serialize for AggregateFunction { { let variant = match self { Self::Unused => "UNUSED", - Self::Avg => "AVG", Self::ArrayAgg => "ARRAY_AGG", }; serializer.serialize_str(variant) @@ -590,7 +589,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { { match value { "UNUSED" => Ok(AggregateFunction::Unused), - "AVG" => Ok(AggregateFunction::Avg), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8a927b03c883..279a9075de71 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1973,20 +1973,13 @@ impl AggregateFunction { pub fn as_str_name(&self) -> &'static str { match self { AggregateFunction::Unused => "UNUSED", - AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Correlation => "CORRELATION", - AggregateFunction::Grouping => "GROUPING", - AggregateFunction::BoolAnd => "BOOL_AND", - AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "UNUSED" => Some(Self::Unused), - "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 82eba2f6d3f8..09e1020ce06e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,13 +142,7 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { impl From for AggregateFunction { fn from(agg_fun: protobuf::AggregateFunction) -> Self { match agg_fun { - protobuf::AggregateFunction::Avg => Self::Avg, - protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, - protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::Unused => panic!("This should never happen, we are retiring this but protobuf doesn't support enum with no 0 values"), } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 48a60b1c3861..a5a5c98679db 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -114,9 +114,6 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { impl From<&AggregateFunction> for protobuf::AggregateFunction { fn from(value: &AggregateFunction) -> Self { match value { - AggregateFunction::Avg => Self::Avg, - AggregateFunction::BoolAnd => Self::BoolAnd, - AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, } } @@ -375,14 +372,6 @@ pub fn serialize_expr( AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 8d38a2a39eaf..3a0cef742a84 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, + Ntile, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -249,14 +248,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { } else if aggr_expr.downcast_ref::().is_some() { distinct = true; protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Avg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::NthValueAgg } else { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bc10b6966122..55333e35e368 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -45,8 +45,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, max, grouping, median, min, stddev, - stddev_pop, sum, var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, + stddev, stddev_pop, sum, var_pop, var_sample, }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::prelude::*;