diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index fcba5ca7..fb09da4c 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -64,7 +64,7 @@ use tokio::{ net::TcpListener, sync::{ mpsc::{channel, Receiver, Sender}, - watch, + watch, Semaphore, }, task::block_in_place, time::{error::Elapsed, sleep, timeout}, @@ -119,7 +119,9 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age info!("Actor ID: {}", actor_id); - let pool = SplitPool::create(&conf.db.path, tripwire.clone()).await?; + let write_sema = Arc::new(Semaphore::new(1)); + + let pool = SplitPool::create(&conf.db.path, write_sema.clone(), tripwire.clone()).await?; let schema = { let mut conn = pool.write_priority().await?; @@ -319,6 +321,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age tx_empty, tx_changes, tx_foca, + write_sema, schema: RwLock::new(schema), tripwire, }); diff --git a/crates/corro-pg/src/lib.rs b/crates/corro-pg/src/lib.rs index 9e6d1430..a7b9e29f 100644 --- a/crates/corro-pg/src/lib.rs +++ b/crates/corro-pg/src/lib.rs @@ -57,7 +57,10 @@ use sqlparser::ast::Statement as PgStatement; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, ReadBuf}, net::{TcpListener, TcpStream}, - sync::mpsc::{channel, Sender}, + sync::{ + mpsc::{channel, Sender}, + AcquireError, OwnedSemaphorePermit, + }, task::block_in_place, }; use tokio_util::{codec::Framed, sync::CancellationToken}; @@ -333,7 +336,91 @@ fn parse_query(sql: &str) -> Result, ParseError> { Ok(cmds) } +#[derive(Debug, Default)] enum OpenTx { + Started { + kind: OpenTxKind, + write_permit: Option, + }, + #[default] + Ended, +} + +impl OpenTx { + fn implicit() -> Self { + Self::Started { + kind: OpenTxKind::Implicit, + write_permit: None, + } + } + fn explicit() -> Self { + Self::Started { + kind: OpenTxKind::Explicit, + write_permit: None, + } + } + + fn is_writing(&self) -> bool { + matches!( + self, + OpenTx::Started { + write_permit: Some(_), + .. + } + ) + } + + fn set_write_permit(&mut self, permit: OwnedSemaphorePermit) { + match self { + OpenTx::Started { write_permit, .. } => *write_permit = Some(permit), + OpenTx::Ended => { + // do nothing, maybe bomb? + } + } + } + + fn is_implicit(&self) -> bool { + matches!( + self, + OpenTx::Started { + kind: OpenTxKind::Implicit, + .. + } + ) + } + fn is_explicit(&self) -> bool { + matches!( + self, + OpenTx::Started { + kind: OpenTxKind::Explicit, + .. + } + ) + } + fn is_ended(&self) -> bool { + matches!(self, OpenTx::Ended) + } + + fn start_implicit(&mut self) { + *self = Self::implicit() + } + + fn start_explicit(&mut self) { + *self = Self::explicit() + } + + fn end(&mut self) -> Option { + let permit = match self { + OpenTx::Started { write_permit, .. } => write_permit.take(), + OpenTx::Ended => None, + }; + *self = OpenTx::Ended; + permit + } +} + +#[derive(Debug)] +enum OpenTxKind { Implicit, Explicit, } @@ -615,7 +702,7 @@ pub async fn start( let mut discard_until_sync = false; - let mut open_tx = None; + let mut open_tx = OpenTx::default(); 'outer: while let Some(msg) = front_rx.blocking_recv() { debug!("msg: {msg:?}"); @@ -1476,9 +1563,9 @@ pub async fn start( } // automatically commit an implicit tx - if matches!(open_tx, Some(OpenTx::Implicit)) { + if open_tx.is_implicit() { trace!("committing IMPLICIT tx"); - open_tx = None; + let _permit = open_tx.end(); if let Err(e) = handle_commit(&agent, &conn) { back_tx.blocking_send( @@ -1683,33 +1770,31 @@ pub async fn start( fn send_ready( agent: &Agent, conn: &Connection, - open_tx: &mut Option, + open_tx: &mut OpenTx, discard_until_sync: bool, back_tx: &Sender, ) -> Result<(), BoxError> { - let ready_status = match open_tx { - Some(OpenTx::Implicit) => { - if discard_until_sync { - // an error occured, rollback implicit tx! - warn!("receive Sync message w/ an error to send, rolling back implicit tx"); - conn.execute_batch("ROLLBACK")?; - } else { - // no error, commit implicit tx - warn!("receive Sync message, committing implicit tx"); - handle_commit(agent, conn)?; - } - *open_tx = None; - - READY_STATUS_IDLE + let ready_status = if open_tx.is_implicit() { + let _permit = open_tx.end(); // do this first, in case of failure + if discard_until_sync { + // an error occured, rollback implicit tx! + warn!("receive Sync message w/ an error to send, rolling back implicit tx"); + conn.execute_batch("ROLLBACK")?; + } else { + // no error, commit implicit tx + warn!("receive Sync message, committing implicit tx"); + handle_commit(agent, conn)?; } - Some(OpenTx::Explicit) => { - if discard_until_sync { - READY_STATUS_FAILED_TRANSACTION_BLOCK - } else { - READY_STATUS_TRANSACTION_BLOCK - } + + READY_STATUS_IDLE + } else if open_tx.is_explicit() { + if discard_until_sync { + READY_STATUS_FAILED_TRANSACTION_BLOCK + } else { + READY_STATUS_TRANSACTION_BLOCK } - None => READY_STATUS_IDLE, + } else { + READY_STATUS_IDLE }; back_tx.blocking_send( @@ -1735,6 +1820,8 @@ enum QueryError { PgWire(#[from] PgWireError), #[error("backend response channel is closed")] BackendResponseSendFailed, + #[error("could not acquire write permit")] + PermitAcquire(#[from] AcquireError), } #[derive(Debug, thiserror::Error)] @@ -1756,6 +1843,9 @@ impl TryFrom for PgWireBackendMessage { QueryError::PgWire(e) => { ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), e.to_string()).into() } + e @ QueryError::PermitAcquire(_) => { + ErrorInfo::new("FATAL".to_owned(), "XX000".to_owned(), e.to_string()).into() + } QueryError::BackendResponseSendFailed => return Err(ChannelClosed), })) } @@ -1768,7 +1858,7 @@ fn handle_execute<'conn>( prepped: &mut Statement<'conn>, result_formats: &[FieldFormat], cmd: &ParsedCmd, - open_tx: &mut Option, + open_tx: &mut OpenTx, max_rows: usize, back_tx: &Sender, ) -> Result<(), QueryError> { @@ -1794,24 +1884,29 @@ fn handle_execute<'conn>( // we need to know because we'll commit it right away let mut opened_implicit_tx = false; - if open_tx.is_none() { + if open_tx.is_ended() { if !cmd.is_begin() && !prepped.readonly() { // NOT in a tx and statement mutates DB... conn.execute_batch("BEGIN")?; - *open_tx = Some(OpenTx::Implicit); + + open_tx.start_implicit(); opened_implicit_tx = true; } else if cmd.is_begin() { conn.execute_batch("BEGIN")?; - *open_tx = Some(OpenTx::Explicit); + open_tx.start_explicit(); } } let mut count = 0; if cmd.is_commit() { + let _permit = open_tx.end(); handle_commit(agent, conn)?; - *open_tx = None; } else { + if !open_tx.is_writing() && !prepped.readonly() { + trace!("statement writes, acquiring permit..."); + open_tx.set_write_permit(agent.write_permit_blocking()?); + } let mut rows = prepped.raw_query(); loop { if count >= max_rows { @@ -1936,8 +2031,8 @@ fn handle_execute<'conn>( } if opened_implicit_tx { + let _permit = open_tx.end(); handle_commit(agent, conn)?; - *open_tx = None; } } @@ -1962,7 +2057,7 @@ fn handle_query( agent: &Agent, conn: &Connection, cmd: &ParsedCmd, - open_tx: &mut Option, + open_tx: &mut OpenTx, back_tx: &Sender, send_row_desc: bool, ) -> Result<(), QueryError> { @@ -1993,29 +2088,32 @@ fn handle_query( } // need to start an implicit transaction - if open_tx.is_none() && !cmd.is_begin() { + if open_tx.is_ended() && !cmd.is_begin() { conn.execute_batch("BEGIN")?; trace!("started IMPLICIT tx"); - *open_tx = Some(OpenTx::Implicit); - } - - // close the current implement tx first - if matches!(open_tx, Some(OpenTx::Implicit)) && cmd.is_begin() { + open_tx.start_implicit(); + } else if open_tx.is_implicit() && cmd.is_begin() { trace!("committing IMPLICIT tx"); - *open_tx = None; + let _permit = open_tx.end(); handle_commit(agent, conn)?; trace!("committed IMPLICIT tx"); } - let count = if cmd.is_commit() { + let count = if cmd.is_begin() { + conn.execute_batch("BEGIN")?; + open_tx.start_explicit(); + 0 + } else if cmd.is_commit() { + let _permit = open_tx.end(); handle_commit(agent, conn)?; - *open_tx = None; + 0 + } else if cmd.is_rollback() { + let _permit = open_tx.end(); + conn.execute_batch("ROLLBACK")?; 0 } else { - let mut prepped = if cmd.is_begin() { - conn.prepare("BEGIN")? - } else if cmd.is_pg() { + let mut prepped = if cmd.is_pg() { return Err(QueryError::NotSqlite); } else { conn.prepare(&cmd.to_string())? @@ -2049,6 +2147,11 @@ fn handle_query( let schema = Arc::new(fields); + if !open_tx.is_writing() && !prepped.readonly() { + trace!("query statement writes, acquiring permit..."); + open_tx.set_write_permit(agent.write_permit_blocking()?); + } + let mut rows = prepped.raw_query(); let ncols = schema.len(); @@ -2097,11 +2200,7 @@ fn handle_query( if cmd.is_begin() { trace!("setting EXPLICIT tx"); // explicit tx - *open_tx = Some(OpenTx::Explicit) - } else if cmd.is_rollback() || cmd.is_commit() { - trace!("clearing current open tx"); - // if this was a rollback, remove the current open tx - *open_tx = None; + open_tx.start_explicit(); } Ok(()) @@ -2846,6 +2945,8 @@ fn parameter_types<'a>(schema: &'a Schema, cmd: &ParsedCmd) -> Vec<(SqliteType, #[cfg(test)] mod tests { + use std::time::{Duration, Instant}; + use chrono::{DateTime, Utc}; use corro_tests::launch_test_agent; use spawn::wait_for_all_pending_handles; @@ -2882,6 +2983,8 @@ mod tests { ) .await?; + let sema = ta.agent.write_sema().clone(); + let server = start( ta.agent.clone(), PgConfig { @@ -2904,6 +3007,8 @@ mod tests { println!("client is ready!"); tokio::spawn(client_conn); + let _permit = sema.acquire().await; + println!("before prepare"); let stmt = client.prepare("SELECT 1").await?; println!( @@ -2913,7 +3018,10 @@ mod tests { ); println!("before query"); - let rows = client.query(&stmt, &[]).await?; + // add a timeout because the semaphore shouldn't block anything here + // it will fail if the semaphore prevents this query. + let rows = tokio::time::timeout(Duration::from_millis(100), client.query(&stmt, &[])) + .await??; println!("rows count: {}", rows.len()); for row in rows { @@ -2921,10 +3029,26 @@ mod tests { } println!("before execute"); - let affected = client - .execute("INSERT INTO tests VALUES (1,2)", &[]) - .await?; - println!("after execute, affected: {affected}"); + let start = Instant::now(); + let (affected_res, sema_elapsed) = tokio::join!( + async { + let affected = client + .execute("INSERT INTO tests VALUES (1,2)", &[]) + .await?; + Ok::<_, tokio_postgres::Error>((affected, start.elapsed())) + }, + async move { + tokio::time::sleep(Duration::from_secs(1)).await; + drop(_permit); + start.elapsed() + } + ); + + let (affected, exec_elapsed) = affected_res?; + + println!("after execute, affected: {affected}, sema elapsed: {sema_elapsed:?}, exec elapsed: {exec_elapsed:?}"); + + assert!(exec_elapsed > sema_elapsed); let row = client.query_one("SELECT * FROM crsql_changes", &[]).await?; println!("CHANGE ROW: {row:?}"); diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index bdec6d3f..ba9681fb 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -18,11 +18,12 @@ use indexmap::IndexMap; use metrics::{gauge, histogram}; use parking_lot::RwLock; use rangemap::RangeInclusiveSet; -use rusqlite::{Connection, InterruptHandle, Transaction}; +use rusqlite::{Connection, Transaction}; use serde::{Deserialize, Serialize}; use tokio::sync::{ - watch, OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, RwLock as TokioRwLock, - RwLockReadGuard as TokioRwLockReadGuard, RwLockWriteGuard as TokioRwLockWriteGuard, + watch, AcquireError, OwnedRwLockWriteGuard as OwnedTokioRwLockWriteGuard, OwnedSemaphorePermit, + RwLock as TokioRwLock, RwLockReadGuard as TokioRwLockReadGuard, + RwLockWriteGuard as TokioRwLockWriteGuard, }; use tokio::{ runtime::Handle, @@ -32,7 +33,7 @@ use tokio::{ }, }; use tokio_util::sync::{CancellationToken, DropGuard}; -use tracing::{debug, error, info, Instrument}; +use tracing::{debug, error, info}; use tripwire::Tripwire; use crate::{ @@ -65,6 +66,8 @@ pub struct AgentConfig { pub tx_changes: Sender<(ChangeV1, ChangeSource)>, pub tx_foca: Sender, + pub write_sema: Arc, + pub schema: RwLock, pub tripwire: Tripwire, } @@ -84,6 +87,7 @@ pub struct AgentInner { tx_empty: Sender<(ActorId, RangeInclusive)>, tx_changes: Sender<(ChangeV1, ChangeSource)>, tx_foca: Sender, + write_sema: Arc, schema: RwLock, limits: Limits, } @@ -110,6 +114,7 @@ impl Agent { tx_empty: config.tx_empty, tx_changes: config.tx_changes, tx_foca: config.tx_foca, + write_sema: config.write_sema, schema: config.schema, limits: Limits { sync: Arc::new(Semaphore::new(3)), @@ -157,6 +162,18 @@ impl Agent { &self.0.tx_foca } + pub fn write_sema(&self) -> &Arc { + &self.0.write_sema + } + + pub async fn write_permit(&self) -> Result { + self.0.write_sema.clone().acquire_owned().await + } + + pub fn write_permit_blocking(&self) -> Result { + Handle::current().block_on(self.0.write_sema.clone().acquire_owned()) + } + pub fn bookie(&self) -> &Bookie { &self.0.bookie } @@ -317,6 +334,7 @@ pub struct SplitPool(Arc); #[derive(Debug)] struct SplitPoolInner { path: PathBuf, + write_sema: Arc, read: SqlitePool, write: SqlitePool, @@ -334,6 +352,8 @@ pub enum PoolError { QueueClosed, #[error("callback is closed")] CallbackClosed, + #[error("could not acquire write permit")] + Permit(#[from] AcquireError), } #[derive(Debug, thiserror::Error)] @@ -357,6 +377,7 @@ pub enum SplitPoolCreateError { impl SplitPool { pub async fn create>( path: P, + write_sema: Arc, tripwire: Tripwire, ) -> Result { let rw_pool = sqlite_pool::Config::new(path.as_ref()) @@ -373,13 +394,20 @@ impl SplitPool { Ok(Self::new( path.as_ref().to_owned(), + write_sema, ro_pool, rw_pool, tripwire, )) } - fn new(path: PathBuf, read: SqlitePool, write: SqlitePool, mut tripwire: Tripwire) -> Self { + fn new( + path: PathBuf, + write_sema: Arc, + read: SqlitePool, + write: SqlitePool, + mut tripwire: Tripwire, + ) -> Self { let (priority_tx, mut priority_rx) = channel(256); let (normal_tx, mut normal_rx) = channel(512); let (low_tx, mut low_rx) = channel(1024); @@ -413,6 +441,7 @@ impl SplitPool { Self(Arc::new(SplitPoolInner { path, + write_sema, read, write, priority_tx, @@ -502,48 +531,21 @@ impl SplitPool { histogram!("corro.sqlite.pool.queue.seconds", start.elapsed().as_secs_f64(), "queue" => queue); let conn = self.0.write.get().await?; - tokio::spawn( - timeout_wait( - token.clone(), - conn.get_interrupt_handle(), - Duration::from_secs(30), - queue, - ) - .in_current_span(), + let start = Instant::now(); + let _permit = self.0.write_sema.clone().acquire_owned().await?; + histogram!( + "corro.sqlite.write_permit.acquisition.seconds", + start.elapsed().as_secs_f64() ); Ok(WriteConn { conn, _drop_guard: token.drop_guard(), + _permit, }) } } -async fn timeout_wait( - token: CancellationToken, - _handle: InterruptHandle, - _timeout: Duration, - queue: &'static str, -) { - let start = Instant::now(); - token.cancelled().await; - histogram!("corro.sqlite.pool.execution.seconds", start.elapsed().as_secs_f64(), "queue" => queue); - // tokio::select! { - // biased; - // _ = token.cancelled() => { - // trace!("conn dropped before timeout"); - // histogram!("corro.sqlite.pool.execution.seconds", start.elapsed().as_secs_f64(), "queue" => queue); - // return; - // }, - // _ = tokio::time::sleep(timeout) => { - // warn!("conn execution timed out, interrupting!"); - // } - // } - // handle.interrupt(); - // increment_tracker!("corro.sqlite.pool.execution.timeout"); - // FIXME: do we need to cancel the token? -} - async fn wait_conn_drop(tx: oneshot::Sender) { let cancel = CancellationToken::new(); @@ -558,6 +560,7 @@ async fn wait_conn_drop(tx: oneshot::Sender) { pub struct WriteConn { conn: sqlite_pool::Connection, _drop_guard: DropGuard, + _permit: OwnedSemaphorePermit, } impl Deref for WriteConn { diff --git a/crates/corro-types/src/pubsub.rs b/crates/corro-types/src/pubsub.rs index 09afb88e..079be5ef 100644 --- a/crates/corro-types/src/pubsub.rs +++ b/crates/corro-types/src/pubsub.rs @@ -2016,6 +2016,7 @@ mod tests { use corro_api_types::row_to_change; use rusqlite::params; use spawn::wait_for_all_pending_handles; + use tokio::sync::Semaphore; use crate::{ agent::migrate, @@ -2041,7 +2042,8 @@ mod tests { let subscriptions_path: Utf8PathBuf = tmpdir.path().join("subs").display().to_string().into(); - let pool = SplitPool::create(db_path, tripwire.clone()).await?; + let pool = + SplitPool::create(db_path, Arc::new(Semaphore::new(1)), tripwire.clone()).await?; { let mut conn = pool.write_priority().await?; setup_conn(&mut conn)?; @@ -2162,7 +2164,9 @@ mod tests { let subscriptions_path: Utf8PathBuf = tmpdir.path().join("subs").display().to_string().into(); - let pool = SplitPool::create(&db_path, tripwire.clone()).await.unwrap(); + let pool = SplitPool::create(&db_path, Arc::new(Semaphore::new(1)), tripwire.clone()) + .await + .unwrap(); let mut conn = pool.write_priority().await.unwrap(); {