Skip to content

Commit

Permalink
chore(cubesql): Support PERCENTILE_CONT planning
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 11, 2024
1 parent db2256d commit ce2065a
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 42 deletions.
14 changes: 7 additions & 7 deletions packages/cubejs-backend-native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions rust/cubesql/cubesql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ homepage = "https://cube.dev"

[dependencies]
arc-swap = "1"
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "dcf3e4aa26fd112043ef26fa4a78db5dbd443c86", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "11a4ed10b184b2f1b22f7458702ae0c63f011241", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
anyhow = "1.0"
thiserror = "1.0.50"
cubeclient = { path = "../cubeclient" }
pg-srv = { path = "../pg-srv" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }
base64 = "0.13.0"
tokio = { version = "^1.35", features = ["full", "rt", "tracing"] }
serde = { version = "^1.0", features = ["derive"] }
Expand Down
34 changes: 25 additions & 9 deletions rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,31 @@ pub fn rewrite(expr: &Expr, map: &HashMap<Column, Option<Expr>>) -> Result<Optio
fun,
args,
distinct,
} => args
.iter()
.map(|arg| rewrite(arg, map))
.collect::<Result<Option<Vec<_>>>>()?
.map(|args| Expr::AggregateFunction {
fun: fun.clone(),
args,
distinct: distinct.clone(),
}),
within_group,

Check warning on line 245 in rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs#L245

Added line #L245 was not covered by tests
} => {
let args = args
.iter()
.map(|arg| rewrite(arg, map))
.collect::<Result<Option<Vec<_>>>>()?;
let within_group = match within_group.as_ref() {
Some(within_group) => within_group
.iter()
.map(|expr| rewrite(expr, map))
.collect::<Result<Option<Vec<_>>>>()?
.map(|within_group| Some(within_group)),
None => Some(None),

Check warning on line 257 in rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs#L247-L257

Added lines #L247 - L257 were not covered by tests
};
if let (Some(args), Some(within_group)) = (args, within_group) {
Some(Expr::AggregateFunction {
fun: fun.clone(),
args,
distinct: distinct.clone(),
within_group,
})

Check warning on line 265 in rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs#L259-L265

Added lines #L259 - L265 were not covered by tests
} else {
None

Check warning on line 267 in rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/optimizers/utils.rs#L267

Added line #L267 was not covered by tests
}
}
Expr::WindowFunction {
fun,
args,
Expand Down
19 changes: 18 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,10 @@ impl CubeScanWrapperNode {
fun,
args,
distinct,
within_group,
} => {
let mut sql_args = Vec::new();
let mut sql_within_group = Vec::new();
for arg in args {
if let AggregateFunction::Count = fun {
if !distinct {
Expand All @@ -1850,10 +1852,25 @@ impl CubeScanWrapperNode {
sql_query = query;
sql_args.push(sql);
}
if let Some(within_group) = within_group {
for expr in within_group {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
expr,
ungrouped_scan_node.clone(),
subqueries.clone(),
)
.await?;
sql_query = query;
sql_within_group.push(sql);

Check warning on line 1867 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1856-L1867

Added lines #L1856 - L1867 were not covered by tests
}
}
Ok((
sql_generator
.get_sql_templates()
.aggregate_function(fun, sql_args, distinct)
.aggregate_function(fun, sql_args, distinct, sql_within_group)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for aggregate function: {}",
Expand Down
23 changes: 22 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ impl LogicalPlanToLanguageConverter {
fun,
args,
distinct,
within_group,
} => {
let fun = add_expr_data_node!(graph, fun, AggregateFunctionExprFun);
let args = add_expr_list_node!(
Expand All @@ -434,8 +435,18 @@ impl LogicalPlanToLanguageConverter {
flat_list
);
let distinct = add_expr_data_node!(graph, distinct, AggregateFunctionExprDistinct);
let within_group = add_expr_list_node!(
graph,

Check warning on line 439 in rust/cubesql/cubesql/src/compile/rewrite/converter.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/converter.rs#L439

Added line #L439 was not covered by tests
within_group.as_ref().unwrap_or(&vec![]),
query_params,
AggregateFunctionExprWithinGroup,
flat_list
);
graph.add(LogicalPlanLanguage::AggregateFunctionExpr([
fun, args, distinct,
fun,
args,
distinct,
within_group,
]))
}
Expr::WindowFunction {
Expand Down Expand Up @@ -1145,10 +1156,20 @@ pub fn node_to_expr(
let args =
match_expr_list_node!(node_by_id, to_expr, params[1], AggregateFunctionExprArgs);
let distinct = match_data_node!(node_by_id, params[2], AggregateFunctionExprDistinct);
let within_group = match_expr_list_node!(
node_by_id,
to_expr,
params[3],
AggregateFunctionExprWithinGroup
);
Expr::AggregateFunction {
fun,
args,
distinct,
within_group: match within_group.len() {
0 => None,
_ => Some(within_group),

Check warning on line 1171 in rust/cubesql/cubesql/src/compile/rewrite/converter.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/converter.rs#L1171

Added line #L1171 was not covered by tests
},
}
}
LogicalPlanLanguage::WindowFunctionExpr(params) => {
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ macro_rules! variant_field_struct {
AggregateFunction::ApproxMedian => "ApproxMedian",
AggregateFunction::BoolAnd => "BoolAnd",
AggregateFunction::BoolOr => "BoolOr",
AggregateFunction::PercentileCont => "PercentileCont",
}
);
};
Expand Down
8 changes: 6 additions & 2 deletions rust/cubesql/cubesql/src/compile/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ crate::plan_to_language! {
fun: AggregateFunction,
args: Vec<Expr>,
distinct: bool,
within_group: Vec<Expr>,
},
WindowFunctionExpr {
fun: WindowFunction,
Expand Down Expand Up @@ -1325,12 +1326,15 @@ fn agg_fun_expr(fun_name: impl Display, args: Vec<impl Display>, distinct: impl
} else {
"AggregateFunctionExprFun:"
};
// TODO
let within_group = "AggregateFunctionExprWithinGroup";
format!(
"(AggregateFunctionExpr {}{} {} {})",
"(AggregateFunctionExpr {}{} {} {} {})",
prefix,
fun_name,
list_expr("AggregateFunctionExprArgs", args),
distinct
distinct,
within_group,
)
}

Expand Down
8 changes: 7 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5191,9 +5191,14 @@ impl OldSplitRules {
vec![column_expr, tail],
),
);
let within_group = egraph.add(
LogicalPlanLanguage::AggregateFunctionExprWithinGroup(
vec![],
),
);

Check warning on line 5198 in rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs#L5194-L5198

Added lines #L5194 - L5198 were not covered by tests
let aggr_expr = egraph.add(
LogicalPlanLanguage::AggregateFunctionExpr(
[measure_fun, args, measure_distinct],
[measure_fun, args, measure_distinct, within_group],

Check warning on line 5201 in rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs#L5201

Added line #L5201 was not covered by tests
),
);
let alias = egraph.add(
Expand Down Expand Up @@ -5521,6 +5526,7 @@ impl OldSplitRules {
fun: utils::reaggragate_fun(&agg_type)?,
args: vec![expr],
distinct: false,
within_group: None,

Check warning on line 5529 in rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/old_split.rs#L5529

Added line #L5529 was not covered by tests
};
let expr_name =
aggr_expr.name(&DFSchema::empty()).ok()?;
Expand Down
10 changes: 7 additions & 3 deletions rust/cubesql/cubesql/src/sql/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pg_srv::{
};
use sqlparser::ast::{
self, ArrayAgg, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value,
WithinGroup,
};
use std::{collections::HashMap, error::Error};

Expand Down Expand Up @@ -177,9 +178,6 @@ trait Visitor<'ast, E: Error> {
}
}
}
for order_expr in list_agg.within_group.iter_mut() {
self.visit_expr(&mut order_expr.expr)?;
}
}
Expr::GroupingSets(vec) | Expr::Cube(vec) | Expr::Rollup(vec) => {
for v in vec.iter_mut() {
Expand Down Expand Up @@ -229,6 +227,12 @@ trait Visitor<'ast, E: Error> {
self.visit_expr(limit)?;
}
}
Expr::WithinGroup(WithinGroup { expr, order_by }) => {
self.visit_expr(expr)?;
for order_by_expr in order_by {
self.visit_expr(&mut order_by_expr.expr)?;

Check warning on line 233 in rust/cubesql/cubesql/src/sql/statement.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/sql/statement.rs#L230-L233

Added lines #L230 - L233 were not covered by tests
}
}
};

Ok(())
Expand Down
4 changes: 3 additions & 1 deletion rust/cubesql/cubesql/src/transport/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,12 +539,14 @@ impl SqlTemplates {
aggregate_function: AggregateFunction,
args: Vec<String>,
distinct: bool,
within_group: Vec<String>,
) -> Result<String, CubeError> {
let function = self.aggregate_function_name(aggregate_function, distinct);
let args_concat = args.join(", ");
let within_group_concat = within_group.join(", ");
self.render_template(
&format!("functions/{}", function),
context! { args_concat => args_concat, args => args, distinct => distinct },
context! { args_concat => args_concat, args => args, distinct => distinct, within_group_concat => within_group_concat },
)
}

Expand Down
Loading

0 comments on commit ce2065a

Please sign in to comment.