Skip to content

Commit

Permalink
fix(cubesql): Don't clone AST on pre-planning step (#8644)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr authored Aug 27, 2024
1 parent 4366299 commit 03277b0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 45 deletions.
2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/src/compile/query_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl QueryEngine for SqlQueryEngine {
}

fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType {
SensitiveDataSanitizer::new().replace(&stmt)
SensitiveDataSanitizer::new().replace(stmt.clone())
}
}

Expand Down
21 changes: 12 additions & 9 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ impl QueryRouter {
("query".to_string(), stmt.to_string()),
(
"sanitizedQuery".to_string(),
SensitiveDataSanitizer::new().replace(stmt).to_string(),
SensitiveDataSanitizer::new()
.replace(stmt.clone())
.to_string(),
),
]));
let msg = err.message();
Expand Down Expand Up @@ -712,25 +714,26 @@ impl QueryRouter {
}
}

pub fn rewrite_statement(stmt: &ast::Statement) -> ast::Statement {
pub fn rewrite_statement(stmt: ast::Statement) -> ast::Statement {
let stmt = CastReplacer::new().replace(stmt);
let stmt = ToTimestampReplacer::new().replace(&stmt);
let stmt = UdfWildcardArgReplacer::new().replace(&stmt);
let stmt = DateTokenNormalizeReplacer::new().replace(&stmt);
let stmt = RedshiftDatePartReplacer::new().replace(&stmt);
let stmt = ApproximateCountDistinctVisitor::new().replace(&stmt);
let stmt = ToTimestampReplacer::new().replace(stmt);
let stmt = UdfWildcardArgReplacer::new().replace(stmt);
let stmt = DateTokenNormalizeReplacer::new().replace(stmt);
let stmt = RedshiftDatePartReplacer::new().replace(stmt);
let stmt = ApproximateCountDistinctVisitor::new().replace(stmt);

stmt
}

pub async fn convert_statement_to_cube_query(
stmt: &ast::Statement,
stmt: ast::Statement,
meta: Arc<MetaContext>,
session: Arc<Session>,
qtrace: &mut Option<Qtrace>,
span_id: Option<Arc<SpanId>>,
) -> CompilationResult<QueryPlan> {
let stmt = rewrite_statement(stmt);

if let Some(qtrace) = qtrace {
qtrace.set_visitor_replaced_statement(&stmt);
}
Expand All @@ -745,5 +748,5 @@ pub async fn convert_sql_to_cube_query(
session: Arc<Session>,
) -> CompilationResult<QueryPlan> {
let stmt = parse_sql_to_statement(&query, session.state.protocol.clone(), &mut None)?;
convert_statement_to_cube_query(&stmt, meta, session, &mut None, None).await
convert_statement_to_cube_query(stmt, meta, session, &mut None, None).await
}
8 changes: 4 additions & 4 deletions rust/cubesql/cubesql/src/compile/test/rewrite_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ pub async fn create_test_postgresql_cube_context(

pub fn query_to_logical_plan(query: String, context: &CubeContext) -> LogicalPlan {
let stmt = parse_sql_to_statement(&query, DatabaseProtocol::PostgreSQL, &mut None).unwrap();
let stmt = rewrite_statement(&stmt);
let stmt = rewrite_statement(stmt);
let df_query_planner = SqlToRel::new_with_options(context, true);

return df_query_planner
.statement_to_plan(Statement::Statement(Box::new(stmt.clone())))
.unwrap();
df_query_planner
.statement_to_plan(Statement::Statement(Box::new(stmt)))
.unwrap()
}

pub fn rewrite_runner(plan: LogicalPlan, context: Arc<CubeContext>) -> CubeRunner {
Expand Down
12 changes: 6 additions & 6 deletions rust/cubesql/cubesql/src/sql/postgres/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ impl AsyncPostgresShim {
.await?;

let plan = convert_statement_to_cube_query(
&prepared_statement,
prepared_statement,
meta,
self.session.clone(),
&mut None,
Expand Down Expand Up @@ -1132,10 +1132,10 @@ impl AsyncPostgresShim {
.await?;

let stmt_replacer = StatementPlaceholderReplacer::new();
let hacked_query = stmt_replacer.replace(&query)?;
let hacked_query = stmt_replacer.replace(query.clone())?;

let plan = convert_statement_to_cube_query(
&hacked_query,
hacked_query,
meta,
self.session.clone(),
qtrace,
Expand Down Expand Up @@ -1393,7 +1393,7 @@ impl AsyncPostgresShim {
})?;

let plan = convert_statement_to_cube_query(
&cursor.query,
cursor.query.clone(),
meta,
self.session.clone(),
qtrace,
Expand Down Expand Up @@ -1475,7 +1475,7 @@ impl AsyncPostgresShim {
let select_stmt = Statement::Query(query);
// It's just a verification that we can compile that query.
let _ = convert_statement_to_cube_query(
&select_stmt,
select_stmt.clone(),
meta.clone(),
self.session.clone(),
&mut None,
Expand Down Expand Up @@ -1648,7 +1648,7 @@ impl AsyncPostgresShim {
}
other => {
let plan = convert_statement_to_cube_query(
&other,
other,
meta.clone(),
self.session.clone(),
qtrace,
Expand Down
62 changes: 37 additions & 25 deletions rust/cubesql/cubesql/src/sql/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ impl StatementPlaceholderReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> Result<ast::Statement, ConnectionError> {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> Result<ast::Statement, ConnectionError> {
let mut result = stmt;

self.visit_statement(&mut result)?;

Expand Down Expand Up @@ -671,8 +671,8 @@ impl CastReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -813,8 +813,8 @@ impl DateTokenNormalizeReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -874,8 +874,8 @@ impl RedshiftDatePartReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -930,8 +930,8 @@ impl ToTimestampReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand All @@ -957,8 +957,8 @@ impl UdfWildcardArgReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -1046,8 +1046,8 @@ impl ApproximateCountDistinctVisitor {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -1075,8 +1075,8 @@ impl SensitiveDataSanitizer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -1113,10 +1113,13 @@ mod tests {
use sqlparser::{dialect::PostgreSqlDialect, parser::Parser};

fn run_cast_replacer(input: &str, output: &str) -> Result<(), CubeError> {
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap();
let stmt = Parser::parse_sql(&PostgreSqlDialect {}, &input)
.unwrap()
.pop()
.expect("must contain at least one statement");

let replacer = CastReplacer::new();
let res = replacer.replace(&stmts[0]);
let res = replacer.replace(stmt);

assert_eq!(res.to_string(), output);

Expand Down Expand Up @@ -1144,7 +1147,7 @@ mod tests {
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap();

let replacer = RedshiftDatePartReplacer::new();
let res = replacer.replace(&stmts[0]);
let res = replacer.replace(stmts[0].clone());

assert_eq!(res.to_string(), output);

Expand Down Expand Up @@ -1176,10 +1179,13 @@ mod tests {
output: &str,
values: Vec<BindValue>,
) -> Result<(), ConnectionError> {
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap();
let stmt = Parser::parse_sql(&PostgreSqlDialect {}, &input)
.unwrap()
.pop()
.expect("must contain at least one statement");

let binder = PostgresStatementParamsBinder::new(values);
let mut res = stmts[0].clone();
let mut res = stmt;
binder.bind(&mut res)?;

assert_eq!(res.to_string(), output);
Expand Down Expand Up @@ -1353,10 +1359,13 @@ mod tests {
}

fn assert_placeholder_replacer(input: &str, output: &str) -> Result<(), CubeError> {
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap();
let stmt = Parser::parse_sql(&PostgreSqlDialect {}, &input)
.unwrap()
.pop()
.expect("must contain at least one statement");

let binder = StatementPlaceholderReplacer::new();
let result = binder.replace(&stmts[0]).unwrap();
let result = binder.replace(stmt).unwrap();

assert_eq!(result.to_string(), output);

Expand All @@ -1377,10 +1386,13 @@ mod tests {
}

fn assert_sensitive_data_sanitizer(input: &str, output: &str) -> Result<(), CubeError> {
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap();
let stmt = Parser::parse_sql(&PostgreSqlDialect {}, &input)
.unwrap()
.pop()
.expect("must contain at least one statement");

let binder = SensitiveDataSanitizer::new();
let result = binder.replace(&stmts[0]);
let result = binder.replace(stmt);

assert_eq!(result.to_string(), output);

Expand Down

0 comments on commit 03277b0

Please sign in to comment.