From b356cbcac5fa7d5b16821c26b0452e95b308f9b3 Mon Sep 17 00:00:00 2001 From: Lucian Buzzo Date: Sun, 15 Oct 2023 13:58:28 +0100 Subject: [PATCH] feat: add support for nested transaction rollbacks via savepoints in sql This is my first OSS contribution for a Rust project, so I'm sure I've made some stupid mistakes, but I think it should mostly work :) This change adds a mutable depth counter, that can track how many levels deep a transaction is, and uses savepoints to implement correct rollback behaviour. Previously, once a nested transaction was complete, it would be saved with `COMMIT`, meaning that even if the outer transaction was rolled back, the operations in the inner transaction would persist. With this change, if the outer transaction gets rolled back, then all inner transactions will also be rolled back. Different flavours of SQL servers have different syntax for handling savepoints, so I've had to add new methods to the `Queryable` trait for getting the commit and rollback statements. These are both parameterized by the current depth. I've additionally had to modify the `begin_statement` method to accept a depth parameter, as it will need to conditionally create a savepoint. When opening a transaction via the transaction server, you can now pass the prior transaction ID to re-use the existing transaction, incrementing the depth. Signed-off-by: Lucian Buzzo --- quaint/src/connector/mssql/native/mod.rs | 51 +++++++- quaint/src/connector/mysql/native/mod.rs | 36 +++++- quaint/src/connector/postgres/native/mod.rs | 37 +++++- quaint/src/connector/queryable.rs | 39 +++++- quaint/src/connector/sqlite/native/mod.rs | 41 +++++- quaint/src/connector/transaction.rs | 79 ++++++++++-- quaint/src/pooled.rs | 5 +- quaint/src/pooled/manager.rs | 15 ++- quaint/src/single.rs | 21 ++- quaint/src/tests/query.rs | 16 ++- quaint/src/tests/query/error.rs | 2 +- .../tests/new/interactive_tx.rs | 120 ++++++++++++++---- .../query-engine-tests/tests/new/metrics.rs | 4 +- .../tests/new/regressions/prisma_13405.rs | 2 +- .../tests/new/regressions/prisma_15607.rs | 2 +- .../query-tests-setup/src/runner/mod.rs | 3 +- .../src/interface/transaction.rs | 17 ++- .../query-connector/src/interface.rs | 5 +- .../src/database/transaction.rs | 18 +-- query-engine/core/src/executor/mod.rs | 5 +- .../interactive_transactions/actor_manager.rs | 34 +++-- .../src/interactive_transactions/actors.rs | 68 ++++++++-- .../src/interactive_transactions/messages.rs | 8 +- .../core/src/interactive_transactions/mod.rs | 4 +- query-engine/driver-adapters/src/proxy.rs | 9 ++ query-engine/driver-adapters/src/queryable.rs | 15 ++- .../driver-adapters/src/transaction.rs | 56 +++++++- query-engine/query-engine/src/server/mod.rs | 6 +- 28 files changed, 596 insertions(+), 122 deletions(-) diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index d22aa7a15dd..354d327acee 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -17,7 +17,7 @@ use futures::lock::Mutex; use std::{ convert::TryFrom, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{Arc, atomic::{AtomicBool, Ordering}}, time::Duration, }; use tiberius::*; @@ -44,11 +44,13 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) + Ok(Box::new(DefaultTransaction::new(self, opts).await?)) } } @@ -59,6 +61,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -90,6 +93,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -229,8 +233,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index fdcc3a6276d..90483deacff 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -21,7 +21,7 @@ use mysql_async::{ }; use std::{ future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{Arc, atomic::{AtomicBool, Ordering}}, time::Duration, }; use tokio::sync::Mutex; @@ -74,6 +74,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } impl Mysql { @@ -87,6 +88,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -294,4 +296,36 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 30f34e7002b..7f820648e62 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -24,7 +24,7 @@ use std::{ fmt::{Debug, Display}, fs, future::Future, - sync::atomic::{AtomicBool, Ordering}, + sync::{Arc, atomic::{AtomicBool, Ordering}}, time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; @@ -50,6 +50,7 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex>, is_healthy: AtomicBool, + transaction_depth: Arc>, } #[derive(Debug)] @@ -243,6 +244,7 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -523,6 +525,39 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 09dbc7abba4..10e551af4ba 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -87,8 +87,35 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; } /// Sets the transaction isolation level to given value. @@ -117,10 +144,14 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 3bf0c46a7db..9dc20c2f41b 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -16,7 +16,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::convert::TryFrom; +use std::{sync::Arc, convert::TryFrom}; use tokio::sync::Mutex; /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. @@ -26,6 +26,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } impl TryFrom<&str> for Sqlite { @@ -43,7 +44,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -58,6 +62,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -154,6 +159,38 @@ impl Queryable for Sqlite { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { savepoint_stmt } else { "BEGIN".to_string() }; + + return ret; + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret; + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret; + } } #[cfg(test)] diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a..7330b8f8f24 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -4,18 +4,22 @@ use crate::{ error::{Error, ErrorKind}, }; use async_trait::async_trait; +use futures::lock::Mutex; use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; +use std::{sync::Arc, fmt, str::FromStr}; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + + /// Commit the changes to the database and consume the transaction. + async fn commit(&mut self) -> crate::Result; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +31,9 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,15 +43,18 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, } impl<'a> DefaultTransaction<'a> { pub(crate) async fn new( inner: &'a dyn Queryable, - begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let mut this = Self { + inner, + depth: tx_opts.depth, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -52,7 +62,7 @@ impl<'a> DefaultTransaction<'a> { } } - inner.raw_cmd(begin_stmt).await?; + this.begin().await?; if !tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,27 +72,63 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { + async fn begin(&mut self) -> crate::Result<()> { + increment_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.inner.begin_statement(st_depth).await; + + self.inner.raw_cmd(&begin_statement).await?; + + Ok(()) + } + /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let commit_statement = self.inner.commit_statement(st_depth).await; + + self.inner.raw_cmd(&commit_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; - Ok(()) + let mut depth_guard = self.depth.lock().await; + + let st_depth = *depth_guard; + + let rollback_statement = self.inner.rollback_statement(st_depth).await; + + self.inner.raw_cmd(&rollback_statement).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { @@ -190,10 +236,15 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new( + isolation_level: Option, + isolation_first: bool, + depth: Arc>, + ) -> Self { Self { isolation_level, isolation_first, + depth, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 4c415292337..458a3412ece 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -500,7 +500,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 73441b7609b..087ea01e5ce 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -10,12 +10,15 @@ use crate::{ error::Error, }; use async_trait::async_trait; +use futures::lock::Mutex; use mobc::{Connection as MobcPooled, Manager}; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -62,8 +65,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 1a4dbdf52a6..a2608945c44 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -5,6 +5,7 @@ use crate::{ connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, }; use async_trait::async_trait; +use futures::lock::Mutex; use std::{fmt, sync::Arc}; #[cfg(feature = "sqlite-native")] @@ -15,6 +16,7 @@ use std::convert::TryFrom; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -162,7 +164,11 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { + inner, + connection_info, + transaction_depth: Arc::new(Mutex::new(0)), + }) } #[cfg(feature = "sqlite-native")] @@ -175,6 +181,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), }), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -229,8 +236,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a960..cf471fbf733 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d..67334858576 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index 33908a9e079..bfd4a57b95a 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -8,7 +8,7 @@ mod interactive_tx { #[connector_test] async fn basic_commit_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -35,7 +35,7 @@ mod interactive_tx { #[connector_test] async fn basic_rollback_workflow(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -63,7 +63,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one second. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -108,7 +108,7 @@ mod interactive_tx { #[connector_test] async fn no_auto_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -135,7 +135,7 @@ mod interactive_tx { #[connector_test(only(Postgres))] async fn raw_queries(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -164,7 +164,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_success(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -190,7 +190,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_rollback(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); let queries = vec![ @@ -216,7 +216,7 @@ mod interactive_tx { #[connector_test] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); // One dup key, will cause failure of the batch. @@ -259,7 +259,7 @@ mod interactive_tx { #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. - let tx_id = runner.start_tx(5000, 1000, None).await?; + let tx_id = runner.start_tx(5000, 1000, None, None).await?; runner.set_active_tx(tx_id.clone()); // Row is created @@ -328,10 +328,10 @@ mod interactive_tx { #[connector_test(exclude(Sqlite))] async fn multiple_tx(mut runner: Runner) -> TestResult<()> { // First transaction. - let tx_id_a = runner.start_tx(2000, 2000, None).await?; + let tx_id_a = runner.start_tx(2000, 2000, None, None).await?; // Second transaction. - let tx_id_b = runner.start_tx(2000, 2000, None).await?; + let tx_id_b = runner.start_tx(2000, 2000, None, None).await?; // Execute on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -379,10 +379,10 @@ mod interactive_tx { ); // First transaction. - let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_a = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Second transaction. - let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into())).await?; + let tx_id_b = runner.start_tx(5000, 5000, Some("Serializable".into()), None).await?; // Read on first transaction. runner.set_active_tx(tx_id_a.clone()); @@ -421,7 +421,7 @@ mod interactive_tx { #[connector_test] async fn double_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -456,9 +456,82 @@ mod interactive_tx { Ok(()) } + + #[connector_test(only(Postgres))] + async fn nested_commit_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + let res = runner.commit_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + Ok(()) + } + + #[connector_test(only(Postgres))] + async fn nested_commit_rollback_workflow(mut runner: Runner) -> TestResult<()> { + // Start the outer transaction + let outer_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(outer_tx_id.clone()); + + // Start the inner transaction + let inner_tx_id = runner.start_tx(5000, 5000, None, Some(outer_tx_id.clone())).await?; + runner.set_active_tx(inner_tx_id.clone()); + + // Perform operations in the inner transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 1 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":1}}}"### + ); + + let res = runner.commit_tx(inner_tx_id).await?; + assert!(res.is_ok()); + + // Perform operations in the outer transaction and commit + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneTestModel(data: { id: 2 }) { id }}"#), + @r###"{"data":{"createOneTestModel":{"id":2}}}"### + ); + + // Now rollback the outer transaction + let res = runner.rollback_tx(outer_tx_id).await?; + assert!(res.is_ok()); + + // Assert that no records were written to the DB + let result_tx_id = runner.start_tx(5000, 5000, None, None).await?; + runner.set_active_tx(result_tx_id.clone()); + insta::assert_snapshot!( + run_query!(&runner, r#"query { findManyTestModel { id field }}"#), + @r###"{"data":{"findManyTestModel":[]}}"### + ); + let _ = runner.commit_tx(result_tx_id).await?; + + Ok(()) + } + #[connector_test] async fn double_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -495,7 +568,7 @@ mod interactive_tx { #[connector_test] async fn commit_after_rollback(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -532,7 +605,7 @@ mod interactive_tx { #[connector_test] async fn rollback_after_commit(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -552,6 +625,7 @@ mod interactive_tx { let error = res.err().unwrap(); let known_err = error.as_known().unwrap(); + println!("Error: {:?}", known_err); assert_eq!(known_err.error_code, Cow::Borrowed("P2028")); assert!(known_err @@ -575,7 +649,7 @@ mod itx_isolation { // All (SQL) connectors support serializable. #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -597,7 +671,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb, Vitess("planetscale.js")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; + let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned()), None).await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -608,13 +682,13 @@ mod itx_isolation { #[connector_test(only(Postgres))] async fn spacing_doesnt_matter(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned())).await?; + let tx_id = runner.start_tx(5000, 5000, Some("Repeatable Read".to_owned()), None).await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; assert!(res.is_ok()); - let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned())).await?; + let tx_id = runner.start_tx(5000, 5000, Some("RepeatableRead".to_owned()), None).await?; runner.set_active_tx(tx_id.clone()); let res = runner.commit_tx(tx_id).await?; @@ -625,7 +699,7 @@ mod itx_isolation { #[connector_test(exclude(MongoDb))] async fn invalid_isolation(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("test".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected invalid isolation level string to throw an error, but it succeeded instead."), @@ -638,7 +712,7 @@ mod itx_isolation { // Mongo doesn't support isolation levels. #[connector_test(only(MongoDb))] async fn mongo_failure(runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await; + let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned()), None).await; match tx_id { Ok(_) => panic!("Expected mongo to throw an unsupported error, but it succeeded instead."), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index cd270bb334c..05406ddcddc 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -50,7 +50,7 @@ mod metrics { #[connector_test] async fn metrics_tx_do_not_go_negative(mut runner: Runner) -> TestResult<()> { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( @@ -67,7 +67,7 @@ mod metrics { let active_transactions = get_gauge(&json, PRISMA_CLIENT_QUERIES_ACTIVE); assert_eq!(active_transactions, 0.0); - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); insta::assert_snapshot!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs index a9b6c439576..49ea6597ff6 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_13405.rs @@ -90,7 +90,7 @@ mod mongodb { } async fn start_itx(runner: &mut Runner) -> TestResult { - let tx_id = runner.start_tx(5000, 5000, None).await?; + let tx_id = runner.start_tx(5000, 5000, None, None).await?; runner.set_active_tx(tx_id.clone()); Ok(tx_id) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index 3ab34b12010..ebd8accfb35 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -82,7 +82,7 @@ impl Actor { response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = with_logs(runner.start_tx(10000, 10000, None, None), log_tx.clone()).await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index 03e2dce5c5e..750a9ca9976 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -365,8 +365,9 @@ impl Runner { max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, + new_tx_id: Option, ) -> TestResult { - let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level); + let tx_opts = TransactionOptions::new(max_acquisition_millis, valid_for_millis, isolation_level, new_tx_id); match &self.executor { RunnerExecutor::Builtin(executor) => { let id = executor diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 1de0bb8c750..090618bbea0 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -40,26 +40,35 @@ impl<'conn> MongoDbTransaction<'conn> { #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { - async fn commit(&mut self) -> connector_interface::Result<()> { + async fn begin(&mut self) -> connector_interface::Result<()> { + Ok(()) + } + + async fn commit(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + println!("Committing transaction"); + utils::commit_with_retry(&mut self.connection.session) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + Ok(0) } - async fn rollback(&mut self) -> connector_interface::Result<()> { + async fn rollback(&mut self) -> connector_interface::Result { decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + println!("Rolling back transaction"); + self.connection .session .abort_transaction() .await .map_err(|err| MongoError::from(err).into_connector_error())?; - Ok(()) + println!("Transaction rolled back"); + Ok(0) } fn as_connection_like(&mut self) -> &mut dyn ConnectionLike { diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index 942edd1868f..f0bb64a9684 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -30,8 +30,9 @@ pub trait Connection: ConnectionLike { #[async_trait] pub trait Transaction: ConnectionLike { - async fn commit(&mut self) -> crate::Result<()>; - async fn rollback(&mut self) -> crate::Result<()>; + async fn begin(&mut self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result; + async fn rollback(&mut self) -> crate::Result; /// Explicit upcast of self reference. Rusts current vtable layout doesn't allow for an upcast if /// `trait A`, `trait B: A`, so that `Box as Box` works. This is a simple, explicit workaround. diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 7fa9aaf3b5b..a8c6bf8e8d1 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -37,21 +37,23 @@ impl<'tx> ConnectionLike for SqlConnectorTransaction<'tx> {} #[async_trait] impl<'tx> Transaction for SqlConnectorTransaction<'tx> { - async fn commit(&mut self) -> connector::Result<()> { + async fn begin(&mut self) -> connector::Result<()> { catch(self.connection_info.clone(), async move { - self.inner.commit().await.map_err(SqlError::from) + self.inner.begin().await.map_err(SqlError::from) }) .await } - async fn rollback(&mut self) -> connector::Result<()> { + async fn commit(&mut self) -> connector::Result { catch(self.connection_info.clone(), async move { - let res = self.inner.rollback().await.map_err(SqlError::from); + self.inner.commit().await.map_err(SqlError::from) + }) + .await + } - match res { - Err(SqlError::TransactionAlreadyClosed(_)) | Err(SqlError::RollbackWithoutBegin) => Ok(()), - _ => res, - } + async fn rollback(&mut self) -> connector::Result { + catch(self.connection_info.clone(), async move { + self.inner.rollback().await.map_err(SqlError::from) }) .await } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ba2784d3c71..ffb67559a53 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -73,17 +73,16 @@ pub struct TransactionOptions { /// An optional pre-defined transaction id. Some value might be provided in case we want to generate /// a new id at the beginning of the transaction - #[serde(skip)] pub new_tx_id: Option, } impl TransactionOptions { - pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option) -> Self { + pub fn new(max_acquisition_millis: u64, valid_for_millis: u64, isolation_level: Option, new_tx_id: Option) -> Self { Self { max_acquisition_millis, valid_for_millis, isolation_level, - new_tx_id: None, + new_tx_id, } } diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 105733be416..4d22759550a 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -73,19 +73,27 @@ impl TransactionActorManager { timeout: Duration, engine_protocol: EngineProtocol, ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); + // Only create a client if there is no client for this transaction yet. + // otherwise, begin a new transaction/savepoint for the existing client. + if !self.clients.read().await.contains_key(&tx_id) { + let client = spawn_itx_actor( + query_schema.clone(), + tx_id.clone(), + conn, + isolation_level, + timeout, + CHANNEL_SIZE, + self.send_done.clone(), + engine_protocol, + ) + .await?; + + self.clients.write().await.insert(tx_id, client); + } else { + let client = self.get_client(&tx_id, "begin").await?; + client.begin().await?; + } + Ok(()) } diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 104ffc26812..4f54a100911 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -65,15 +65,40 @@ impl<'a> ITXServer<'a> { let _ = op.respond_to.send(TxOpResponse::Batch(result)); RunState::Continue } + TxOpRequestMsg::Begin => { + let resp = self.begin().await; + let _ = op.respond_to.send(TxOpResponse::Begin(resp)); + RunState::Continue + } TxOpRequestMsg::Commit => { let resp = self.commit().await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; + let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } TxOpRequestMsg::Rollback => { let resp = self.rollback(false).await; + let resp_value = match &resp { + Ok(val) => *val, + Err(_) => 0, + }; let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished + + + if resp_value > 0 { + RunState::Continue + } else { + RunState::Finished + } } } } @@ -117,32 +142,46 @@ impl<'a> ITXServer<'a> { .await } - pub(crate) async fn commit(&mut self) -> crate::Result<()> { + pub(crate) async fn begin(&mut self) -> crate::Result<()> { if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; + trace!("[{}] beginning.", self.id.to_string()); + open_tx.begin().await?; } Ok(()) } - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + pub(crate) async fn commit(&mut self) -> crate::Result { + if let CachedTx::Open(_) = self.cached_tx { + let open_tx = self.cached_tx.as_open()?; + trace!("[{}] committing.", self.id.to_string()); + let depth = open_tx.commit().await?; + if depth == 0 { + self.cached_tx = CachedTx::Committed; + } + return Ok(depth); + } + + Ok(0) + } + + pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result { debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); if let CachedTx::Open(_) = self.cached_tx { let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; + let depth = open_tx.rollback().await?; if was_timeout { trace!("[{}] Expired Rolling back", self.id.to_string()); self.cached_tx = CachedTx::Expired; - } else { + } else if depth == 0 { self.cached_tx = CachedTx::RolledBack; trace!("[{}] Rolling back", self.id.to_string()); } + return Ok(depth); } - Ok(()) + Ok(0) } pub(crate) fn name(&self) -> String { @@ -157,7 +196,12 @@ pub struct ITXClient { } impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { + pub async fn begin(&self) -> crate::Result<()> { + self.send_and_receive(TxOpRequestMsg::Begin).await?; + Ok(()) + } + + pub(crate) async fn commit(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; if let TxOpResponse::Committed(resp) = msg { @@ -168,7 +212,7 @@ impl ITXClient { } } - pub(crate) async fn rollback(&self) -> crate::Result<()> { + pub(crate) async fn rollback(&self) -> crate::Result { let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; if let TxOpResponse::RolledBack(resp) = msg { diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs index 0dba2c096a8..8f64a2fb712 100644 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ b/query-engine/core/src/interactive_transactions/messages.rs @@ -6,6 +6,7 @@ use tokio::sync::oneshot; pub enum TxOpRequestMsg { Commit, Rollback, + Begin, Single(Operation, Option), Batch(Vec, Option), } @@ -18,6 +19,7 @@ pub struct TxOpRequest { impl Display for TxOpRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.msg { + TxOpRequestMsg::Begin => write!(f, "Begin"), TxOpRequestMsg::Commit => write!(f, "Commit"), TxOpRequestMsg::Rollback => write!(f, "Rollback"), TxOpRequestMsg::Single(..) => write!(f, "Single"), @@ -28,8 +30,9 @@ impl Display for TxOpRequest { #[derive(Debug)] pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), + Begin(crate::Result<()>), + Committed(crate::Result), + RolledBack(crate::Result), Single(crate::Result), Batch(crate::Result>>), } @@ -37,6 +40,7 @@ pub enum TxOpResponse { impl Display for TxOpResponse { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Self::Begin(..) => write!(f, "Begin"), Self::Committed(..) => write!(f, "Committed"), Self::RolledBack(..) => write!(f, "RolledBack"), Self::Single(..) => write!(f, "Single"), diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index ce125e8fa17..5c99ebd9f8d 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,6 +1,6 @@ use crate::CoreError; use connector::Transaction; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use tokio::time::{Duration, Instant}; @@ -38,7 +38,7 @@ pub(crate) use messages::*; /// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction /// rather than an error message stating that the transaction does not exist. -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] pub struct TxId(String); const MINIMUM_TX_ID_LENGTH: usize = 24; diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 19693453988..0c27eca991d 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -46,6 +46,9 @@ pub(crate) struct TransactionProxy { /// transaction options options: TransactionOptions, + /// being trnsaction + pub begin: AsyncJsFunction<(), ()>, + /// commit transaction commit: AsyncJsFunction<(), ()>, @@ -579,10 +582,12 @@ pub struct TransactionOptions { impl TransactionProxy { pub fn new(js_transaction: &JsObject) -> napi::Result { let commit = js_transaction.get_named_property("commit")?; + let begin = js_transaction.get_named_property("begin")?; let rollback = js_transaction.get_named_property("rollback")?; let options = js_transaction.get_named_property("options")?; Ok(Self { + begin, commit, rollback, options, @@ -594,6 +599,10 @@ impl TransactionProxy { &self.options } + pub async fn begin(&self) -> quaint::Result<()> { + self.begin.call(()).await + } + /// Commits the transaction via the driver adapter. /// /// ## Cancellation safety diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index ab154eccc13..7e1603a9d9a 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -3,6 +3,7 @@ use crate::{ proxy::{CommonProxy, DriverProxy, Query}, }; use async_trait::async_trait; +use futures::lock::Mutex; use napi::JsObject; use psl::datamodel_connector::Flavour; use quaint::{ @@ -11,6 +12,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable as QuaintQueryable, ResultSet, TransactionCapable}, visitor::{self, Visitor}, }; +use std::sync::Arc; use tracing::{info_span, Instrument}; /// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the @@ -193,6 +195,7 @@ impl JsBaseQueryable { pub struct JsQueryable { inner: JsBaseQueryable, driver_proxy: DriverProxy, + pub transaction_depth: Arc>, } impl std::fmt::Display for JsQueryable { @@ -270,14 +273,19 @@ impl TransactionCapable for JsQueryable { } } - let begin_stmt = tx.begin_statement(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_stmt = tx.begin_statement(st_depth).await; let tx_opts = tx.options(); if tx_opts.use_phantom_query { - let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); + let begin_stmt = JsBaseQueryable::phantom_query_message(&begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; } else { - tx.raw_cmd(begin_stmt).await?; + tx.raw_cmd(&begin_stmt).await?; } if !isolation_first { @@ -299,5 +307,6 @@ pub fn from_napi(driver: JsObject) -> JsQueryable { JsQueryable { inner: JsBaseQueryable::new(common), driver_proxy, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), } } diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index d35a9019c6b..ac26158eba7 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use futures::lock::Mutex; use metrics::decrement_gauge; use napi::{bindgen_prelude::FromNapiValue, JsObject}; use quaint::{ @@ -6,6 +7,7 @@ use quaint::{ prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; +use std::sync::Arc; use crate::{ proxy::{CommonProxy, TransactionOptions, TransactionProxy}, @@ -18,11 +20,20 @@ use crate::{ pub(crate) struct JsTransaction { tx_proxy: TransactionProxy, inner: JsBaseQueryable, + pub depth: Arc>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl JsTransaction { pub(crate) fn new(inner: JsBaseQueryable, tx_proxy: TransactionProxy) -> Self { - Self { inner, tx_proxy } + Self { + inner, + tx_proxy, + commit_stmt: "COMMIT".to_string(), + rollback_stmt: "ROLLBACK".to_string(), + depth: Arc::new(futures::lock::Mutex::new(0)), + } } pub fn options(&self) -> &TransactionOptions { @@ -37,11 +48,31 @@ impl JsTransaction { #[async_trait] impl QuaintTransaction for JsTransaction { - async fn commit(&self) -> quaint::Result<()> { + async fn begin(&mut self) -> quaint::Result<()> { + // increment of this gauge is done in DriverProxy::startTransaction + decrement_gauge!("prisma_client_queries_active", 1.0); + + let mut depth_guard = self.depth.lock().await; + let commit_stmt = "BEGIN"; + + if self.options().use_phantom_query { + let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); + self.raw_phantom_cmd(commit_stmt.as_str()).await?; + } else { + self.inner.raw_cmd(commit_stmt).await?; + } + + // Modify the depth value through the MutexGuard + *depth_guard += 1; + + self.tx_proxy.begin().await + } + async fn commit(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let commit_stmt = "COMMIT"; + let mut depth_guard = self.depth.lock().await; + let commit_stmt = &self.commit_stmt; if self.options().use_phantom_query { let commit_stmt = JsBaseQueryable::phantom_query_message(commit_stmt); @@ -50,14 +81,20 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(commit_stmt).await?; } - self.tx_proxy.commit().await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = self.tx_proxy.commit().await; + + Ok(*depth_guard) } - async fn rollback(&self) -> quaint::Result<()> { + async fn rollback(&mut self) -> quaint::Result { // increment of this gauge is done in DriverProxy::startTransaction decrement_gauge!("prisma_client_queries_active", 1.0); - let rollback_stmt = "ROLLBACK"; + let mut depth_guard = self.depth.lock().await; + let rollback_stmt = &self.rollback_stmt; if self.options().use_phantom_query { let rollback_stmt = JsBaseQueryable::phantom_query_message(rollback_stmt); @@ -66,7 +103,12 @@ impl QuaintTransaction for JsTransaction { self.inner.raw_cmd(rollback_stmt).await?; } - self.tx_proxy.rollback().await + // Modify the depth value through the MutexGuard + *depth_guard -= 1; + + let _ = self.tx_proxy.rollback().await; + + Ok(*depth_guard) } fn as_queryable(&self) -> &dyn Queryable { diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index f3583df310d..ba1f4d4f13b 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -282,7 +282,11 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); + let tx_id = if tx_opts.new_tx_id.is_none() { + tx_opts.with_new_transaction_id() + } else { + tx_opts.new_tx_id.clone().unwrap() + }; // This is the span we use to instrument the execution of a transaction. This span will be open // during the tx execution, and held in the ITXServer for that transaction (see ITXServer])