diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 53b1eaf6ea5b..4b061dd6e592 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -48,6 +48,7 @@ use query::parser::QueryStatement; use query::QueryEngineRef; use session::context::{Channel, QueryContextRef}; use session::table_name::table_idents_to_full_name; +use set::set_query_timeout; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument}; use sql::statements::set_variables::SetVariables; @@ -338,6 +339,31 @@ impl StatementExecutor { "DATESTYLE" => set_datestyle(set_var.value, query_ctx)?, "CLIENT_ENCODING" => validate_client_encoding(set_var)?, + // TODO: write sqlness test for query timeout variables + // once the proper channel is configured in the test infra. + // The current sqlness test channel is default to Unknown. + "MAX_EXECUTION_TIME" => match query_ctx.channel() { + Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?, + Channel::Postgres => { + query_ctx.set_warning(format!("Unsupported set variable {}", var_name)) + } + _ => { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail() + } + }, + "STATEMENT_TIMEOUT" => { + if query_ctx.channel() == Channel::Postgres { + set_query_timeout(set_var.value, query_ctx)? + } else { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail(); + } + } _ => { // for postgres, we give unknown SET statements a warning with // success, this is prevent the SET call becoming a blocker diff --git a/src/operator/src/statement/set.rs b/src/operator/src/statement/set.rs index 6436f136d9c5..11f81407b082 100644 --- a/src/operator/src/statement/set.rs +++ b/src/operator/src/statement/set.rs @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use common_time::Timezone; +use lazy_static::lazy_static; +use regex::Regex; +use session::context::Channel::Postgres; use session::context::QueryContextRef; use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; use snafu::{ensure, OptionExt, ResultExt}; @@ -21,6 +26,15 @@ use sql::statements::set_variables::SetVariables; use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result}; +lazy_static! { + // Regex rules: + // The string must start with a number (one or more digits). + // The number must be followed by one of the valid time units (ms, s, min, h, d). + // The string must end immediately after the unit, meaning there can be no extra + // characters or spaces after the valid time specification. + static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap(); +} + pub fn set_timezone(exprs: Vec, ctx: QueryContextRef) -> Result<()> { let tz_expr = exprs.first().context(NotSupportedSnafu { feat: "No timezone find in set variable statement", @@ -177,3 +191,87 @@ pub fn set_datestyle(exprs: Vec, ctx: QueryContextRef) -> Result<()> { .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order)); Ok(()) } + +pub fn set_query_timeout(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let timeout_expr = exprs.first().context(NotSupportedSnafu { + feat: "No timeout value find in set query timeout statement", + })?; + match timeout_expr { + Expr::Value(Value::Number(timeout, _)) => { + match timeout.parse::() { + Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)), + Err(_) => { + return NotSupportedSnafu { + feat: format!("Invalid timeout expr {} in set variable statement", timeout), + } + .fail() + } + } + Ok(()) + } + // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms'; + Expr::Value(Value::SingleQuotedString(timeout)) + | Expr::Value(Value::DoubleQuotedString(timeout)) => { + if ctx.channel() != Postgres { + return NotSupportedSnafu { + feat: format!("Invalid timeout expr {} in set variable statement", timeout), + } + .fail(); + } + let timeout = parse_pg_query_timeout_input(timeout)?; + ctx.set_query_timeout(Duration::from_millis(timeout)); + Ok(()) + } + expr => NotSupportedSnafu { + feat: format!( + "Unsupported timeout expr {} in set variable statement", + expr + ), + } + .fail(), + } +} + +// support time units in ms, s, min, h, d for postgres protocol. +// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days). +fn parse_pg_query_timeout_input(input: &str) -> Result { + if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) { + let value = captures[1].parse::().expect("regex failed"); + let unit = &captures[2]; + + match unit { + "ms" => Ok(value), + "s" => Ok(value * 1000), + "min" => Ok(value * 60 * 1000), + "h" => Ok(value * 60 * 60 * 1000), + "d" => Ok(value * 24 * 60 * 60 * 1000), + _ => unreachable!("regex failed"), + } + } else { + NotSupportedSnafu { + feat: format!( + "Unsupported timeout expr {} in set variable statement", + input + ), + } + .fail() + } +} + +#[cfg(test)] +mod test { + use crate::statement::set::parse_pg_query_timeout_input; + + #[test] + fn test_parse_pg_query_timeout_input() { + assert!(parse_pg_query_timeout_input("").is_err()); + assert!(parse_pg_query_timeout_input(" 50 ms").is_err()); + assert!(parse_pg_query_timeout_input("5s 1ms").is_err()); + assert!(parse_pg_query_timeout_input("3a").is_err()); + assert!(parse_pg_query_timeout_input("1.5min").is_err()); + + assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap()); + assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap()); + assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap()); + } +} diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index 5679cd5dc43d..172961d50a1f 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -48,7 +48,7 @@ use datatypes::vectors::StringVector; use object_store::ObjectStore; use once_cell::sync::Lazy; use regex::Regex; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContextRef}; pub use show_create_table::create_table_stmt; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Ident; @@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result< let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); format!("{}, {}", style, order) } + "MAX_EXECUTION_TIME" => { + if query_ctx.channel() == Channel::Mysql { + query_ctx.query_timeout_as_millis().to_string() + } else { + return UnsupportedVariableSnafu { name: variable }.fail(); + } + } + "STATEMENT_TIMEOUT" => { + // Add time units to postgres query timeout display. + if query_ctx.channel() == Channel::Postgres { + let mut timeout = query_ctx.query_timeout_as_millis().to_string(); + timeout.push_str("ms"); + timeout + } else { + return UnsupportedVariableSnafu { name: variable }.fail(); + } + } _ => return UnsupportedVariableSnafu { name: variable }.fail(), }; let schema = Arc::new(Schema::new(vec![ColumnSchema::new( diff --git a/src/session/src/context.rs b/src/session/src/context.rs index f85a8ceea313..cab351176b21 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; +use std::time::Duration; use api::v1::region::RegionRequestHeader; use arc_swap::ArcSwap; @@ -282,6 +283,22 @@ impl QueryContext { pub fn set_warning(&self, msg: String) { self.mutable_query_context_data.write().unwrap().warning = Some(msg); } + + pub fn query_timeout(&self) -> Option { + self.mutable_session_data.read().unwrap().query_timeout + } + + pub fn query_timeout_as_millis(&self) -> u128 { + let timeout = self.mutable_session_data.read().unwrap().query_timeout; + if let Some(t) = timeout { + return t.as_millis(); + } + 0 + } + + pub fn set_query_timeout(&self, timeout: Duration) { + self.mutable_session_data.write().unwrap().query_timeout = Some(timeout); + } } impl QueryContextBuilder { diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 33bd140c7057..5ddaae7eb579 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -18,6 +18,7 @@ pub mod table_name; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; +use std::time::Duration; use auth::UserInfoRef; use common_catalog::build_db_string; @@ -45,6 +46,7 @@ pub(crate) struct MutableInner { schema: String, user_info: UserInfoRef, timezone: Timezone, + query_timeout: Option, } impl Default for MutableInner { @@ -53,6 +55,7 @@ impl Default for MutableInner { schema: DEFAULT_SCHEMA_NAME.into(), user_info: auth::userinfo_by_name(None), timezone: get_timezone(None).clone(), + query_timeout: None, } } } diff --git a/src/sql/src/parsers/set_var_parser.rs b/src/sql/src/parsers/set_var_parser.rs index e2a7db9d08a2..8a66269803cc 100644 --- a/src/sql/src/parsers/set_var_parser.rs +++ b/src/sql/src/parsers/set_var_parser.rs @@ -58,47 +58,83 @@ mod tests { use crate::dialect::GreptimeDbDialect; use crate::parser::ParseOptions; - fn assert_mysql_parse_result(sql: &str) { + fn assert_mysql_parse_result(sql: &str, indent_str: &str, expr: Expr) { let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), Statement::SetVariables(SetVariables { - variable: ObjectName(vec![Ident::new("time_zone")]), - value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))] + variable: ObjectName(vec![Ident::new(indent_str)]), + value: vec![expr] }) ); } - fn assert_pg_parse_result(sql: &str) { + fn assert_pg_parse_result(sql: &str, indent: &str, expr: Expr) { let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), Statement::SetVariables(SetVariables { - variable: ObjectName(vec![Ident::new("TIMEZONE")]), - value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))], + variable: ObjectName(vec![Ident::new(indent)]), + value: vec![expr], }) ); } #[test] pub fn test_set_timezone() { + let expected_utc_expr = Expr::Value(Value::SingleQuotedString("UTC".to_string())); // mysql style let sql = "SET time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); // session or local style let sql = "SET LOCAL time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); let sql = "SET SESSION time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); // postgresql style let sql = "SET TIMEZONE TO 'UTC'"; - assert_pg_parse_result(sql); + assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr.clone()); let sql = "SET TIMEZONE 'UTC'"; - assert_pg_parse_result(sql); + assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr); + } + + #[test] + pub fn test_set_query_timeout() { + let expected_query_timeout_expr = Expr::Value(Value::Number("5000".to_string(), false)); + // mysql style + let sql = "SET MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + // session or local style + let sql = "SET LOCAL MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + let sql = "SET SESSION MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + + // postgresql style + let sql = "SET STATEMENT_TIMEOUT = 5000"; + assert_pg_parse_result( + sql, + "STATEMENT_TIMEOUT", + expected_query_timeout_expr.clone(), + ); + let sql = "SET STATEMENT_TIMEOUT TO 5000"; + assert_pg_parse_result(sql, "STATEMENT_TIMEOUT", expected_query_timeout_expr); } }