diff --git a/sqlx-core/src/sqlite/connection/worker.rs b/sqlx-core/src/sqlite/connection/worker.rs index b737fbf386..dd38617d22 100644 --- a/sqlx-core/src/sqlite/connection/worker.rs +++ b/sqlx-core/src/sqlite/connection/worker.rs @@ -32,6 +32,16 @@ pub(crate) struct ConnectionWorker { pub(crate) handle_raw: ConnectionHandleRaw, /// Mutex for locking access to the database. pub(crate) shared: Arc, + + // Tracks `shared.conn.transaction_depth` to help provide cancel-safety. Updated only when + // `begin()` / `commit()` / `rollback()` successfully complete. + // + // - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred + // - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled + // - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or + // `rollback()` was cancelled + // - No other cases are possible (would indicate a logic bug) + transaction_depth: usize, } pub(crate) struct WorkerSharedState { @@ -52,15 +62,19 @@ enum Command { query: Box, arguments: Option>, persistent: bool, + transaction_depth: usize, tx: flume::Sender, Error>>, }, Begin { + transaction_depth: usize, tx: oneshot::Sender>, }, Commit { + transaction_depth: usize, tx: oneshot::Sender>, }, Rollback { + transaction_depth: usize, tx: Option>>, }, CreateCollation { @@ -110,6 +124,7 @@ impl ConnectionWorker { command_tx, handle_raw: conn.handle.to_raw(), shared: Arc::clone(&shared), + transaction_depth: 0, })) .is_err() { @@ -135,8 +150,15 @@ impl ConnectionWorker { query, arguments, persistent, + transaction_depth, tx, } => { + if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth) + { + tx.send(Err(error)).ok(); + continue; + } + let iter = match execute::iter(&mut conn, &query, arguments, persistent) { Ok(iter) => iter, @@ -154,7 +176,16 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { + Command::Begin { + transaction_depth, + tx, + } => { + if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth) + { + tx.send(Err(error)).ok(); + continue; + } + let depth = conn.transaction_depth; let res = conn.handle @@ -165,9 +196,17 @@ impl ConnectionWorker { tx.send(res).ok(); } - Command::Commit { tx } => { - let depth = conn.transaction_depth; + Command::Commit { + transaction_depth, + tx, + } => { + if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth) + { + tx.send(Err(error)).ok(); + continue; + } + let depth = conn.transaction_depth; let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) @@ -180,9 +219,26 @@ impl ConnectionWorker { tx.send(res).ok(); } - Command::Rollback { tx } => { - let depth = conn.transaction_depth; + Command::Rollback { + transaction_depth, + tx, + } => { + match handle_cancelled_begin_or_commit_or_rollback( + &mut conn, + transaction_depth, + ) { + Ok(true) => (), + Ok(false) => continue, + Err(error) => { + if let Some(tx) = tx { + tx.send(Err(error)).ok(); + } + continue; + } + } + + let depth = conn.transaction_depth; let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) @@ -259,6 +315,7 @@ impl ConnectionWorker { query: query.into(), arguments: args.map(SqliteArguments::into_static), persistent, + transaction_depth: self.transaction_depth, tx, }) .await @@ -268,21 +325,56 @@ impl ConnectionWorker { } pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd(|tx| Command::Begin { tx }).await? + let transaction_depth = self.transaction_depth; + + self.oneshot_cmd(|tx| Command::Begin { + transaction_depth, + tx, + }) + .await??; + + self.transaction_depth += 1; + + Ok(()) } pub(crate) async fn commit(&mut self) -> Result<(), Error> { - self.oneshot_cmd(|tx| Command::Commit { tx }).await? + let transaction_depth = self.transaction_depth; + + self.oneshot_cmd(|tx| Command::Commit { + transaction_depth, + tx, + }) + .await??; + + self.transaction_depth -= 1; + + Ok(()) } pub(crate) async fn rollback(&mut self) -> Result<(), Error> { - self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) }) - .await? + let transaction_depth = self.transaction_depth; + + self.oneshot_cmd(|tx| Command::Rollback { + transaction_depth, + tx: Some(tx), + }) + .await??; + + self.transaction_depth -= 1; + + Ok(()) } pub(crate) fn start_rollback(&mut self) -> Result<(), Error> { + let transaction_depth = self.transaction_depth; + self.transaction_depth -= 1; + self.command_tx - .send(Command::Rollback { tx: None }) + .send(Command::Rollback { + transaction_depth, + tx: None, + }) .map_err(|_| Error::WorkerCrashed) } @@ -387,3 +479,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result Result<(), Error> { + if expected_transaction_depth != conn.transaction_depth { + if expected_transaction_depth == conn.transaction_depth - 1 { + let depth = conn.transaction_depth; + conn.handle.exec(rollback_ansi_transaction_sql(depth))?; + conn.transaction_depth -= 1; + } else { + // This would indicate cancelled `commit` or `rollback`, but that can only happen when + // handling a `Rollback` command because `commit()` / `rollback()` take the + // transaction by value and so when they are cancelled the transaction is immediately + // dropped which sends a `Rollback`. + unreachable!() + } + } + + Ok(()) +} + +// Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()` +// as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding +// `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open` +// flag is not set to `false` which causes another `Rollback` to be sent when the transaction +// is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`. +// +// Use only when handling a `Rollback` command. +fn handle_cancelled_begin_or_commit_or_rollback( + conn: &mut ConnectionState, + expected_transaction_depth: usize, +) -> Result { + if expected_transaction_depth != conn.transaction_depth { + if expected_transaction_depth == conn.transaction_depth - 1 { + let depth = conn.transaction_depth; + conn.handle.exec(rollback_ansi_transaction_sql(depth))?; + conn.transaction_depth -= 1; + + Ok(true) + } else if expected_transaction_depth == conn.transaction_depth + 1 { + Ok(false) + } else { + unreachable!() + } + } else { + Ok(true) + } +}