Skip to content

Commit

Permalink
feat(cubesql): CASE WHEN SQL push down (#7029)
Browse files Browse the repository at this point in the history
* feat(cubesql): CASE WHEN SQL push down

* Enable SQL push down for legacy tests
  • Loading branch information
paveltiunov authored Aug 13, 2023
1 parent 0e8a76a commit 80e4a60
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/rust-cubesql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ jobs:
env:
CUBESQL_TESTING_CUBE_TOKEN: ${{ secrets.CUBESQL_TESTING_CUBE_TOKEN }}
CUBESQL_TESTING_CUBE_URL: ${{ secrets.CUBESQL_TESTING_CUBE_URL }}
CUBESQL_SQL_PUSH_DOWN: true
run: cd rust/cubesql && cargo test

native_linux:
Expand Down
4 changes: 4 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2402,7 +2402,11 @@ class BaseQuery {
select: 'SELECT {{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' +
'FROM (\n {{ from }}\n) AS {{ from_alias }} \n' +
'{% if group_by %} GROUP BY {{ group_by | map(attribute=\'index\') | join(\', \') }}{% endif %}',
},
expressions: {
column_aliased: '{{expr}} {{quoted_alias}}',
case: 'CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END',
binary: '{{ left }} {{ op }} {{ right }}'
},
quotes: {
identifiers: '"',
Expand Down
88 changes: 85 additions & 3 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl CubeScanWrapperNode {
plan.clone(),
sql_query,
sql_generator.clone(),
(*expr).clone(),
*expr,
)
.await?;
Ok((expr, sql_query))
Expand Down Expand Up @@ -423,7 +423,32 @@ impl CubeScanWrapperNode {
sql_query,
)),
// Expr::ScalarVariable(_, _) => {}
// Expr::BinaryExpr { .. } => {}
Expr::BinaryExpr { left, op, right } => {
let (left, sql_query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*left,
)
.await?;
let (right, sql_query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*right,
)
.await?;
let resulting_sql = sql_generator
.get_sql_templates()
.binary_expr(left, op.to_string(), right)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for binary expr: {}",
e
))
})?;
Ok((resulting_sql, sql_query))
}
// Expr::AnyExpr { .. } => {}
// Expr::Like(_) => {}-=
// Expr::ILike(_) => {}
Expand All @@ -434,7 +459,64 @@ impl CubeScanWrapperNode {
// Expr::Negative(_) => {}
// Expr::GetIndexedField { .. } => {}
// Expr::Between { .. } => {}
// Expr::Case { .. } => {}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let expr = if let Some(expr) = expr {
let (expr, sql_query_next) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*expr,
)
.await?;
sql_query = sql_query_next;
Some(expr)
} else {
None
};
let mut when_then_expr_sql = Vec::new();
for (when, then) in when_then_expr {
let (when, sql_query_next) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*when,
)
.await?;
let (then, sql_query_next) = Self::generate_sql_for_expr(
plan.clone(),
sql_query_next,
sql_generator.clone(),
*then,
)
.await?;
sql_query = sql_query_next;
when_then_expr_sql.push((when, then));
}
let else_expr = if let Some(else_expr) = else_expr {
let (else_expr, sql_query_next) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*else_expr,
)
.await?;
sql_query = sql_query_next;
Some(else_expr)
} else {
None
};
let resulting_sql = sql_generator
.get_sql_templates()
.case(expr, when_then_expr_sql, else_expr)
.map_err(|e| {
DataFusionError::Internal(format!("Can't generate SQL for case: {}", e))
})?;
Ok((resulting_sql, sql_query))
}
// Expr::Cast { .. } => {}
// Expr::TryCast { .. } => {}
// Expr::Sort { .. } => {}
Expand Down
63 changes: 49 additions & 14 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,8 @@ mod tests {

fn find_cube_scan(&self) -> CubeScanNode;

fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode;

fn find_cube_scans(&self) -> Vec<CubeScanNode>;

fn find_filter(&self) -> Option<Filter>;
Expand Down Expand Up @@ -1736,6 +1738,20 @@ mod tests {
cube_scans[0].clone()
}

fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode {
match self {
LogicalPlan::Extension(Extension { node }) => {
if let Some(wrapper_node) = node.as_any().downcast_ref::<CubeScanWrapperNode>()
{
wrapper_node.clone()
} else {
panic!("Root plan node is not cube_scan_wrapper!");
}
}
_ => panic!("Root plan node is not extension!"),
}
}

fn find_cube_scans(&self) -> Vec<CubeScanNode> {
find_cube_scans_deep_search(Arc::new(self.clone()), true)
}
Expand Down Expand Up @@ -17929,20 +17945,39 @@ ORDER BY \"COUNT(count)\" DESC"
)
.await;

// let logical_plan = query_plan.as_logical_plan();
// assert_eq!(
// logical_plan.find_cube_scan().request,
// V1LoadRequestQuery {
// measures: Some(vec!["KibanaSampleDataEcommerce.avgPrice".to_string(),]),
// segments: Some(vec![]),
// dimensions: Some(vec![]),
// time_dimensions: None,
// order: None,
// limit: None,
// offset: None,
// filters: None
// }
// );
let logical_plan = query_plan.as_logical_plan();
assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("COALESCE"));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}

#[tokio::test]
async fn test_case_wrapper() {
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT CASE WHEN customer_gender = 'female' THEN 'f' ELSE 'm' END, MIN(avgPrice) mp FROM (SELECT avgPrice, customer_gender FROM KibanaSampleDataEcommerce LIMIT 1) a GROUP BY 1"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("CASE WHEN"));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
Expand Down
147 changes: 145 additions & 2 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use crate::{
rewrite::{
agg_fun_expr, aggregate, alias_expr,
analysis::LogicalPlanAnalysis,
column_expr, cube_scan, cube_scan_wrapper, fun_expr_var_arg, literal_expr, projection,
rewrite,
binary_expr, case_expr_var_arg, column_expr, cube_scan, cube_scan_wrapper,
fun_expr_var_arg, literal_expr, projection, rewrite,
rewriter::RewriteRules,
rules::{replacer_pull_up_node, replacer_push_down_node},
scalar_fun_expr_args, scalar_fun_expr_args_empty_tail, transforming_rewrite,
Expand Down Expand Up @@ -339,6 +339,52 @@ impl RewriteRules for WrapperRules {
alias_expr(wrapper_pullup_replacer("?expr", "?alias_to_cube"), "?alias"),
wrapper_pullup_replacer(alias_expr("?expr", "?alias"), "?alias_to_cube"),
),
// Case
rewrite(
"wrapper-push-down-case",
wrapper_pushdown_replacer(
case_expr_var_arg("?when", "?then", "?else"),
"?alias_to_cube",
),
case_expr_var_arg(
wrapper_pushdown_replacer("?when", "?alias_to_cube"),
wrapper_pushdown_replacer("?then", "?alias_to_cube"),
wrapper_pushdown_replacer("?else", "?alias_to_cube"),
),
),
transforming_rewrite(
"wrapper-pull-up-case",
case_expr_var_arg(
wrapper_pullup_replacer("?when", "?alias_to_cube"),
wrapper_pullup_replacer("?then", "?alias_to_cube"),
wrapper_pullup_replacer("?else", "?alias_to_cube"),
),
wrapper_pullup_replacer(
case_expr_var_arg("?when", "?then", "?else"),
"?alias_to_cube",
),
self.transform_case_expr("?alias_to_cube"),
),
// Binary Expr
rewrite(
"wrapper-push-down-binary-expr",
wrapper_pushdown_replacer(binary_expr("?left", "?op", "?right"), "?alias_to_cube"),
binary_expr(
wrapper_pushdown_replacer("?left", "?alias_to_cube"),
"?op",
wrapper_pushdown_replacer("?right", "?alias_to_cube"),
),
),
transforming_rewrite(
"wrapper-pull-up-binary-expr",
binary_expr(
wrapper_pullup_replacer("?left", "?alias_to_cube"),
"?op",
wrapper_pullup_replacer("?right", "?alias_to_cube"),
),
wrapper_pullup_replacer(binary_expr("?left", "?op", "?right"), "?alias_to_cube"),
self.transform_binary_expr("?op", "?alias_to_cube"),
),
// Column
rewrite(
"wrapper-push-down-column",
Expand All @@ -353,6 +399,20 @@ impl RewriteRules for WrapperRules {
),
];

Self::expr_list_pushdown_pullup_rules(&mut rules, "wrapper-case-expr", "CaseExprExpr");

Self::expr_list_pushdown_pullup_rules(
&mut rules,
"wrapper-case-when-expr",
"CaseExprWhenThenExpr",
);

Self::expr_list_pushdown_pullup_rules(
&mut rules,
"wrapper-case-else-expr",
"CaseExprElseExpr",
);

Self::list_pushdown_pullup_rules(
&mut rules,
"wrapper-aggregate-aggr-expr",
Expand Down Expand Up @@ -562,6 +622,63 @@ impl WrapperRules {
}
}

fn transform_case_expr(
&self,
alias_to_cube_var: &'static str,
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
let alias_to_cube_var = var!(alias_to_cube_var);
let meta = self.cube_context.meta.clone();
move |egraph, subst| {
for alias_to_cube in var_iter!(
egraph[subst[alias_to_cube_var]],
WrapperPullupReplacerAliasToCube
)
.cloned()
{
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
if sql_generator
.get_sql_templates()
.templates
.contains_key("expressions/case")
{
return true;
}
}
}
false
}
}

fn transform_binary_expr(
&self,
_operator_var: &'static str,
alias_to_cube_var: &'static str,
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
let alias_to_cube_var = var!(alias_to_cube_var);
// let operator_var = var!(operator_var);
let meta = self.cube_context.meta.clone();
move |egraph, subst| {
for alias_to_cube in var_iter!(
egraph[subst[alias_to_cube_var]],
WrapperPullupReplacerAliasToCube
)
.cloned()
{
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
if sql_generator
.get_sql_templates()
.templates
.contains_key("expressions/binary")
{
// TODO check supported operators
return true;
}
}
}
false
}
}

fn list_pushdown_pullup_rules(
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
rule_name: &str,
Expand All @@ -588,4 +705,30 @@ impl WrapperRules {
wrapper_pullup_replacer(substitute_list_node, "?alias_to_cube"),
)]);
}

fn expr_list_pushdown_pullup_rules(
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
rule_name: &str,
list_node: &str,
) {
rules.extend(replacer_push_down_node(
rule_name,
list_node,
|node| wrapper_pushdown_replacer(node, "?alias_to_cube"),
false,
));

rules.extend(replacer_pull_up_node(
rule_name,
list_node,
list_node,
|node| wrapper_pullup_replacer(node, "?alias_to_cube"),
));

rules.extend(vec![rewrite(
rule_name,
wrapper_pushdown_replacer(list_node, "?alias_to_cube"),
wrapper_pullup_replacer(list_node, "?alias_to_cube"),
)]);
}
}
Loading

0 comments on commit 80e4a60

Please sign in to comment.