diff --git a/rust/cubesql/cubesql/src/sql/postgres/extended.rs b/rust/cubesql/cubesql/src/sql/postgres/extended.rs index f039609e25faa..61636b2518b56 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/extended.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/extended.rs @@ -58,6 +58,13 @@ pub enum PreparedStatement { description: Option, span_id: Option>, }, + Error { + /// Prepared statement can be declared from SQL or protocol (Parser) + from_sql: bool, + sql: String, + created: DateTime, + span_id: Option>, + }, } impl PreparedStatement { @@ -65,6 +72,7 @@ impl PreparedStatement { match self { PreparedStatement::Empty { created, .. } => created, PreparedStatement::Query { created, .. } => created, + PreparedStatement::Error { created, .. } => created, } } @@ -73,6 +81,7 @@ impl PreparedStatement { match self { PreparedStatement::Empty { .. } => "".to_string(), PreparedStatement::Query { query, .. } => query.to_string(), + PreparedStatement::Error { sql, .. } => sql.clone(), } } @@ -80,6 +89,7 @@ impl PreparedStatement { match self { PreparedStatement::Empty { from_sql, .. } => from_sql.clone(), PreparedStatement::Query { from_sql, .. } => from_sql.clone(), + PreparedStatement::Error { from_sql, .. } => from_sql.clone(), } } @@ -87,6 +97,7 @@ impl PreparedStatement { match self { PreparedStatement::Empty { .. } => None, PreparedStatement::Query { parameters, .. } => Some(¶meters.parameters), + PreparedStatement::Error { .. } => None, } } @@ -103,6 +114,10 @@ impl PreparedStatement { Ok(statement) } + PreparedStatement::Error { .. } => Err(CubeError::internal( + "It's not possible to bind errored prepared statements (it's a bug)".to_string(), + ) + .into()), } } @@ -110,6 +125,7 @@ impl PreparedStatement { match self { PreparedStatement::Empty { span_id, .. } => span_id.clone(), PreparedStatement::Query { span_id, .. } => span_id.clone(), + PreparedStatement::Error { span_id, .. } => span_id.clone(), } } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 8de9ac405c5ff..184ce0ec880b1 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -374,99 +374,132 @@ impl AsyncPostgresShim { protocol::FrontendMessage::Terminate => return Ok(()), // Extended protocol::FrontendMessage::Parse(body) => { - if tracked_error.is_none() { - doing_extended_query_message = true; - let mut qtrace = Qtrace::new(&body.query); - let span_id = Self::new_span_id(body.query.clone()); - if let Some(qtrace) = &qtrace { - debug!("Assigned query UUID: {}", qtrace.uuid()) - } - if let Some(auth_context) = self.session.state.auth_context() { - self.session - .session_manager - .server - .transport - .log_load_state( - span_id.clone(), - auth_context, - self.session.state.get_load_request_meta(), - "Load Request".to_string(), - serde_json::json!({ - "query": span_id.as_ref().unwrap().query_key.clone(), - }), - ) - .await?; - } - let result = self - .parse(body, &mut qtrace, span_id.clone()) - .await - .map_err(|e| e.with_span_id(span_id)); - if let Err(err) = &result { - if let Some(qtrace) = &mut qtrace { - qtrace.set_query_error_message(&err.to_string()) - } - }; - if let Some(qtrace) = &qtrace { - qtrace.save_json() - } - result - } else { + if tracked_error.is_some() { continue; } + doing_extended_query_message = true; + let mut qtrace = Qtrace::new(&body.query); + let span_id = Self::new_span_id(body.query.clone()); + if let Some(qtrace) = &qtrace { + debug!("Assigned query UUID: {}", qtrace.uuid()) + } + if let Some(auth_context) = self.session.state.auth_context() { + self.session + .session_manager + .server + .transport + .log_load_state( + span_id.clone(), + auth_context, + self.session.state.get_load_request_meta(), + "Load Request".to_string(), + serde_json::json!({ + "query": span_id.as_ref().unwrap().query_key.clone(), + // Hide query by default until Execute + "isDataQuery": false, + }), + ) + .await?; + } + let result = self + .parse(body, &mut qtrace, span_id.clone()) + .await + .map_err(|e| e.with_span_id(span_id)); + if let Err(err) = &result { + if let Some(qtrace) = &mut qtrace { + qtrace.set_query_error_message(&err.to_string()) + } + }; + if let Some(qtrace) = &qtrace { + qtrace.save_json() + } + result } protocol::FrontendMessage::Bind(body) => { if tracked_error.is_none() { doing_extended_query_message = true; - let span_id = { - let statements_guard = self.session.state.statements.read().await; - statements_guard - .get(&body.statement) - .and_then(|s| s.span_id()) - }; - self.bind(body, span_id).await - } else { - continue; } + let span_id = { + let statements_guard = self.session.state.statements.read().await; + statements_guard + .get(&body.statement) + .and_then(|s| s.span_id()) + }; + self.bind(body, span_id).await } protocol::FrontendMessage::Execute(body) => { - if tracked_error.is_none() { - doing_extended_query_message = true; - let span_id = if let Some(portal) = self.portals.get(&body.portal) { - portal.span_id() - } else { - None - }; - let result = self - .execute(body) - .await - .map_err(|e| e.with_span_id(span_id.clone())); - if result.is_ok() { - if let Some(auth_context) = self.session.state.auth_context() { - if let Some(span_id) = span_id { - self.session - .session_manager - .server - .transport - .log_load_state( - Some(span_id.clone()), - auth_context, - self.session.state.get_load_request_meta(), - "Load Request Success".to_string(), - serde_json::json!({ - "query": span_id.query_key.clone(), - "apiType": "sql", - "duration": span_id.duration(), - "isDataQuery": span_id.is_data_query().await - }), - ) - .await?; - } + let span_id = self + .portals + .get(&body.portal) + .and_then(|portal| portal.span_id()); + if tracked_error.is_some() { + if let Some(auth_context) = self.session.state.auth_context() { + if let Some(span_id) = span_id { + // If there was an error, always show the query + self.session + .session_manager + .server + .transport + .log_load_state( + Some(span_id.clone()), + auth_context, + self.session.state.get_load_request_meta(), + "Data Query Status".to_string(), + serde_json::json!({ + "isDataQuery": true + }), + ) + .await?; } } - result - } else { continue; } + doing_extended_query_message = true; + let result = self + .execute(body) + .await + .map_err(|e| e.with_span_id(span_id.clone())); + if let Some(auth_context) = self.session.state.auth_context() { + if let Some(span_id) = span_id { + // Always indicate whether this is a data query + // Errors are always visible ("data queries") + if result.is_err() { + self.session + .session_manager + .server + .transport + .log_load_state( + Some(span_id.clone()), + auth_context.clone(), + self.session.state.get_load_request_meta(), + "Data Query Status".to_string(), + serde_json::json!({ + "isDataQuery": true, + }), + ) + .await?; + } else { + self.session + .session_manager + .server + .transport + .log_load_state( + Some(span_id.clone()), + auth_context, + self.session.state.get_load_request_meta(), + "Load Request Success".to_string(), + serde_json::json!({ + "query": span_id.query_key.clone(), + "apiType": "sql", + "duration": span_id.duration(), + "isDataQuery": span_id.is_data_query().await + }), + ) + .await?; + } + } + } + result } protocol::FrontendMessage::Close(body) => { if tracked_error.is_none() { @@ -881,6 +914,10 @@ impl AsyncPostgresShim { self.write(packet.clone()).await } }, + PreparedStatement::Error { .. } => Err(CubeError::internal( + "Describe called on errored prepared statement (it's a bug)".to_string(), + ) + .into()), }, } } @@ -1039,6 +1076,11 @@ impl AsyncPostgresShim { Portal::new(plan, format, PortalFrom::Extended, span_id) } + PreparedStatement::Error { .. } => { + drop(statements_guard); + + Portal::new_empty(format, PortalFrom::Extended, span_id) + } }; self.portals.insert(body.portal, portal); @@ -1064,12 +1106,26 @@ impl AsyncPostgresShim { }, ); } else { - let query = parse_sql_to_statement(&parse.query, DatabaseProtocol::PostgreSQL, qtrace)?; - if let Some(qtrace) = qtrace { - qtrace.push_statement(&query); + match parse_sql_to_statement(&parse.query, DatabaseProtocol::PostgreSQL, qtrace) { + Ok(query) => { + if let Some(qtrace) = qtrace { + qtrace.push_statement(&query); + } + self.prepare_statement(parse.name, Ok(query), false, qtrace, span_id.clone()) + .await?; + } + Err(err) => { + self.prepare_statement( + parse.name, + Err(parse.query.to_string()), + false, + qtrace, + span_id.clone(), + ) + .await?; + Err(err)?; + } } - self.prepare_statement(parse.name, query, false, qtrace, span_id.clone()) - .await?; } self.write(protocol::ParseComplete::new()).await?; @@ -1080,7 +1136,7 @@ impl AsyncPostgresShim { pub async fn prepare_statement( &mut self, name: String, - query: Statement, + query: Result, from_sql: bool, qtrace: &mut Option, span_id: Option>, @@ -1106,49 +1162,78 @@ impl AsyncPostgresShim { )); } - let stmt_finder = PostgresStatementParamsFinder::new(); - let parameters: Vec = stmt_finder - .find(&query)? - .into_iter() - .map(|param| param.coltype.to_pg_tid()) - .collect(); + let (pstmt, result) = match query { + Ok(query) => { + let stmt_finder = PostgresStatementParamsFinder::new(); + let parameters: Vec = stmt_finder + .find(&query)? + .into_iter() + .map(|param| param.coltype.to_pg_tid()) + .collect(); - let meta = self - .session - .server - .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) - .await?; - - let stmt_replacer = StatementPlaceholderReplacer::new(); - let hacked_query = stmt_replacer.replace(query.clone())?; + let meta = self + .session + .server + .compiler_cache + .meta(self.auth_context()?, self.session.state.protocol.clone()) + .await?; - let plan = convert_statement_to_cube_query( - hacked_query, - meta, - self.session.clone(), - qtrace, - span_id.clone(), - ) - .await?; + let stmt_replacer = StatementPlaceholderReplacer::new(); + let hacked_query = stmt_replacer.replace(query.clone())?; - let description = if let Some(description) = plan.to_row_description(Format::Text)? { - if description.len() > 0 { - Some(description) - } else { - None + let plan = convert_statement_to_cube_query( + hacked_query, + meta, + self.session.clone(), + qtrace, + span_id.clone(), + ) + .await; + + match plan { + Ok(plan) => { + let description = + plan.to_row_description(Format::Text)? + .and_then(|description| { + if description.len() > 0 { + Some(description) + } else { + None + } + }); + + ( + PreparedStatement::Query { + from_sql, + created: chrono::offset::Utc::now(), + query, + parameters: protocol::ParameterDescription::new(parameters), + description, + span_id, + }, + Ok(()), + ) + } + Err(err) => ( + PreparedStatement::Error { + from_sql, + sql: query.to_string(), + created: chrono::offset::Utc::now(), + span_id, + }, + Err(err.into()), + ), + } } - } else { - None - }; - - let pstmt = PreparedStatement::Query { - from_sql, - created: chrono::offset::Utc::now(), - query, - parameters: protocol::ParameterDescription::new(parameters), - description, - span_id, + Err(sql) => ( + PreparedStatement::Error { + from_sql, + sql, + created: chrono::offset::Utc::now(), + span_id, + }, + Ok(()), + ), }; self.session .state @@ -1157,7 +1242,7 @@ impl AsyncPostgresShim { .await .insert(name, pstmt); - Ok(()) + result } pub fn end_transaction(&mut self) -> Result { @@ -1623,7 +1708,7 @@ impl AsyncPostgresShim { _ => *statement, }; - self.prepare_statement(name.value, statement, true, qtrace, span_id.clone()) + self.prepare_statement(name.value, Ok(statement), true, qtrace, span_id.clone()) .await?; let plan = QueryPlan::MetaOk(StatusFlags::empty(), CommandCompletion::Prepare);