diff --git a/src/backend/mysql/query.rs b/src/backend/mysql/query.rs index 58126014..3244838d 100644 --- a/src/backend/mysql/query.rs +++ b/src/backend/mysql/query.rs @@ -140,6 +140,17 @@ impl QueryBuilder for MysqlQueryBuilder { fn prepare_returning(&self, _returning: &Option, _sql: &mut dyn SqlWriter) {} + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!( + sql, + "SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = {}", + quoted_exception_message + ) + .unwrap(); + } + fn random_function(&self) -> &str { "RAND" } diff --git a/src/backend/postgres/query.rs b/src/backend/postgres/query.rs index f0650110..bf6662a2 100644 --- a/src/backend/postgres/query.rs +++ b/src/backend/postgres/query.rs @@ -153,6 +153,12 @@ impl QueryBuilder for PostgresQueryBuilder { sql.push_param(value.clone(), self as _); } + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!(sql, "RAISE EXCEPTION {}", quoted_exception_message).unwrap(); + } + fn write_string_quoted(&self, string: &str, buffer: &mut String) { let escaped = self.escape_string(string); let string = if escaped.find('\\').is_some() { diff --git a/src/backend/query_builder.rs b/src/backend/query_builder.rs index 41cac862..abcdf2ae 100644 --- a/src/backend/query_builder.rs +++ b/src/backend/query_builder.rs @@ -387,6 +387,9 @@ pub trait QueryBuilder: SimpleExpr::Constant(val) => { self.prepare_constant(val, sql); } + SimpleExpr::Exception(val) => { + self.prepare_exception_statement(val, sql); + } } } @@ -982,6 +985,15 @@ pub trait QueryBuilder: } } + // Translate [`Exception`] into SQL statement. + fn prepare_exception_statement( + &self, + _exception: &ExceptionStatement, + _sql: &mut dyn SqlWriter, + ) { + panic!("Exception handling not implemented for this backend"); + } + /// Convert a SQL value into syntax-specific string fn value_to_string(&self, v: &Value) -> String { self.value_to_string_common(v) diff --git a/src/backend/sqlite/query.rs b/src/backend/sqlite/query.rs index 0a062294..ceb0b6d4 100644 --- a/src/backend/sqlite/query.rs +++ b/src/backend/sqlite/query.rs @@ -84,6 +84,12 @@ impl QueryBuilder for SqliteQueryBuilder { "MIN" } + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!(sql, "SELECT RAISE(ABORT, {})", quoted_exception_message).unwrap(); + } + fn char_length_function(&self) -> &str { "LENGTH" } diff --git a/src/exception.rs b/src/exception.rs new file mode 100644 index 00000000..2a4af51a --- /dev/null +++ b/src/exception.rs @@ -0,0 +1,46 @@ +//! Custom SQL exceptions and errors +use inherent::inherent; + +use crate::backend::SchemaBuilder; + +/// SQL Exceptions +#[derive(Debug, Clone, PartialEq)] +pub struct ExceptionStatement { + pub(crate) message: String, +} + +impl ExceptionStatement { + pub fn new(message: String) -> Self { + Self { message } + } +} + +pub trait ExceptionStatementBuilder { + /// Build corresponding SQL statement for certain database backend and return SQL string + fn build(&self, schema_builder: T) -> String; + + /// Build corresponding SQL statement for certain database backend and return SQL string + fn build_any(&self, schema_builder: &dyn SchemaBuilder) -> String; + + /// Build corresponding SQL statement for certain database backend and return SQL string + fn to_string(&self, schema_builder: T) -> String { + self.build(schema_builder) + } +} + +#[inherent] +impl ExceptionStatementBuilder for ExceptionStatement { + pub fn build(&self, schema_builder: T) -> String { + let mut sql = String::with_capacity(256); + schema_builder.prepare_exception_statement(self, &mut sql); + sql + } + + pub fn build_any(&self, schema_builder: &dyn SchemaBuilder) -> String { + let mut sql = String::with_capacity(256); + schema_builder.prepare_exception_statement(self, &mut sql); + sql + } + + pub fn to_string(&self, schema_builder: T) -> String; +} diff --git a/src/expr.rs b/src/expr.rs index b6894c94..391fca1c 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -4,7 +4,7 @@ //! //! [`SimpleExpr`] is the expression common among select fields, where clauses and many other places. -use crate::{func::*, query::*, types::*, value::*}; +use crate::{exception::ExceptionStatement, func::*, query::*, types::*, value::*}; /// Helper to build a [`SimpleExpr`]. #[derive(Debug, Clone)] @@ -35,6 +35,7 @@ pub enum SimpleExpr { AsEnum(DynIden, Box), Case(Box), Constant(Value), + Exception(ExceptionStatement), } /// "Operator" methods for building complex expressions. diff --git a/src/lib.rs b/src/lib.rs index 15e1a189..4f7de509 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -814,6 +814,7 @@ pub mod backend; pub mod error; +pub mod exception; pub mod expr; pub mod extension; pub mod foreign_key; @@ -832,6 +833,7 @@ pub mod value; pub mod tests_cfg; pub use backend::*; +pub use exception::*; pub use expr::*; pub use foreign_key::*; pub use func::*; diff --git a/tests/mysql/exception.rs b/tests/mysql/exception.rs new file mode 100644 index 00000000..2446330c --- /dev/null +++ b/tests/mysql/exception.rs @@ -0,0 +1,20 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn signal_sqlstate() { + let message = "Some error occurred"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(MysqlQueryBuilder), + format!("SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = '{message}'") + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(MysqlQueryBuilder), + format!("SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Does this \\'break\\'?'") + ); +} diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index d717774f..e4acd836 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -1,5 +1,6 @@ use sea_query::{extension::mysql::*, tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query; diff --git a/tests/postgres/exception.rs b/tests/postgres/exception.rs new file mode 100644 index 00000000..372cac0b --- /dev/null +++ b/tests/postgres/exception.rs @@ -0,0 +1,20 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn raise_exception() { + let message = "Some error occurred"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(PostgresQueryBuilder), + format!("RAISE EXCEPTION '{message}'") + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(PostgresQueryBuilder), + format!("RAISE EXCEPTION E'Does this \\'break\\'?'") + ); +} diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index 82b85df3..d65872c9 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -1,5 +1,6 @@ use sea_query::{tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query; diff --git a/tests/sqlite/exception.rs b/tests/sqlite/exception.rs new file mode 100644 index 00000000..81d17507 --- /dev/null +++ b/tests/sqlite/exception.rs @@ -0,0 +1,21 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn select_raise_abort() { + let message = "Some error occurred here"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(SqliteQueryBuilder), + format!("SELECT RAISE(ABORT, '{}')", message) + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + let escaped_message = "Does this ''break''?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(SqliteQueryBuilder), + format!("SELECT RAISE(ABORT, '{}')", escaped_message) + ); +} diff --git a/tests/sqlite/mod.rs b/tests/sqlite/mod.rs index fc7388cd..7206e8fe 100644 --- a/tests/sqlite/mod.rs +++ b/tests/sqlite/mod.rs @@ -1,5 +1,6 @@ use sea_query::{tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query;