diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 1c868682..c2945949 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -6,6 +6,7 @@ use std::{ fmt, future::poll_fn, net::SocketAddr, + ops::Deref, str::{FromStr, Utf8Error}, sync::Arc, }; @@ -14,7 +15,7 @@ use bytes::Buf; use chrono::NaiveDateTime; use compact_str::CompactString; use corro_types::{ - agent::{Agent, ChangeError}, + agent::{Agent, ChangeError, CurrentVersion, KnownDbVersion}, broadcast::broadcast_changes, change::{insert_local_changes, InsertChangesInfo}, config::PgConfig, @@ -333,42 +334,41 @@ fn parse_query(sql: &str) -> Result, ParseError> { } #[derive(Debug, Default)] -enum TxState { - Started { - kind: OpenTxKind, - write_permit: Option, +enum TxState<'conn> { + Implicit { + tx: TxGuard<'conn>, + }, + Explicit { + tx: TxGuard<'conn>, + failed: bool, }, #[default] Ended, } -impl TxState { - fn implicit() -> Self { - Self::Started { - kind: OpenTxKind::Implicit, - write_permit: None, - } +impl<'conn> TxState<'conn> { + fn implicit(conn: &'conn Connection) -> rusqlite::Result { + Ok(Self::Implicit { + tx: TxGuard::start(conn)?, + }) } - fn explicit() -> Self { - Self::Started { - kind: OpenTxKind::Explicit, - write_permit: None, - } + fn explicit(conn: &'conn Connection) -> rusqlite::Result { + Ok(Self::Explicit { + tx: TxGuard::start(conn)?, + failed: false, + }) } fn is_writing(&self) -> bool { - matches!( - self, - TxState::Started { - write_permit: Some(_), - .. - } - ) + match self { + TxState::Implicit { tx } | TxState::Explicit { tx, .. } => tx.has_write_permit(), + TxState::Ended => false, + } } fn set_write_permit(&mut self, permit: OwnedSemaphorePermit) { match self { - TxState::Started { write_permit, .. } => *write_permit = Some(permit), + TxState::Implicit { tx } | TxState::Explicit { tx, .. } => tx.set_write_permit(permit), TxState::Ended => { // do nothing, maybe bomb? } @@ -376,49 +376,45 @@ impl TxState { } fn is_implicit(&self) -> bool { - matches!( - self, - TxState::Started { - kind: OpenTxKind::Implicit, - .. - } - ) + matches!(self, TxState::Implicit { .. }) } fn is_explicit(&self) -> bool { - matches!( - self, - TxState::Started { - kind: OpenTxKind::Explicit, - .. - } - ) + matches!(self, TxState::Explicit { .. }) } fn is_ended(&self) -> bool { matches!(self, TxState::Ended) } - fn start_implicit(&mut self) { - *self = Self::implicit() + fn start_implicit(&mut self, conn: &'conn Connection) -> rusqlite::Result<()> { + *self = Self::implicit(conn)?; + Ok(()) } - fn start_explicit(&mut self) { - *self = Self::explicit() + fn start_explicit(&mut self, conn: &'conn Connection) -> rusqlite::Result<()> { + *self = Self::explicit(conn)?; + Ok(()) } - fn end(&mut self) -> Option { - let permit = match self { - TxState::Started { write_permit, .. } => write_permit.take(), - TxState::Ended => None, - }; - *self = TxState::Ended; - permit + fn set_failed(&mut self) { + if let TxState::Explicit { failed, .. } = self { + *failed = true; + } } -} -#[derive(Debug)] -enum OpenTxKind { - Implicit, - Explicit, + fn is_failed(&self) -> bool { + match self { + TxState::Explicit { failed, .. } => *failed, + _ => false, + } + } + + fn take_tx(&mut self) -> Option> { + let prev = std::mem::take(self); + match prev { + TxState::Implicit { tx } | TxState::Explicit { tx, .. } => Some(tx), + TxState::Ended => None, + } + } } async fn peek_for_sslrequest( @@ -559,9 +555,11 @@ pub async fn start( let cancel = cancel.clone(); async move { // cancel stuff if this loop breaks - let _drop_guard = cancel.drop_guard(); + let _drop_guard = cancel.clone().drop_guard(); - while let Some(decode_res) = stream.next().await { + while let Outcome::Completed(Some(decode_res)) = + stream.next().preemptible(cancel.cancelled()).await + { let msg = match decode_res { Ok(msg) => msg, Err(PgWireError::IoError(io_error)) => { @@ -588,9 +586,9 @@ pub async fn start( break; } }; - front_tx.send(msg).await?; } + debug!("frontend stream is done"); Ok::<_, BoxError>(()) @@ -600,8 +598,11 @@ pub async fn start( tokio::spawn({ let cancel = cancel.clone(); async move { - let _drop_guard = cancel.drop_guard(); - while let Some(back) = back_rx.recv().await { + let _drop_guard = cancel.clone().drop_guard(); + + while let Outcome::Completed(Some(back)) = + back_rx.recv().preemptible(cancel.cancelled()).await + { match back { BackendResponse::Message { message, flush } => { if let PgWireBackendMessage::ErrorResponse(e) = &message { @@ -754,6 +755,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } }; @@ -782,6 +784,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } @@ -805,6 +808,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } }; @@ -858,6 +862,7 @@ pub async fn start( back_tx .blocking_send((e.into(), true).into())?; discard_until_sync = true; + session.tx_state.set_failed(); continue 'outer; } }; @@ -911,6 +916,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); } Some(Prepared::Empty) => { back_tx.blocking_send( @@ -972,6 +978,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); } Some(Portal::Empty { .. }) => { back_tx.blocking_send( @@ -1046,6 +1053,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } } @@ -1076,6 +1084,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } Some(Prepared::Empty) => { @@ -1110,6 +1119,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } }; @@ -1151,6 +1161,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } }; @@ -1188,6 +1199,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue 'outer; } continue; @@ -1215,6 +1227,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue 'outer; } Some(param_type) => { @@ -1364,6 +1377,7 @@ pub async fn start( ).into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue 'outer; } } @@ -1399,7 +1413,7 @@ pub async fn start( )?; } PgWireFrontendMessage::Sync(_) => { - send_ready(&mut session, &conn, discard_until_sync, &back_tx)?; + send_ready(&mut session, discard_until_sync, &back_tx)?; // reset this discard_until_sync = false; @@ -1442,6 +1456,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } }; @@ -1469,8 +1484,9 @@ pub async fn start( })?; discard_until_sync = true; + session.tx_state.set_failed(); - send_ready(&mut session, &conn, discard_until_sync, &back_tx)?; + send_ready(&mut session, discard_until_sync, &back_tx)?; continue; } } @@ -1492,12 +1508,8 @@ pub async fn start( ) .into(), )?; - send_ready( - &mut session, - &conn, - discard_until_sync, - &back_tx, - )?; + session.tx_state.set_failed(); + send_ready(&mut session, discard_until_sync, &back_tx)?; continue; } }; @@ -1513,7 +1525,7 @@ pub async fn start( .into(), )?; - send_ready(&mut session, &conn, discard_until_sync, &back_tx)?; + send_ready(&mut session, discard_until_sync, &back_tx)?; continue; } @@ -1525,48 +1537,41 @@ pub async fn start( message: e.try_into()?, flush: true, })?; - send_ready( - &mut session, - &conn, - discard_until_sync, - &back_tx, - )?; + session.tx_state.set_failed(); + send_ready(&mut session, discard_until_sync, &back_tx)?; continue 'outer; } } // automatically commit an implicit tx if session.tx_state.is_implicit() { - trace!("committing IMPLICIT tx"); - let _permit = session.tx_state.end(); - - if let Err(e) = session.handle_commit(&conn) { - back_tx.blocking_send( - ( - PgWireBackendMessage::ErrorResponse( - ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - e.to_string(), - ) + if let Some(tx) = session.tx_state.take_tx() { + trace!("committing IMPLICIT tx"); + if let Err(e) = session.handle_commit(tx) { + back_tx.blocking_send( + ( + PgWireBackendMessage::ErrorResponse( + ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + e.to_string(), + ) + .into(), + ), + true, + ) .into(), - ), - true, - ) - .into(), - )?; - send_ready( - &mut session, - &conn, - discard_until_sync, - &back_tx, - )?; - continue; + )?; + discard_until_sync = true; + session.tx_state.set_failed(); + send_ready(&mut session, discard_until_sync, &back_tx)?; + continue; + } + trace!("committed IMPLICIT tx"); } - trace!("committed IMPLICIT tx"); } - send_ready(&mut session, &conn, discard_until_sync, &back_tx)?; + send_ready(&mut session, discard_until_sync, &back_tx)?; } PgWireFrontendMessage::Terminate(_) => { break; @@ -1637,6 +1642,7 @@ pub async fn start( .into(), )?; discard_until_sync = true; + session.tx_state.set_failed(); continue; } } @@ -1733,19 +1739,23 @@ pub async fn start( Ok(PgServer { local_addr }) } -struct Session { +struct Session<'conn> { agent: Agent, - tx_state: TxState, + tx_state: TxState<'conn>, } -impl Session { +impl<'conn> Session<'conn> { fn handle_query( &mut self, - conn: &Connection, + conn: &'conn Connection, cmd: &ParsedCmd, back_tx: &Sender, send_row_desc: bool, ) -> Result<(), QueryError> { + if self.tx_state.is_failed() && !cmd.is_rollback() { + return Err(QueryError::AbortedTx); + } + if cmd.is_show() { back_tx .blocking_send( @@ -1774,32 +1784,38 @@ impl Session { // need to start an implicit transaction if self.tx_state.is_ended() && !cmd.is_begin() { - conn.execute_batch("BEGIN")?; trace!("started IMPLICIT tx"); - self.tx_state.start_implicit(); + self.tx_state.start_implicit(conn)?; } else if self.tx_state.is_implicit() && cmd.is_begin() { - trace!("committing IMPLICIT tx"); - let _permit = self.tx_state.end(); - - self.handle_commit(conn)?; - trace!("committed IMPLICIT tx"); + // this starts a new transaction, commits the previous implicit one + if let Some(tx) = self.tx_state.take_tx() { + trace!("committing IMPLICIT tx"); + self.handle_commit(tx)?; + trace!("committed IMPLICIT tx"); + } } let tag = cmd.tag(); let mut changes = 0usize; + // prevent nested transactions! + if cmd.is_begin() && !self.tx_state.is_ended() { + return Err(QueryError::NestedTransaction); + } + let count = if cmd.is_begin() { - conn.execute_batch("BEGIN")?; - self.tx_state.start_explicit(); + self.tx_state.start_explicit(conn)?; 0 } else if cmd.is_commit() { - let _permit = self.tx_state.end(); - self.handle_commit(conn)?; + if let Some(tx) = self.tx_state.take_tx() { + self.handle_commit(tx)?; + } 0 } else if cmd.is_rollback() { - let _permit = self.tx_state.end(); - conn.execute_batch("ROLLBACK")?; + if let Some(tx) = self.tx_state.take_tx() { + tx.rollback()?; + } 0 } else { let mut prepped = if cmd.is_pg() { @@ -1894,17 +1910,11 @@ impl Session { ) .map_err(|_| QueryError::BackendResponseSendFailed)?; - if cmd.is_begin() { - trace!("setting EXPLICIT tx"); - // explicit tx - self.tx_state.start_explicit(); - } - Ok(()) } #[allow(clippy::too_many_arguments)] - fn handle_execute<'conn>( + fn handle_execute( &mut self, conn: &'conn Connection, prepped: &mut Statement<'conn>, @@ -1938,13 +1948,10 @@ impl Session { if self.tx_state.is_ended() { if !cmd.is_begin() && !prepped.readonly() { // NOT in a tx and statement mutates DB... - conn.execute_batch("BEGIN")?; - - self.tx_state.start_implicit(); + self.tx_state.start_implicit(conn)?; opened_implicit_tx = true; } else if cmd.is_begin() { - conn.execute_batch("BEGIN")?; - self.tx_state.start_explicit(); + self.tx_state.start_explicit(conn)?; } } @@ -1954,8 +1961,9 @@ impl Session { let mut changes = 0usize; if cmd.is_commit() { - let _permit = self.tx_state.end(); - self.handle_commit(conn)?; + if let Some(tx) = self.tx_state.take_tx() { + self.handle_commit(tx)?; + } } else { if !self.tx_state.is_writing() && !prepped.readonly() { trace!("statement writes, acquiring permit..."); @@ -2100,8 +2108,9 @@ impl Session { } if opened_implicit_tx { - let _permit = self.tx_state.end(); - self.handle_commit(conn)?; + if let Some(tx) = self.tx_state.take_tx() { + self.handle_commit(tx)?; + } } } @@ -2123,9 +2132,8 @@ impl Session { Ok(()) } - fn handle_commit(&self, conn: &Connection) -> Result<(), ChangeError> { + fn handle_commit(&self, tx: TxGuard) -> Result<(), ChangeError> { trace!("HANDLE COMMIT"); - let mut book_writer = self .agent .booked() @@ -2133,8 +2141,8 @@ impl Session { let actor_id = self.agent.actor_id(); - let insert_info = insert_local_changes(&self.agent, conn, &mut book_writer)?; - conn.execute_batch("COMMIT") + let insert_info = insert_local_changes(&self.agent, &tx, &mut book_writer)?; + tx.execute_batch("COMMIT") .map_err(|source| ChangeError::Rusqlite { source, actor_id: Some(actor_id), @@ -2164,27 +2172,92 @@ impl Session { } } +#[derive(Debug)] +struct TxGuard<'conn> { + conn: &'conn Connection, + write_permit: Option, + ended: bool, +} + +impl<'conn> TxGuard<'conn> { + fn start(conn: &'conn Connection) -> rusqlite::Result { + conn.execute_batch("BEGIN")?; + Ok(Self { + conn, + write_permit: None, + ended: false, + }) + } + + fn set_write_permit(&mut self, permit: OwnedSemaphorePermit) { + self.write_permit = Some(permit); + } + + fn has_write_permit(&self) -> bool { + self.write_permit.is_some() + } + + fn commit(mut self) -> rusqlite::Result<()> { + self.commit_() + } + + fn commit_(&mut self) -> rusqlite::Result<()> { + self.conn.execute_batch("COMMIT")?; + self.ended = true; + Ok(()) + } + + fn rollback(mut self) -> rusqlite::Result<()> { + self.rollback_() + } + + fn rollback_(&mut self) -> rusqlite::Result<()> { + self.conn.execute_batch("ROLLBACK")?; + self.ended = true; + Ok(()) + } +} + +impl<'conn> Drop for TxGuard<'conn> { + fn drop(&mut self) { + if self.ended { + return; + } + // default rollback if not commited! + _ = self.rollback_(); + } +} + +impl<'conn> Deref for TxGuard<'conn> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.conn + } +} + fn send_ready( session: &mut Session, - conn: &Connection, discard_until_sync: bool, back_tx: &Sender, ) -> Result<(), BoxError> { let ready_status = if session.tx_state.is_implicit() { - let _permit = session.tx_state.end(); // do this first, in case of failure - if discard_until_sync { - // an error occured, rollback implicit tx! - warn!("receive Sync message w/ an error to send, rolling back implicit tx"); - conn.execute_batch("ROLLBACK")?; - } else { - // no error, commit implicit tx - warn!("receive Sync message, committing implicit tx"); - session.handle_commit(conn)?; + if let Some(tx) = session.tx_state.take_tx() { + if discard_until_sync { + // an error occured, rollback implicit tx! + warn!("receive Sync message w/ an error to send, rolling back implicit tx"); + tx.rollback()?; + } else { + // no error, commit implicit tx + warn!("receive Sync message, committing implicit tx"); + session.handle_commit(tx)?; + } } READY_STATUS_IDLE } else if session.tx_state.is_explicit() { - if discard_until_sync { + if discard_until_sync || session.tx_state.is_failed() { READY_STATUS_FAILED_TRANSACTION_BLOCK } else { READY_STATUS_TRANSACTION_BLOCK @@ -2220,6 +2293,10 @@ enum QueryError { PermitAcquire(#[from] AcquireError), #[error(transparent)] Change(#[from] ChangeError), + #[error("nested transactions are not supported, use savepoints for a similar functionality")] + NestedTransaction, + #[error("current transaction is aborted, commands ignored until end of transaction block")] + AbortedTx, } #[derive(Debug, thiserror::Error)] @@ -2248,6 +2325,12 @@ impl TryFrom for PgWireBackendMessage { QueryError::Change(e) => { ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), e.to_string()).into() } + QueryError::NestedTransaction => { + ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), value.to_string()).into() + } + QueryError::AbortedTx => { + ErrorInfo::new("ERROR".to_owned(), "25P02".to_owned(), value.to_string()).into() + } })) } } @@ -3255,6 +3338,29 @@ mod tests { assert_eq!(future, updated_at); } + { + let (mut client, client_conn) = tokio_postgres::connect(&conn_str, NoTls).await?; + tokio::spawn(client_conn); + + let tx = client.transaction().await?; + let res = tx + .batch_execute("INSERT INTO nonexistenttable VALUES (1)") + .await; + assert!(res.is_err()); + + let res = tx + .batch_execute("INSERT INTO nonexistenttable VALUES (1)") + .await; + assert_eq!( + res.err().unwrap().code().unwrap(), + &tokio_postgres::error::SqlState::IN_FAILED_SQL_TRANSACTION + ); + + tx.rollback().await?; + + assert!(client.query_one("SELECT 1", &[]).await.is_ok()); + } + tripwire_tx.send(()).await.ok(); tripwire_worker.await; wait_for_all_pending_handles().await;