diff --git a/packages/cubejs-backend-native/Cargo.lock b/packages/cubejs-backend-native/Cargo.lock index f3b911b4b5687..0ef353b259fac 100644 --- a/packages/cubejs-backend-native/Cargo.lock +++ b/packages/cubejs-backend-native/Cargo.lock @@ -681,7 +681,7 @@ dependencies = [ [[package]] name = "cube-ext" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "chrono", @@ -838,7 +838,7 @@ checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" [[package]] name = "datafusion" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -871,7 +871,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "ordered-float 2.10.1", @@ -882,7 +882,7 @@ dependencies = [ [[package]] name = "datafusion-data-access" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "async-trait", "chrono", @@ -895,7 +895,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -906,7 +906,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -3015,7 +3015,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=5fe1b77d1a91b80529a0b7af0b89411d3cba5137#5fe1b77d1a91b80529a0b7af0b89411d3cba5137" dependencies = [ "log", ] diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 3bb6734aaddef..04f7725d7607f 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -3193,6 +3193,8 @@ export class BaseQuery { // DATEADD is being rewritten to DATE_ADD // DATEADD: 'DATEADD({{ date_part }}, {{ interval }}, {{ args[2] }})', DATE: 'DATE({{ args_concat }})', + + PERCENTILECONT: 'PERCENTILE_CONT({{ args_concat }})', }, statements: { select: 'SELECT {% if distinct %}DISTINCT {% endif %}' + @@ -3228,6 +3230,7 @@ export class BaseQuery { like: '{{ expr }} {% if negated %}NOT {% endif %}LIKE {{ pattern }}', ilike: '{{ expr }} {% if negated %}NOT {% endif %}ILIKE {{ pattern }}', like_escape: '{{ like_expr }} ESCAPE {{ escape_char }}', + within_group: '{{ fun_sql }} WITHIN GROUP (ORDER BY {{ within_group_concat }})', }, quotes: { identifiers: '"', diff --git a/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.ts b/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.ts index 7c34a4a0d8e9a..1119f23a3fcd6 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.ts @@ -245,6 +245,7 @@ export class BigqueryQuery extends BaseQuery { // templates.functions.DATEADD = 'DATETIME_ADD(CAST({{ args[2] }} AS DATETTIME), INTERVAL {{ interval }} {{ date_part }})'; templates.functions.CURRENTDATE = 'CURRENT_DATE'; delete templates.functions.TO_CHAR; + delete templates.functions.PERCENTILECONT; templates.expressions.binary = '{% if op == \'%\' %}MOD({{ left }}, {{ right }}){% else %}({{ left }} {{ op }} {{ right }}){% endif %}'; templates.expressions.interval = 'INTERVAL {{ interval }}'; templates.expressions.extract = 'EXTRACT({% if date_part == \'DOW\' %}DAYOFWEEK{% elif date_part == \'DOY\' %}DAYOFYEAR{% else %}{{ date_part }}{% endif %} FROM {{ expr }})'; diff --git a/packages/cubejs-schema-compiler/src/adapter/ClickHouseQuery.ts b/packages/cubejs-schema-compiler/src/adapter/ClickHouseQuery.ts index ed393b971a453..0de087f6e078f 100644 --- a/packages/cubejs-schema-compiler/src/adapter/ClickHouseQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/ClickHouseQuery.ts @@ -271,6 +271,7 @@ export class ClickHouseQuery extends BaseQuery { templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})'; // TODO: Introduce additional filter in jinja? or parseDateTimeBestEffort? // https://github.com/ClickHouse/ClickHouse/issues/19351 + delete templates.functions.PERCENTILECONT; templates.expressions.timestamp_literal = 'parseDateTimeBestEffort(\'{{ value }}\')'; delete templates.expressions.like_escape; templates.quotes.identifiers = '`'; diff --git a/packages/cubejs-schema-compiler/src/adapter/MssqlQuery.ts b/packages/cubejs-schema-compiler/src/adapter/MssqlQuery.ts index 9b1bae70a04a9..8abfee919fbb0 100644 --- a/packages/cubejs-schema-compiler/src/adapter/MssqlQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/MssqlQuery.ts @@ -223,6 +223,8 @@ export class MssqlQuery extends BaseQuery { const templates = super.sqlTemplates(); templates.functions.LEAST = 'LEAST({{ args_concat }})'; templates.functions.GREATEST = 'GREATEST({{ args_concat }})'; + // PERCENTILE_CONT works but requires PARTITION BY + delete templates.functions.PERCENTILECONT; delete templates.expressions.ilike; templates.types.string = 'VARCHAR'; templates.types.boolean = 'BIT'; diff --git a/packages/cubejs-schema-compiler/src/adapter/MysqlQuery.ts b/packages/cubejs-schema-compiler/src/adapter/MysqlQuery.ts index 6a439af332e69..eee360e2ed0eb 100644 --- a/packages/cubejs-schema-compiler/src/adapter/MysqlQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/MysqlQuery.ts @@ -156,6 +156,8 @@ export class MysqlQuery extends BaseQuery { public sqlTemplates() { const templates = super.sqlTemplates(); + // PERCENTILE_CONT works but requires PARTITION BY + delete templates.functions.PERCENTILECONT; templates.quotes.identifiers = '`'; templates.quotes.escape = '\\`'; delete templates.expressions.ilike; diff --git a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts index 6852f94e52f0a..a83e65b2695ff 100644 --- a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts @@ -112,6 +112,7 @@ export class PrestodbQuery extends BaseQuery { const templates = super.sqlTemplates(); templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})'; templates.functions.DATEPART = 'DATE_PART({{ args_concat }})'; + delete templates.functions.PERCENTILECONT; templates.statements.select = 'SELECT {{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' + 'FROM (\n {{ from }}\n) AS {{ from_alias }} \n' + '{% if group_by %} GROUP BY {{ group_by }}{% endif %}' + diff --git a/rust/cubesql/Cargo.lock b/rust/cubesql/Cargo.lock index a832246ac7504..f89a076d9267d 100644 --- a/rust/cubesql/Cargo.lock +++ b/rust/cubesql/Cargo.lock @@ -721,7 +721,7 @@ dependencies = [ [[package]] name = "cube-ext" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "chrono", @@ -851,7 +851,7 @@ dependencies = [ [[package]] name = "datafusion" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -884,7 +884,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "ordered-float 2.10.0", @@ -895,7 +895,7 @@ dependencies = [ [[package]] name = "datafusion-data-access" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "async-trait", "chrono", @@ -908,7 +908,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -919,7 +919,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -2877,7 +2877,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=5fe1b77d1a91b80529a0b7af0b89411d3cba5137#5fe1b77d1a91b80529a0b7af0b89411d3cba5137" dependencies = [ "log", ] diff --git a/rust/cubesql/cubesql/Cargo.toml b/rust/cubesql/cubesql/Cargo.toml index e88ffd9fc6346..691f8ffeafe2f 100644 --- a/rust/cubesql/cubesql/Cargo.toml +++ b/rust/cubesql/cubesql/Cargo.toml @@ -10,12 +10,12 @@ homepage = "https://cube.dev" [dependencies] arc-swap = "1" -datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "dcf3e4aa26fd112043ef26fa4a78db5dbd443c86", default-features = false, features = ["regex_expressions", "unicode_expressions"] } +datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "11a4ed10b184b2f1b22f7458702ae0c63f011241", default-features = false, features = ["regex_expressions", "unicode_expressions"] } anyhow = "1.0" thiserror = "1.0.50" cubeclient = { path = "../cubeclient" } pg-srv = { path = "../pg-srv" } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" } base64 = "0.13.0" tokio = { version = "^1.35", features = ["full", "rt", "tracing"] } serde = { version = "^1.0", features = ["derive"] } diff --git a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs index 459a7525b524d..4384f19bf7fb2 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs @@ -242,15 +242,31 @@ pub fn rewrite(expr: &Expr, map: &HashMap>) -> Result args - .iter() - .map(|arg| rewrite(arg, map)) - .collect::>>>()? - .map(|args| Expr::AggregateFunction { - fun: fun.clone(), - args, - distinct: distinct.clone(), - }), + within_group, + } => { + let args = args + .iter() + .map(|arg| rewrite(arg, map)) + .collect::>>>()?; + let within_group = match within_group.as_ref() { + Some(within_group) => within_group + .iter() + .map(|expr| rewrite(expr, map)) + .collect::>>>()? + .map(|within_group| Some(within_group)), + None => Some(None), + }; + if let (Some(args), Some(within_group)) = (args, within_group) { + Some(Expr::AggregateFunction { + fun: fun.clone(), + args, + distinct: distinct.clone(), + within_group, + }) + } else { + None + } + } Expr::WindowFunction { fun, args, diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index b66fa831cab02..809fe40736bf6 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -1921,8 +1921,10 @@ impl CubeScanWrapperNode { fun, args, distinct, + within_group, } => { let mut sql_args = Vec::new(); + let mut sql_within_group = Vec::new(); for arg in args { if let AggregateFunction::Count = fun { if !distinct { @@ -1944,10 +1946,25 @@ impl CubeScanWrapperNode { sql_query = query; sql_args.push(sql); } + if let Some(within_group) = within_group { + for expr in within_group { + let (sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + expr, + ungrouped_scan_node.clone(), + subqueries.clone(), + ) + .await?; + sql_query = query; + sql_within_group.push(sql); + } + } Ok(( sql_generator .get_sql_templates() - .aggregate_function(fun, sql_args, distinct) + .aggregate_function(fun, sql_args, distinct, sql_within_group) .map_err(|e| { DataFusionError::Internal(format!( "Can't generate SQL for aggregate function: {}", diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 10f7a4587ba28..f53a80203fa22 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -18588,4 +18588,36 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), Ok(()) } + + #[tokio::test] + async fn test_within_group_push_down() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY taxful_total_price) AS pc + FROM KibanaSampleDataEcommerce + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql; + assert!(sql.contains("WITHIN GROUP (ORDER BY")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs index fd908d44d21a9..9c92346fa5014 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs @@ -424,6 +424,7 @@ impl LogicalPlanToLanguageConverter { fun, args, distinct, + within_group, } => { let fun = add_expr_data_node!(graph, fun, AggregateFunctionExprFun); let args = add_expr_list_node!( @@ -434,8 +435,18 @@ impl LogicalPlanToLanguageConverter { flat_list ); let distinct = add_expr_data_node!(graph, distinct, AggregateFunctionExprDistinct); + let within_group = add_expr_list_node!( + graph, + within_group.as_ref().unwrap_or(&vec![]), + query_params, + AggregateFunctionExprWithinGroup, + flat_list + ); graph.add(LogicalPlanLanguage::AggregateFunctionExpr([ - fun, args, distinct, + fun, + args, + distinct, + within_group, ])) } Expr::WindowFunction { @@ -1145,10 +1156,20 @@ pub fn node_to_expr( let args = match_expr_list_node!(node_by_id, to_expr, params[1], AggregateFunctionExprArgs); let distinct = match_data_node!(node_by_id, params[2], AggregateFunctionExprDistinct); + let within_group = match_expr_list_node!( + node_by_id, + to_expr, + params[3], + AggregateFunctionExprWithinGroup + ); Expr::AggregateFunction { fun, args, distinct, + within_group: match within_group.len() { + 0 => None, + _ => Some(within_group), + }, } } LogicalPlanLanguage::WindowFunctionExpr(params) => { diff --git a/rust/cubesql/cubesql/src/compile/rewrite/language.rs b/rust/cubesql/cubesql/src/compile/rewrite/language.rs index 1bae740891932..c7d42e910ac44 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/language.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/language.rs @@ -312,6 +312,7 @@ macro_rules! variant_field_struct { AggregateFunction::ApproxMedian => "ApproxMedian", AggregateFunction::BoolAnd => "BoolAnd", AggregateFunction::BoolOr => "BoolOr", + AggregateFunction::PercentileCont => "PercentileCont", } ); }; diff --git a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs index a63b4d97da4a0..344e3bd4e3524 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs @@ -233,6 +233,7 @@ crate::plan_to_language! { fun: AggregateFunction, args: Vec, distinct: bool, + within_group: Vec, }, WindowFunctionExpr { fun: WindowFunction, @@ -1319,21 +1320,39 @@ fn scalar_fun_expr_args_empty_tail() -> String { fun_expr_args_empty_tail() } -fn agg_fun_expr(fun_name: impl Display, args: Vec, distinct: impl Display) -> String { +fn agg_fun_expr( + fun_name: impl Display, + args: Vec, + distinct: impl Display, + within_group: impl Display, +) -> String { let prefix = if fun_name.to_string().starts_with("?") { "" } else { "AggregateFunctionExprFun:" }; format!( - "(AggregateFunctionExpr {}{} {} {})", + "(AggregateFunctionExpr {}{} {} {} {})", prefix, fun_name, list_expr("AggregateFunctionExprArgs", args), - distinct + distinct, + within_group, ) } +fn agg_fun_expr_within_group(left: impl Display, right: impl Display) -> String { + format!("(AggregateFunctionExprWithinGroup {} {})", left, right) +} + +fn agg_fun_expr_within_group_list(order_by: Vec) -> String { + list_expr("AggregateFunctionExprWithinGroup", order_by) +} + +fn agg_fun_expr_within_group_empty_tail() -> String { + agg_fun_expr_within_group_list(Vec::::new()) +} + fn window_fun_expr_var_arg( fun_name: impl Display, arg_list: impl Display, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/common.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/common.rs index 7860e92635edc..56ade34528986 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/common.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/common.rs @@ -1,6 +1,6 @@ use crate::{ compile::rewrite::{ - agg_fun_expr, alias_expr, + agg_fun_expr, agg_fun_expr_within_group_empty_tail, alias_expr, analysis::{ConstantFolding, LogicalPlanAnalysis, OriginalExpr}, binary_expr, column_expr, fun_expr, rewriter::{RewriteRules, Rewriter}, @@ -29,10 +29,16 @@ impl RewriteRules for CommonRules { "Sum", vec![binary_expr(column_expr("?column"), "/", "?literal")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), alias_expr( binary_expr( - agg_fun_expr("Sum", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "Sum", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), "/", "?literal", ), @@ -46,10 +52,16 @@ impl RewriteRules for CommonRules { "Sum", vec![binary_expr(column_expr("?column"), "*", "?literal")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), alias_expr( binary_expr( - agg_fun_expr("Sum", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "Sum", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), "*", "?literal", ), diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/pass_through.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/pass_through.rs index 26e560a1a3350..d306d624bf93d 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/pass_through.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/pass_through.rs @@ -1,7 +1,7 @@ use crate::compile::rewrite::{ - agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, binary_expr, cast_expr, - flatten_pushdown_replacer, fun_expr_var_arg, is_not_null_expr, is_null_expr, rewrite, - rules::flatten::FlattenRules, udf_expr_var_arg, LogicalPlanLanguage, + agg_fun_expr, agg_fun_expr_within_group_empty_tail, alias_expr, analysis::LogicalPlanAnalysis, + binary_expr, cast_expr, flatten_pushdown_replacer, fun_expr_var_arg, is_not_null_expr, + is_null_expr, rewrite, rules::flatten::FlattenRules, udf_expr_var_arg, LogicalPlanLanguage, }; use egg::Rewrite; @@ -19,7 +19,14 @@ impl FlattenRules { ); self.single_arg_pass_through_rules( "agg-function", - |expr| agg_fun_expr("?fun", vec![expr], "?distinct"), + |expr| { + agg_fun_expr( + "?fun", + vec![expr], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, rules, ); self.single_arg_pass_through_rules( diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs index 79fa094a995ea..2c2820d50dc94 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs @@ -1,6 +1,6 @@ use crate::{ compile::rewrite::{ - agg_fun_expr, aggregate, alias_expr, all_members, + agg_fun_expr, agg_fun_expr_within_group_empty_tail, aggregate, alias_expr, all_members, analysis::{ ConstantFolding, LogicalPlanAnalysis, LogicalPlanData, MemberNamesToExpr, OriginalExpr, }, @@ -89,7 +89,12 @@ impl RewriteRules for MemberRules { ), self.measure_rewrite( "simple-count", - agg_fun_expr("?aggr_fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr( + "?aggr_fun", + vec![literal_expr("?literal")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), None, Some("?distinct"), Some("?aggr_fun"), @@ -97,7 +102,12 @@ impl RewriteRules for MemberRules { ), self.measure_rewrite( "named", - agg_fun_expr("?aggr_fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?aggr_fun", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), Some("?column"), Some("?distinct"), Some("?aggr_fun"), @@ -114,6 +124,7 @@ impl RewriteRules for MemberRules { "?aggr_fun", vec![cast_expr(column_expr("?column"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), Some("?column"), Some("?distinct"), @@ -590,12 +601,22 @@ impl MemberRules { )); rules.extend(find_matching_old_member( "agg-fun", - agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), )); rules.extend(find_matching_old_member( "agg-fun-alias", alias_expr( - agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), "?alias", ), )); @@ -609,6 +630,7 @@ impl MemberRules { "Count", vec![literal_expr("?any")], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ), true, )); @@ -619,6 +641,7 @@ impl MemberRules { "Count", vec![literal_expr("?any")], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ), "?alias", ), @@ -631,6 +654,7 @@ impl MemberRules { "?fun_name", vec![cast_expr(column_expr("?column"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), )); rules.extend(find_matching_old_member( @@ -1014,7 +1038,12 @@ impl MemberRules { ) { rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-agg-fun", - agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), measure_expr("?name", "?old_alias"), Some("?fun_name"), Some("?distinct"), @@ -1024,7 +1053,12 @@ impl MemberRules { rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-agg-fun-alias", alias_expr( - agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), "?alias", ), measure_expr("?name", "?old_alias"), @@ -1039,6 +1073,7 @@ impl MemberRules { "?fun_name", vec![cast_expr(column_expr("?column"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), measure_expr("?name", "?old_alias"), Some("?fun_name"), @@ -1048,7 +1083,12 @@ impl MemberRules { )); rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-agg-fun-on-dimension", - agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), dimension_expr("?name", "?old_alias"), Some("?fun_name"), Some("?distinct"), @@ -1075,7 +1115,12 @@ impl MemberRules { )); rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-agg-fun-default-count", - agg_fun_expr("?fun_name", vec![literal_expr("?any")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![literal_expr("?any")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), measure_expr("?name", "?old_alias"), Some("?fun_name"), Some("?distinct"), @@ -1085,7 +1130,12 @@ impl MemberRules { rules.push(pushdown_measure_rewrite( "member-pushdown-replacer-agg-fun-default-count-alias", alias_expr( - agg_fun_expr("?fun_name", vec![literal_expr("?any")], "?distinct"), + agg_fun_expr( + "?fun_name", + vec![literal_expr("?any")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), "?alias", ), measure_expr("?name", "?old_alias"), diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs index 2b141f124ee75..a0ae4fa0da981 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs @@ -1,9 +1,9 @@ use super::utils; use crate::{ compile::rewrite::{ - agg_fun_expr, aggr_aggr_expr_empty_tail, aggr_aggr_expr_legacy as aggr_aggr_expr, - aggr_group_expr_empty_tail, aggr_group_expr_legacy as aggr_group_expr, aggregate, - alias_expr, + agg_fun_expr, agg_fun_expr_within_group_empty_tail, aggr_aggr_expr_empty_tail, + aggr_aggr_expr_legacy as aggr_aggr_expr, aggr_group_expr_empty_tail, + aggr_group_expr_legacy as aggr_group_expr, aggregate, alias_expr, analysis::{ConstantFolding, LogicalPlanAnalysis, OriginalExpr}, binary_expr, cast_expr, cast_expr_explicit, column_expr, cube_scan, event_notification, fun_expr, group_aggregate_split_replacer, group_expr_split_replacer, @@ -491,6 +491,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![cast_expr(literal_expr("?expr"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), "?cube", ), @@ -504,6 +505,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![cast_expr(literal_expr("?expr"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), "?cube", ), @@ -511,6 +513,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![cast_expr(literal_expr("?expr"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), self.transform_aggr_fun_with_literal("?fun", "?expr"), ), @@ -521,6 +524,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![cast_expr(literal_expr("?expr"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), "?cube", ), @@ -528,6 +532,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![cast_expr(literal_expr("?expr"), "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), self.transform_aggr_fun_with_literal("?fun", "?expr"), ), @@ -1453,7 +1458,7 @@ impl RewriteRules for OldSplitRules { transforming_chain_rewrite( "split-push-down-aggr-fun-with-date-trunc-inner-aggr-replacer", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec!["?expr".to_string()], "?distinct"), + agg_fun_expr("?fun", vec!["?expr".to_string()], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), vec![( @@ -1486,7 +1491,7 @@ impl RewriteRules for OldSplitRules { transforming_chain_rewrite( "split-push-down-aggr-fun-with-date-trunc-outer-aggr-replacer", outer_aggregate_split_replacer( - agg_fun_expr("?fun", vec!["?expr".to_string()], "?distinct"), + agg_fun_expr("?fun", vec!["?expr".to_string()], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), vec![( @@ -1500,6 +1505,7 @@ impl RewriteRules for OldSplitRules { "?fun", vec![alias_expr("?alias_column", "?alias")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), MemberRules::transform_original_expr_date_trunc( "?expr", @@ -1517,23 +1523,23 @@ impl RewriteRules for OldSplitRules { "?aggr_expr", "?cube", ), - vec![("?aggr_expr", agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"))], + vec![("?aggr_expr", agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct", agg_fun_expr_within_group_empty_tail()))], "?out_expr".to_string(), self.transform_inner_measure("?cube", Some("?column"), Some("?aggr_expr"), Some("?fun"), Some("?distinct"), Some("?out_expr")), ), transforming_rewrite( "split-push-down-aggr-fun-inner-replacer-simple-count", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), self.transform_inner_measure("?cube", None, None, None, None, None), ), transforming_rewrite( "split-push-down-aggr-fun-inner-replacer-missing-count", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), aggr_aggr_expr_empty_tail(), @@ -1544,7 +1550,7 @@ impl RewriteRules for OldSplitRules { outer_projection_split_replacer("?expr", "?cube"), vec![( "?expr", - agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct", agg_fun_expr_within_group_empty_tail()), )], "?alias".to_string(), self.transform_outer_projection_aggr_fun("?cube", "?expr", Some("?column"), "?alias"), @@ -1554,7 +1560,7 @@ impl RewriteRules for OldSplitRules { outer_projection_split_replacer("?expr", "?cube"), vec![( "?expr", - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), )], "?alias".to_string(), self.transform_outer_projection_aggr_fun("?cube", "?expr", None, "?alias"), @@ -1563,11 +1569,11 @@ impl RewriteRules for OldSplitRules { "split-push-down-aggr-fun-outer-aggr-replacer", outer_aggregate_split_replacer("?expr", "?cube"), vec![ - ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct")), + ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail())), ("?arg", column_expr("?column")), ], alias_expr( - agg_fun_expr("?output_fun", vec!["?alias".to_string()], "?distinct"), + agg_fun_expr("?output_fun", vec!["?alias".to_string()], "?distinct", agg_fun_expr_within_group_empty_tail()), "?outer_alias", ), self.transform_outer_aggr_fun( @@ -1588,11 +1594,11 @@ impl RewriteRules for OldSplitRules { "split-push-down-aggr-fun-outer-aggr-replacer-simple-count", outer_aggregate_split_replacer("?expr", "?cube"), vec![ - ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct")), + ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail())), ("?arg", literal_expr("?literal")), ], alias_expr( - agg_fun_expr("?output_fun", vec!["?alias".to_string()], "?distinct"), + agg_fun_expr("?output_fun", vec!["?alias".to_string()], "?distinct", agg_fun_expr_within_group_empty_tail()), "?outer_alias", ), self.transform_outer_aggr_fun( @@ -1612,17 +1618,17 @@ impl RewriteRules for OldSplitRules { transforming_rewrite( "split-push-down-aggr-fun-outer-aggr-replacer-missing-count", outer_aggregate_split_replacer( - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), - agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"), + agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct", agg_fun_expr_within_group_empty_tail()), self.transform_outer_aggr_fun_missing_count("?cube", "?fun"), ), transforming_chain_rewrite( "split-push-down-aggr-fun-dateadd-outer-aggr-replacer", outer_aggregate_split_replacer("?expr", "?cube"), vec![ - ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct")), + ("?expr", agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail())), ( "?arg", udf_expr( @@ -1653,6 +1659,7 @@ impl RewriteRules for OldSplitRules { ], )], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), "?outer_alias", ), @@ -1681,11 +1688,12 @@ impl RewriteRules for OldSplitRules { ArrowDataType::Float64, )], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), "?alias_to_cube", ), inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?alias_to_cube", ), ), @@ -1701,12 +1709,13 @@ impl RewriteRules for OldSplitRules { ArrowDataType::Float64, )], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), )], alias_expr( cast_expr_explicit( outer_projection_split_replacer( - agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?alias_to_cube", ), ArrowDataType::Float64, @@ -1727,12 +1736,13 @@ impl RewriteRules for OldSplitRules { ArrowDataType::Float64, )], "?distinct", + agg_fun_expr_within_group_empty_tail(), ), )], alias_expr( cast_expr_explicit( outer_aggregate_split_replacer( - agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct", agg_fun_expr_within_group_empty_tail()), "?alias_to_cube", ), ArrowDataType::Float64, @@ -1747,7 +1757,7 @@ impl RewriteRules for OldSplitRules { transforming_chain_rewrite( "split-push-down-aggr-min-max-date-trunc-fun-inner-replacer", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec!["?arg"], "?distinct"), + agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), vec![("?arg", column_expr("?column"))], @@ -1765,7 +1775,7 @@ impl RewriteRules for OldSplitRules { transforming_chain_rewrite( "split-push-down-aggr-min-max-dimension-fun-inner-replacer", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec!["?arg"], "?distinct"), + agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), vec![("?arg", column_expr("?column"))], @@ -1777,7 +1787,7 @@ impl RewriteRules for OldSplitRules { transforming_chain_rewrite( "split-push-down-aggr-min-max-dimension-fun-dateadd-inner-replacer", inner_aggregate_split_replacer( - agg_fun_expr("?fun", vec!["?arg"], "?distinct"), + agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail()), "?cube", ), vec![( @@ -1804,6 +1814,7 @@ impl RewriteRules for OldSplitRules { "ApproxDistinct", vec![column_expr("?column")], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ), "?alias_to_cube", ), @@ -1817,6 +1828,7 @@ impl RewriteRules for OldSplitRules { "ApproxDistinct", vec![column_expr("?column")], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ), "?alias_to_cube", ), @@ -1827,6 +1839,7 @@ impl RewriteRules for OldSplitRules { "?alias_to_cube", )], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ), self.transform_outer_aggr_dimension("?alias_to_cube", "?column"), ), @@ -1837,6 +1850,7 @@ impl RewriteRules for OldSplitRules { "Count", vec![column_expr("?column")], "AggregateFunctionExprDistinct:true", + agg_fun_expr_within_group_empty_tail(), ), "?alias_to_cube", ), @@ -1850,6 +1864,7 @@ impl RewriteRules for OldSplitRules { "Count", vec![column_expr("?column")], "AggregateFunctionExprDistinct:true", + agg_fun_expr_within_group_empty_tail(), ), "?alias_to_cube", ), @@ -1860,6 +1875,7 @@ impl RewriteRules for OldSplitRules { "?alias_to_cube", )], "AggregateFunctionExprDistinct:true", + agg_fun_expr_within_group_empty_tail(), ), self.transform_outer_aggr_dimension("?alias_to_cube", "?column"), ), @@ -2340,7 +2356,7 @@ impl RewriteRules for OldSplitRules { "split-count-distinct-to-sum-notification", outer_aggregate_split_replacer("?agg_fun", "?cube"), vec![ - ("?agg_fun", agg_fun_expr("?fun", vec!["?arg"], "?distinct")), + ("?agg_fun", agg_fun_expr("?fun", vec!["?arg"], "?distinct", agg_fun_expr_within_group_empty_tail())), ("?arg", column_expr("?column")), ("?fun", "AggregateFunctionExprFun:Count".to_string()), ( @@ -2355,6 +2371,7 @@ impl RewriteRules for OldSplitRules { "Sum", vec!["?alias".to_string()], "AggregateFunctionExprDistinct:false".to_string(), + agg_fun_expr_within_group_empty_tail(), ), "?outer_alias", ), @@ -4733,6 +4750,7 @@ impl OldSplitRules { "?output_fun", vec![applier(group_aggregate_split_replacer)], "?distinct", + agg_fun_expr_within_group_empty_tail(), ) }, self.transform_group_aggregate_measure( @@ -4761,9 +4779,21 @@ impl OldSplitRules { ), transforming_chain_rewrite( &format!("{}-unwrap-group-aggr-agg-fun", base_name), - applier(|expr, _| agg_fun_expr("?output_fun", vec![expr], "?distinct")), + applier(|expr, _| { + agg_fun_expr( + "?output_fun", + vec![expr], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }), unwrap_agg_chain, - agg_fun_expr("?output_fun", vec![column_expr("?column")], "?distinct"), + agg_fun_expr( + "?output_fun", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ), |_, _| true, ), ] @@ -5191,9 +5221,14 @@ impl OldSplitRules { vec![column_expr, tail], ), ); + let within_group = egraph.add( + LogicalPlanLanguage::AggregateFunctionExprWithinGroup( + vec![], + ), + ); let aggr_expr = egraph.add( LogicalPlanLanguage::AggregateFunctionExpr( - [measure_fun, args, measure_distinct], + [measure_fun, args, measure_distinct, within_group], ), ); let alias = egraph.add( @@ -5521,6 +5556,7 @@ impl OldSplitRules { fun: utils::reaggragate_fun(&agg_type)?, args: vec![expr], distinct: false, + within_group: None, }; let expr_name = aggr_expr.name(&DFSchema::empty()).ok()?; diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/split/aggregate_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/split/aggregate_function.rs index ca01488d6e278..bcbb1503be4e6 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/split/aggregate_function.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/split/aggregate_function.rs @@ -1,6 +1,6 @@ use crate::{ compile::rewrite::{ - agg_fun_expr, alias_expr, + agg_fun_expr, agg_fun_expr_within_group_empty_tail, alias_expr, analysis::{ConstantFolding, LogicalPlanAnalysis}, case_expr, cast_expr, column_expr, is_null_expr, literal_expr, literal_int, rules::{members::MemberRules, split::SplitRules}, @@ -21,10 +21,31 @@ impl SplitRules { ) { self.single_arg_split_point_rules_aggregate_function( "aggregate-function", - || agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), - || agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + || { + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, + || { + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, // ?distinct would always match - |alias_column| agg_fun_expr("?output_fun_name", vec![alias_column], "?distinct"), + |alias_column| { + agg_fun_expr( + "?output_fun_name", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, |alias_column| alias_column, self.transform_aggregate_function( Some("?fun_name"), @@ -54,15 +75,24 @@ impl SplitRules { "?data_type", )], "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, + || { + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), ) }, - || agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), // ?distinct would always match |alias_column| { agg_fun_expr( "?output_fun_name", vec![cast_expr(alias_column, "?data_type")], "?distinct", + agg_fun_expr_within_group_empty_tail(), ) }, |alias_column| cast_expr(alias_column, "?data_type"), @@ -86,10 +116,31 @@ impl SplitRules { ); self.single_arg_split_point_rules_aggregate_function( "aggregate-function-simple-count", - || agg_fun_expr("?fun_name", vec![literal_expr("?literal")], "?distinct"), - || agg_fun_expr("?fun_name", vec![literal_expr("?literal")], "?distinct"), + || { + agg_fun_expr( + "?fun_name", + vec![literal_expr("?literal")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, + || { + agg_fun_expr( + "?fun_name", + vec![literal_expr("?literal")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, // ?distinct would always match - |alias_column| agg_fun_expr("?output_fun_name", vec![alias_column], "?distinct"), + |alias_column| { + agg_fun_expr( + "?output_fun_name", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, |alias_column| alias_column, self.transform_aggregate_function( Some("?fun_name"), @@ -111,16 +162,31 @@ impl SplitRules { ); self.single_arg_split_point_rules_aggregate_function( "aggregate-function-non-matching-count", - || agg_fun_expr("?fun_name", vec![column_expr("?column")], "?distinct"), + || { + agg_fun_expr( + "?fun_name", + vec![column_expr("?column")], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, || { agg_fun_expr( "Count", vec![literal_int(1)], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), ) }, // ?distinct would always match - |alias_column| agg_fun_expr("?output_fun_name", vec![alias_column], "?distinct"), + |alias_column| { + agg_fun_expr( + "?output_fun_name", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, |alias_column| alias_column, self.transform_aggregate_function( Some("?fun_name"), @@ -142,15 +208,30 @@ impl SplitRules { ); self.single_arg_split_point_rules_aggregate_function( "aggregate-function-sum-count-constant", - || agg_fun_expr("?fun_name", vec![literal_int(1)], "?distinct"), + || { + agg_fun_expr( + "?fun_name", + vec![literal_int(1)], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, || { agg_fun_expr( "Count", vec![literal_int(1)], "AggregateFunctionExprDistinct:false", + agg_fun_expr_within_group_empty_tail(), + ) + }, + |alias_column| { + agg_fun_expr( + "?output_fun_name", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), ) }, - |alias_column| agg_fun_expr("?output_fun_name", vec![alias_column], "?distinct"), |alias_column| alias_column, self.transform_aggregate_function( Some("?fun_name"), @@ -172,10 +253,24 @@ impl SplitRules { ); self.single_arg_split_point_rules( "aggregate-function-invariant-constant", - || agg_fun_expr("?fun_name", vec!["?constant"], "?distinct"), + || { + agg_fun_expr( + "?fun_name", + vec!["?constant"], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, || "?constant".to_string(), // ?distinct would always match - |alias_column| agg_fun_expr("?fun_name", vec![alias_column], "?distinct"), + |alias_column| { + agg_fun_expr( + "?fun_name", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, self.transform_invariant_constant("?fun_name", "?distinct", "?constant"), false, rules, @@ -193,10 +288,18 @@ impl SplitRules { Some(literal_int(0)), )], "?distinct", + agg_fun_expr_within_group_empty_tail(), ) }, || literal_int(0), - |alias_column| agg_fun_expr("Max", vec![alias_column], "?distinct"), + |alias_column| { + agg_fun_expr( + "Max", + vec![alias_column], + "?distinct", + agg_fun_expr_within_group_empty_tail(), + ) + }, |alias_column| alias_column, self.transform_powerbi_max_case("?column", "?alias_to_cube"), self.transform_powerbi_max_case("?column", "?alias_to_cube"), diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs index 3d5280f017912..26107983adc21 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs @@ -1,9 +1,9 @@ use crate::{ compile::rewrite::{ - agg_fun_expr, analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules, - transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer, - AggregateFunctionExprDistinct, AggregateFunctionExprFun, LogicalPlanLanguage, - WrapperPullupReplacerAliasToCube, + agg_fun_expr, agg_fun_expr_within_group, agg_fun_expr_within_group_empty_tail, + analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules, transforming_rewrite, + wrapper_pullup_replacer, wrapper_pushdown_replacer, AggregateFunctionExprDistinct, + AggregateFunctionExprFun, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube, }, var, var_iter, }; @@ -11,7 +11,7 @@ use datafusion::physical_plan::aggregates::AggregateFunction; use egg::{EGraph, Rewrite, Subst}; impl WrapperRules { - pub fn window_function_rules( + pub fn aggregate_function_rules( &self, rules: &mut Vec>, ) { @@ -19,7 +19,7 @@ impl WrapperRules { rewrite( "wrapper-push-down-aggregate-function", wrapper_pushdown_replacer( - agg_fun_expr("?fun", vec!["?expr"], "?distinct"), + agg_fun_expr("?fun", vec!["?expr"], "?distinct", "?within_group"), "?alias_to_cube", "?ungrouped", "?in_projection", @@ -35,6 +35,13 @@ impl WrapperRules { "?cube_members", )], "?distinct", + wrapper_pushdown_replacer( + "?within_group", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), ), ), transforming_rewrite( @@ -49,15 +56,91 @@ impl WrapperRules { "?cube_members", )], "?distinct", + wrapper_pullup_replacer( + "?within_group", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + ), + wrapper_pullup_replacer( + agg_fun_expr("?fun", vec!["?expr"], "?distinct", "?within_group"), + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube", "?within_group"), + ), + rewrite( + "wrapper-push-down-aggregate-function-within-group", + wrapper_pushdown_replacer( + agg_fun_expr_within_group("?left", "?right"), + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + agg_fun_expr_within_group( + wrapper_pushdown_replacer( + "?left", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?right", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + ), + ), + rewrite( + "wrapper-push-down-aggregate-function-within-group-empty-tail", + wrapper_pushdown_replacer( + agg_fun_expr_within_group_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", ), wrapper_pullup_replacer( - agg_fun_expr("?fun", vec!["?expr"], "?distinct"), + agg_fun_expr_within_group_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + ), + rewrite( + "wrapper-pull-up-aggregate-function-within-group", + agg_fun_expr_within_group( + wrapper_pullup_replacer( + "?left", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + wrapper_pullup_replacer( + "?right", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + ), + wrapper_pullup_replacer( + agg_fun_expr_within_group("?left", "?right"), "?alias_to_cube", "?ungrouped", "?in_projection", "?cube_members", ), - self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube"), ), ]); } @@ -67,10 +150,12 @@ impl WrapperRules { fun_var: &'static str, distinct_var: &'static str, alias_to_cube_var: &'static str, + within_group_var: &'static str, ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { let fun_var = var!(fun_var); let distinct_var = var!(distinct_var); let alias_to_cube_var = var!(alias_to_cube_var); + let within_group_var = var!(within_group_var); let meta = self.meta_context.clone(); move |egraph, subst| { for alias_to_cube in var_iter!( @@ -91,12 +176,32 @@ impl WrapperRules { fun.to_string() }; - if sql_generator + if !sql_generator .get_sql_templates() .templates .contains_key(&format!("functions/{}", fun.as_str())) { - return true; + continue; + } + + for within_group_node in &egraph[subst[within_group_var]].nodes { + match within_group_node { + LogicalPlanLanguage::AggregateFunctionExprWithinGroup( + nodes, + ) => { + if nodes.len() > 0 { + if !sql_generator + .get_sql_templates() + .templates + .contains_key("expressions/within_group") + { + continue; + } + } + return true; + } + _ => (), + } } } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs index 73bb990bc4b4c..fcbd1fefced23 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs @@ -10,7 +10,7 @@ use datafusion::physical_plan::windows::WindowFunction; use egg::{EGraph, Rewrite, Subst}; impl WrapperRules { - pub fn aggregate_function_rules( + pub fn window_function_rules( &self, rules: &mut Vec>, ) { diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 544ff648bdf4b..3254de10d1e12 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -491,6 +491,7 @@ pub fn sql_generator( ("functions/RIGHT".to_string(), "RIGHT({{ args_concat }})".to_string()), ("functions/LOWER".to_string(), "LOWER({{ args_concat }})".to_string()), ("functions/UPPER".to_string(), "UPPER({{ args_concat }})".to_string()), + ("functions/PERCENTILECONT".to_string(), "PERCENTILE_CONT({{ args_concat }})".to_string()), ("expressions/extract".to_string(), "EXTRACT({{ date_part }} FROM {{ expr }})".to_string()), ( "statements/select".to_string(), @@ -535,6 +536,7 @@ OFFSET {{ offset }}{% endif %}"#.to_string(), ("expressions/like".to_string(), "{{ expr }} {% if negated %}NOT {% endif %}LIKE {{ pattern }}".to_string()), ("expressions/ilike".to_string(), "{{ expr }} {% if negated %}NOT {% endif %}ILIKE {{ pattern }}".to_string()), ("expressions/like_escape".to_string(), "{{ like_expr }} ESCAPE {{ escape_char }}".to_string()), + ("expressions/within_group".to_string(), "{{ fun_sql }} WITHIN GROUP (ORDER BY {{ within_group_concat }})".to_string()), ("quotes/identifiers".to_string(), "\"".to_string()), ("quotes/escape".to_string(), "\"\"".to_string()), ("params/param".to_string(), "${{ param_index + 1 }}".to_string()), diff --git a/rust/cubesql/cubesql/src/sql/statement.rs b/rust/cubesql/cubesql/src/sql/statement.rs index 74ac31e4b16f2..06ea4292a2338 100644 --- a/rust/cubesql/cubesql/src/sql/statement.rs +++ b/rust/cubesql/cubesql/src/sql/statement.rs @@ -7,6 +7,7 @@ use pg_srv::{ }; use sqlparser::ast::{ self, ArrayAgg, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value, + WithinGroup, }; use std::{collections::HashMap, error::Error}; @@ -177,9 +178,6 @@ trait Visitor<'ast, E: Error> { } } } - for order_expr in list_agg.within_group.iter_mut() { - self.visit_expr(&mut order_expr.expr)?; - } } Expr::GroupingSets(vec) | Expr::Cube(vec) | Expr::Rollup(vec) => { for v in vec.iter_mut() { @@ -229,6 +227,12 @@ trait Visitor<'ast, E: Error> { self.visit_expr(limit)?; } } + Expr::WithinGroup(WithinGroup { expr, order_by }) => { + self.visit_expr(expr)?; + for order_by_expr in order_by { + self.visit_expr(&mut order_by_expr.expr)?; + } + } }; Ok(()) diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index df0b384600594..8d226d17cd27c 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -542,12 +542,22 @@ impl SqlTemplates { aggregate_function: AggregateFunction, args: Vec, distinct: bool, + within_group: Vec, ) -> Result { let function = self.aggregate_function_name(aggregate_function, distinct); let args_concat = args.join(", "); - self.render_template( + let sql = self.render_template( &format!("functions/{}", function), context! { args_concat => args_concat, args => args, distinct => distinct }, + ); + if within_group.len() == 0 { + return sql; + } + + let within_group_concat = within_group.join(", "); + self.render_template( + "expressions/within_group", + context! { fun_sql => sql?, within_group_concat => within_group_concat }, ) } diff --git a/rust/cubesqlplanner/Cargo.lock b/rust/cubesqlplanner/Cargo.lock index 58acf427b2091..634b05ad76ed4 100644 --- a/rust/cubesqlplanner/Cargo.lock +++ b/rust/cubesqlplanner/Cargo.lock @@ -639,7 +639,7 @@ dependencies = [ [[package]] name = "cube-ext" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "chrono", @@ -758,7 +758,7 @@ dependencies = [ [[package]] name = "datafusion" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -791,7 +791,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "arrow", "ordered-float 2.10.1", @@ -802,7 +802,7 @@ dependencies = [ [[package]] name = "datafusion-data-access" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "async-trait", "chrono", @@ -815,7 +815,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -826,7 +826,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=dcf3e4aa26fd112043ef26fa4a78db5dbd443c86#dcf3e4aa26fd112043ef26fa4a78db5dbd443c86" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=11a4ed10b184b2f1b22f7458702ae0c63f011241#11a4ed10b184b2f1b22f7458702ae0c63f011241" dependencies = [ "ahash 0.7.8", "arrow", @@ -2775,7 +2775,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=5fe1b77d1a91b80529a0b7af0b89411d3cba5137#5fe1b77d1a91b80529a0b7af0b89411d3cba5137" dependencies = [ "log", ] diff --git a/rust/cubesqlplanner/cubesqlplanner/Cargo.toml b/rust/cubesqlplanner/cubesqlplanner/Cargo.toml index 309e341b5f4fd..e25867ccf2f85 100644 --- a/rust/cubesqlplanner/cubesqlplanner/Cargo.toml +++ b/rust/cubesqlplanner/cubesqlplanner/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "dcf3e4aa26fd112043ef26fa4a78db5dbd443c86", default-features = false, features = ["regex_expressions", "unicode_expressions"] } +datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "11a4ed10b184b2f1b22f7458702ae0c63f011241", default-features = false, features = ["regex_expressions", "unicode_expressions"] } tokio = { version = "^1.35", features = ["full", "rt", "tracing"] } itertools = "0.10.2" cubeclient = { path = "../../cubesql/cubeclient" }