From ecaf602f862f06046266da4bf0320a78f0477876 Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Wed, 11 Sep 2024 16:25:18 +0300 Subject: [PATCH] refactor(cubesql): Extract CubeScanWrappedSqlNode from CubeScanWrapperNode --- .../cubesql/src/compile/engine/df/scan.rs | 31 ++- .../cubesql/src/compile/engine/df/wrapper.rs | 100 +++++--- rust/cubesql/cubesql/src/compile/mod.rs | 237 ++++-------------- .../cubesql/src/compile/query_engine.rs | 2 + .../cubesql/src/compile/test/test_wrapper.rs | 168 +++---------- .../cubesql/cubesql/src/compile/test/utils.rs | 17 +- 6 files changed, 189 insertions(+), 366 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs index 27dd997793e27..928b4c0441bdd 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs @@ -30,7 +30,7 @@ use std::{ use crate::{ compile::{ - engine::df::wrapper::{CubeScanWrapperNode, SqlQuery}, + engine::df::wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode, SqlQuery}, rewrite::WrappedSelectType, test::find_cube_scans_deep_search, }, @@ -386,35 +386,32 @@ impl ExtensionPlanner for CubeScanExtensionPlanner { config_obj: self.config_obj.clone(), })) } else if let Some(wrapper_node) = node.as_any().downcast_ref::() { + return Err(DataFusionError::Internal(format!( + "CubeScanWrapperNode is not executable, SQL should be generated first with QueryEngine::evaluate_wrapped_sql: {:?}", + wrapper_node + ))); + } else if let Some(wrapped_sql_node) = + node.as_any().downcast_ref::() + { // TODO // assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs"); // assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs"); let scan_node = - find_cube_scans_deep_search(wrapper_node.wrapped_plan.clone(), false) + find_cube_scans_deep_search(wrapped_sql_node.wrapped_plan.clone(), false) .into_iter() .next() .ok_or(DataFusionError::Internal(format!( "No cube scans found in wrapper node: {:?}", - wrapper_node + wrapped_sql_node )))?; - let schema = SchemaRef::new(wrapper_node.schema().as_ref().into()); + let schema = SchemaRef::new(wrapped_sql_node.schema().as_ref().into()); Some(Arc::new(CubeScanExecutionPlan { schema, - member_fields: wrapper_node.member_fields.as_ref().ok_or_else(|| { - DataFusionError::Internal(format!( - "Member fields are not set for wrapper node. Optimization wasn't performed: {:?}", - wrapper_node - )) - })?.clone(), + member_fields: wrapped_sql_node.member_fields.clone(), transport: self.transport.clone(), - request: wrapper_node.request.clone().unwrap_or(scan_node.request.clone()), - wrapped_sql: Some(wrapper_node.wrapped_sql.as_ref().ok_or_else(|| { - DataFusionError::Internal(format!( - "Wrapped SQL is not set for wrapper node. Optimization wasn't performed: {:?}", - wrapper_node - )) - })?.clone()), + request: wrapped_sql_node.request.clone(), + wrapped_sql: Some(wrapped_sql_node.wrapped_sql.clone()), auth_context: scan_node.auth_context.clone(), options: scan_node.options.clone(), meta: self.meta.clone(), diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index b66fa831cab02..8f5f02216262a 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -201,14 +201,75 @@ impl SqlQuery { } } +#[derive(Clone, Debug)] +pub struct CubeScanWrappedSqlNode { + // TODO maybe replace wrapped plan with schema + scan_node + pub wrapped_plan: Arc, + pub wrapped_sql: SqlQuery, + pub request: TransportLoadRequestQuery, + pub member_fields: Vec, +} + +impl CubeScanWrappedSqlNode { + pub fn new( + wrapped_plan: Arc, + sql: SqlQuery, + request: TransportLoadRequestQuery, + member_fields: Vec, + ) -> Self { + Self { + wrapped_plan, + wrapped_sql: sql, + request, + member_fields, + } + } +} + +impl UserDefinedLogicalNode for CubeScanWrappedSqlNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + self.wrapped_plan.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + // TODO figure out nice plan for wrapped plan + write!(f, "CubeScanWrappedSql") + } + + fn from_template( + &self, + exprs: &[datafusion::logical_plan::Expr], + inputs: &[datafusion::logical_plan::LogicalPlan], + ) -> std::sync::Arc { + assert_eq!(inputs.len(), 0, "input size inconsistent"); + assert_eq!(exprs.len(), 0, "expression size inconsistent"); + + Arc::new(CubeScanWrappedSqlNode { + wrapped_plan: self.wrapped_plan.clone(), + wrapped_sql: self.wrapped_sql.clone(), + request: self.request.clone(), + member_fields: self.member_fields.clone(), + }) + } +} + #[derive(Debug, Clone)] pub struct CubeScanWrapperNode { pub wrapped_plan: Arc, pub meta: Arc, pub auth_context: AuthContextRef, - pub wrapped_sql: Option, - pub request: Option, - pub member_fields: Option>, pub span_id: Option>, pub config_obj: Arc, } @@ -225,31 +286,10 @@ impl CubeScanWrapperNode { wrapped_plan, meta, auth_context, - wrapped_sql: None, - request: None, - member_fields: None, span_id, config_obj, } } - - pub fn with_sql_and_request( - &self, - sql: SqlQuery, - request: TransportLoadRequestQuery, - member_fields: Vec, - ) -> Self { - Self { - wrapped_plan: self.wrapped_plan.clone(), - meta: self.meta.clone(), - auth_context: self.auth_context.clone(), - wrapped_sql: Some(sql), - request: Some(request), - member_fields: Some(member_fields), - span_id: self.span_id.clone(), - config_obj: self.config_obj.clone(), - } - } } fn expr_name(e: &Expr, schema: &Arc) -> Result { @@ -317,7 +357,7 @@ impl CubeScanWrapperNode { &self, transport: Arc, load_request_meta: Arc, - ) -> result::Result { + ) -> result::Result { let schema = self.schema(); let wrapped_plan = self.wrapped_plan.clone(); let (sql, request, member_fields) = Self::generate_sql_for_node( @@ -361,7 +401,12 @@ impl CubeScanWrapperNode { sql.finalize_query(sql_templates).map_err(|e| CubeError::internal(e.to_string()))?; Ok((sql, request, member_fields)) })?; - Ok(self.with_sql_and_request(sql, request, member_fields)) + Ok(CubeScanWrappedSqlNode::new( + self.wrapped_plan.clone(), + sql, + request, + member_fields, + )) } pub fn set_max_limit_for_node(self, node: Arc) -> Arc { @@ -2226,9 +2271,6 @@ impl UserDefinedLogicalNode for CubeScanWrapperNode { wrapped_plan: self.wrapped_plan.clone(), meta: self.meta.clone(), auth_context: self.auth_context.clone(), - wrapped_sql: self.wrapped_sql.clone(), - request: self.request.clone(), - member_fields: self.member_fields.clone(), span_id: self.span_id.clone(), config_obj: self.config_obj.clone(), }) diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index ee46ecf28abc9..60e3a40ce9089 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -370,11 +370,7 @@ mod tests { ) .await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); assert!(sql.contains(" IN (")); @@ -385,11 +381,7 @@ mod tests { ) .await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); assert!(sql.contains(" IN (")); @@ -408,11 +400,7 @@ mod tests { DatabaseProtocol::PostgreSQL, ).await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); } @@ -2992,17 +2980,15 @@ limit assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("sixteen_charchar_1")); assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("sixteen_charchar_2")); } @@ -7246,11 +7232,7 @@ ORDER BY DatabaseProtocol::PostgreSQL ).await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!( sql.contains(expected_search_expr), "cast_expr is {}, expected_search_expr is {}", @@ -7582,17 +7564,15 @@ ORDER BY ); assert!(!query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ungrouped")); assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("[\"dim2\",\"asc\"]")); } @@ -7645,9 +7625,8 @@ ORDER BY "source"."str0" ASC ); assert!(!query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ungrouped")); } @@ -11453,11 +11432,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `LOWER(..) <> .. OR .. IS NULL` let re = Regex::new(r"LOWER ?\(.+\) != .+ OR .+ IS NULL").unwrap(); @@ -11490,11 +11465,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(LOWER(..) IN (..))` let re = Regex::new(r"NOT.+LOWER ?\(.+\).* IN ").unwrap(); @@ -11815,11 +11786,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(LOWER(..) IN (..)) OR NOT(.. IS NOT NULL)` let re = Regex::new(r"NOT.+LOWER ?\(.+\) IN .+\) OR NOT.+ IS NOT NULL").unwrap(); @@ -12164,11 +12131,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); assert!(sql.contains("GROUP BY ")); @@ -12749,11 +12712,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(.. IS NULL OR LOWER(..) IN)` let re = Regex::new(r"NOT \(.+ IS NULL OR .*LOWER\(.+ IN ").unwrap(); @@ -12800,11 +12759,7 @@ ORDER BY "source"."str0" ASC let re = Regex::new(r"\(LOWER ?\(.+\) = .+ OR .+LOWER ?\(.+\) = .+\) IN \(TRUE, FALSE\)") .unwrap(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(re.is_match(&sql)); } @@ -13427,9 +13382,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("LEFT")); } @@ -13795,9 +13749,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13879,9 +13832,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13908,9 +13860,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13937,9 +13888,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13961,17 +13911,12 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("OVER"), "SQL should contain 'OVER': {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -13998,32 +13943,22 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("long_l_1"), "SQL should contain long_l_1: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("long_l_1"), "SQL should contain long_l_2: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -14049,9 +13984,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CURRENT_DATE()")); @@ -14105,11 +14039,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check if contains `CAST(EXTRACT(YEAR FROM ..) || .. || .. || ..)` let re = Regex::new(r"CAST.+EXTRACT.+YEAR FROM(.+ \|\|){3}").unwrap(); @@ -14270,11 +14200,7 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("EXTRACT(YEAR")); assert!(sql.contains("EXTRACT(MONTH")); @@ -14384,11 +14310,7 @@ ORDER BY "source"."str0" ASC // TODO: split on complex expressions? // CAST(CAST(ta_1.order_date AS Date32) - CAST(CAST(Utf8("1970-01-01") AS Date32) AS Date32) + Int64(3) AS Decimal(38, 10)) if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("\"limit\":1000")); assert!(sql.contains("% 7")); @@ -14771,11 +14693,7 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LIMIT 101")); assert!(sql.contains("ORDER BY")); @@ -14889,9 +14807,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("NOT IN (")); } @@ -14960,9 +14877,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("NOT (")); } @@ -14994,9 +14910,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATEDIFF(day,")); @@ -15023,11 +14938,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATETIME_DIFF(CAST(")); assert!(sql.contains("day)")); @@ -15054,11 +14965,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATEDIFF(day,")); assert!(sql.contains("DATE_TRUNC('day',")); @@ -15085,11 +14992,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("CASE WHEN LOWER('day')")); assert!(sql.contains("WHEN 'year' THEN 12 WHEN 'quarter' THEN 3 WHEN 'month' THEN 1 END")); assert!(sql.contains("EXTRACT(EPOCH FROM")); @@ -15124,9 +15027,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATEADD(day, 7,")); @@ -15153,11 +15055,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATETIME_ADD(CAST(")); assert!(sql.contains("INTERVAL 7 day)")); @@ -15185,11 +15083,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("+ '7 day'::interval")); } @@ -15265,9 +15159,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATE(")); } @@ -15304,9 +15197,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT(MONTH FROM ")); } @@ -15347,11 +15239,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("order_date")); assert!(sql.contains("EXTRACT(DAY FROM")) } @@ -15470,11 +15358,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("OFFSET 1\nLIMIT 2")); } @@ -15726,9 +15610,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("SELECT DISTINCT ")); @@ -15810,9 +15693,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW")); @@ -16354,9 +16236,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("LIMIT 250")); @@ -16641,11 +16522,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS STRING)")); assert!(sql.contains(" AS FLOAT)")); assert!(sql.contains(" AS DOUBLE)")); @@ -16673,11 +16550,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS STRING)")); assert!(sql.contains(" AS FLOAT64)")); assert!(sql.contains(" AS BIGDECIMAL(38,10))")); @@ -16701,11 +16574,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS TEXT)")); assert!(sql.contains(" AS REAL)")); assert!(sql.contains(" AS DOUBLE PRECISION)")); @@ -16825,11 +16694,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LIKE ")); assert!(sql.contains("ESCAPE ")); diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index ac49e217b88fc..daa921e34b148 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -239,6 +239,8 @@ pub trait QueryEngine { }; log::debug!("Rewrite: {:#?}", rewrite_plan); + // We want to generate SQL early, as a part of planning, and not later (like during execution) + // to catch all SQL generation errors during planning let rewrite_plan = Self::evaluate_wrapped_sql( self.transport_ref().clone(), Arc::new(state.get_load_request_meta()), diff --git a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs index 1eb78f92fa5e6..106f77c959e20 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs @@ -29,9 +29,8 @@ async fn test_simple_wrapper() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("COALESCE")); @@ -53,11 +52,7 @@ async fn test_wrapper_group_by_rollup() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -78,11 +73,7 @@ async fn test_wrapper_group_by_rollup_with_aliases() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -103,11 +94,7 @@ async fn test_wrapper_group_by_rollup_nested() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1, 2)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -128,11 +115,7 @@ async fn test_wrapper_group_by_rollup_nested_from_asterisk() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -153,11 +136,7 @@ async fn test_wrapper_group_by_rollup_nested_with_aliases() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1, 2)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -180,11 +159,7 @@ async fn test_wrapper_group_by_rollup_nested_complex() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1), ROLLUP(2), 3, CUBE(4)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -205,11 +180,7 @@ async fn test_wrapper_group_by_rollup_placeholders() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -230,11 +201,7 @@ async fn test_wrapper_group_by_cube() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Cube")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -255,11 +222,7 @@ async fn test_wrapper_group_by_rollup_complex() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -280,11 +243,7 @@ async fn test_simple_subquery_wrapper_projection_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -307,11 +266,7 @@ async fn test_simple_subquery_wrapper_filter_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -334,11 +289,7 @@ async fn test_simple_subquery_wrapper_projection_aggregate_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -360,11 +311,7 @@ async fn test_simple_subquery_wrapper_filter_in_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("IN (SELECT")); assert!(sql.contains("utf8__male__")); @@ -387,11 +334,7 @@ async fn test_simple_subquery_wrapper_filter_and_projection_empty_source() { let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("IN (SELECT")); assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -416,15 +359,13 @@ async fn test_simple_subquery_wrapper_projection() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("\\\\\\\"limit\\\\\\\":1")); @@ -447,9 +388,8 @@ async fn test_simple_subquery_wrapper_projection_aggregate() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); @@ -472,15 +412,13 @@ async fn test_simple_subquery_wrapper_filter_equal() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("\\\\\\\"limit\\\\\\\":1")); @@ -503,9 +441,8 @@ async fn test_simple_subquery_wrapper_filter_in() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("IN (SELECT")); @@ -529,9 +466,8 @@ async fn test_simple_subquery_wrapper_filter_and_projection() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("IN (SELECT")); @@ -554,9 +490,8 @@ async fn test_case_wrapper() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -592,9 +527,8 @@ async fn test_case_wrapper_distinct() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -621,9 +555,8 @@ async fn test_case_wrapper_alias_with_order() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ORDER BY \"case_when_a_cust\"")); @@ -650,9 +583,8 @@ async fn test_case_wrapper_ungrouped() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -684,9 +616,8 @@ async fn test_case_wrapper_non_strict_match() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -719,9 +650,8 @@ async fn test_case_wrapper_ungrouped_sorted() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ORDER BY")); } @@ -748,9 +678,8 @@ async fn test_case_wrapper_ungrouped_sorted_aliased() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql // TODO test without depend on column name .contains("ORDER BY \"case_when")); @@ -772,25 +701,19 @@ async fn test_case_wrapper_with_internal_limit() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("1123"), "SQL contains 1123: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -818,19 +741,14 @@ async fn test_case_wrapper_with_system_fields() { assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains( "\\\"cube_name\\\":\\\"KibanaSampleDataEcommerce\\\",\\\"alias\\\":\\\"user\\\"" ), r#"SQL contains `\"cube_name\":\"KibanaSampleDataEcommerce\",\"alias\":\"user\"` {}"#, - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -856,25 +774,19 @@ async fn test_case_wrapper_with_limit() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("1123"), "SQL contains 1123: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -900,9 +812,8 @@ async fn test_case_wrapper_with_null() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -957,9 +868,8 @@ async fn test_case_wrapper_escaping() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql // Expect 6 backslashes as output is JSON and it's escaped one more time .contains("\\\\\\\\\\\\`")); diff --git a/rust/cubesql/cubesql/src/compile/test/utils.rs b/rust/cubesql/cubesql/src/compile/test/utils.rs index 5193918dc97b4..e22772a655b61 100644 --- a/rust/cubesql/cubesql/src/compile/test/utils.rs +++ b/rust/cubesql/cubesql/src/compile/test/utils.rs @@ -3,14 +3,17 @@ use std::sync::Arc; use datafusion::logical_plan::{plan::Extension, Filter, LogicalPlan, PlanVisitor}; use crate::{ - compile::engine::df::{scan::CubeScanNode, wrapper::CubeScanWrapperNode}, + compile::engine::df::{ + scan::CubeScanNode, + wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode}, + }, CubeError, }; pub trait LogicalPlanTestUtils { fn find_cube_scan(&self) -> CubeScanNode; - fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode; + fn find_cube_scan_wrapped_sql(&self) -> CubeScanWrappedSqlNode; fn find_cube_scans(&self) -> Vec; @@ -27,13 +30,13 @@ impl LogicalPlanTestUtils for LogicalPlan { cube_scans[0].clone() } - fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode { + fn find_cube_scan_wrapped_sql(&self) -> CubeScanWrappedSqlNode { match self { LogicalPlan::Extension(Extension { node }) => { - if let Some(wrapper_node) = node.as_any().downcast_ref::() { + if let Some(wrapper_node) = node.as_any().downcast_ref::() { wrapper_node.clone() } else { - panic!("Root plan node is not cube_scan_wrapper!"); + panic!("Root plan node is not cube_scan_wrapped_sql!"); } } _ => panic!("Root plan node is not extension!"), @@ -66,6 +69,10 @@ pub fn find_cube_scans_deep_search( ext.node.as_any().downcast_ref::() { wrapper_node.wrapped_plan.accept(self)?; + } else if let Some(wrapper_node) = + ext.node.as_any().downcast_ref::() + { + wrapper_node.wrapped_plan.accept(self)?; } } Ok(true)