Skip to content

Commit

Permalink
Rebasing on main after several other aggregate functions were removed
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Jul 14, 2024
1 parent c6eb03a commit 06f6f29
Show file tree
Hide file tree
Showing 21 changed files with 36 additions and 429 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use arrow_schema::DataType;
use std::sync::Arc;

use datafusion::error::Result;
use datafusion::logical_expr::test::function_stub::max;
use datafusion::functions_aggregate::average::avg;
use datafusion::logical_expr::test::function_stub::max;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ use datafusion_common::{
};
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
avg, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::{count, max, median, min, stddev, sum};

use async_trait::async_trait;

Expand Down
31 changes: 8 additions & 23 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,23 +171,6 @@ fn take_optimizable_column_and_table_count(
None
}

fn unwrap_min(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> {
if let Some(casted_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if casted_expr.fun().name() == "MIN" {
return Some(casted_expr);
}
}
None
}

fn unwrap_max(agg_expr: &dyn AggregateExpr) -> Option<&AggregateFunctionExpr> {
if let Some(casted_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if casted_expr.fun().name() == "MAX" {
return Some(casted_expr);
}
}
None
}
/// If this agg_expr is a min that is exactly defined in the statistics, return it.
fn take_optimizable_min(
agg_expr: &dyn AggregateExpr,
Expand All @@ -197,7 +180,7 @@ fn take_optimizable_min(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) = unwrap_min(agg_expr) {
if is_min(agg_expr) {
if let Ok(min_data_type) =
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
Expand All @@ -207,8 +190,9 @@ fn take_optimizable_min(
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) = unwrap_min(agg_expr) {
if casted_expr.expressions().len() == 1 {
if is_min(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
Expand Down Expand Up @@ -242,7 +226,7 @@ fn take_optimizable_max(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) = unwrap_max(agg_expr){
if is_max(agg_expr) {
if let Ok(max_data_type) =
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
Expand All @@ -252,8 +236,9 @@ fn take_optimizable_max(
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) = unwrap_max(agg_expr){
if casted_expr.expressions().len() == 1 {
if is_max(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder,
array_agg, cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, max, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, max, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
24 changes: 1 addition & 23 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Average
Avg,
/// Aggregation into an array
ArrayAgg,
}
Expand All @@ -43,7 +41,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
}
}
Expand All @@ -59,11 +56,6 @@ impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
// general
"avg" => AggregateFunction::Avg,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"mean" => AggregateFunction::Avg,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return plan_err!("There is no built-in function named {name}");
Expand Down Expand Up @@ -99,10 +91,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand All @@ -115,7 +103,6 @@ impl AggregateFunction {
/// nullability
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(true),
}
}
Expand All @@ -126,16 +113,7 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
}
}
}
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2540,12 +2540,6 @@ mod test {

#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Avg
))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,13 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
LogicalPlanBuilder,
};

use super::*;
use crate::test::function_stub::min;
use crate::test::function_stub::avg;

#[test]
fn rewrite_sort_cols_by_agg() {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl Default for Min {
impl Min {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
aliases: vec!["min".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
Expand Down Expand Up @@ -412,7 +412,7 @@ impl Default for Max {
impl Max {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
aliases: vec!["max".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
Expand Down
57 changes: 0 additions & 57 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ use arrow::datatypes::{

use datafusion_common::{internal_err, plan_err, Result};

use crate::{AggregateFunction, Signature, TypeSignature};

pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];

pub static SIGNED_INTEGERS: &[DataType] = &[
Expand Down Expand Up @@ -93,53 +91,8 @@ pub fn coerce_types(
) -> Result<Vec<DataType>> {
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
let v = match &input_types[0] {
Decimal128(p, s) => Decimal128(*p, *s),
Decimal256(p, s) => Decimal256(*p, *s),
d if d.is_numeric() => Float64,
Dictionary(_, v) => {
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
}
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
)
}
};
Ok(vec![v])
}
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
if !is_bool_and_or_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}

Expand Down Expand Up @@ -374,16 +327,6 @@ mod tests {
use super::*;
#[test]
fn test_aggregate_coerce_types() {
let fun = AggregateFunction::Avg;
// test input args is invalid data type for avg
let input_types = vec![DataType::Utf8];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Avg does not support inputs of type Utf8.",
result.unwrap_err().strip_backtrace()
);

// test count, array_agg, approx_distinct.
// the coerced types is same with input types
let funs = vec![AggregateFunction::ArrayAgg];
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ paste = "1.0.14"
sqlparser = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
rand = { workspace = true }
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ mod tests {
let migrated_functions = vec!["count", "max", "min"];
for func in all_default_aggregate_functions() {
// TODO: remove this
// These functions are in intermidiate migration state, skip them
// These functions are in intermediate migration state, skip them
if migrated_functions.contains(&func.name().to_lowercase().as_str()) {
continue;
}
Expand Down
Loading

0 comments on commit 06f6f29

Please sign in to comment.