diff --git a/common/db/src/lib.rs b/common/db/src/lib.rs index 72ff43674..95a041b5f 100644 --- a/common/db/src/lib.rs +++ b/common/db/src/lib.rs @@ -30,13 +30,53 @@ pub trait Get { /// is undefined. The transaction may block, deadlock, panic, overwrite one of the two values /// randomly, or any other action, at time of write or at time of commit. #[must_use] -pub trait DbTxn: Send + Get { +pub trait DbTxn: Sized + Send + Get { /// Write a value to this key. fn put(&mut self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>); /// Delete the value from this key. fn del(&mut self, key: impl AsRef<[u8]>); /// Commit this transaction. fn commit(self); + /// Close this transaction. + /// + /// This is equivalent to `Drop` on transactions which can be dropped. This is explicit and works + /// with transactions which can't be dropped. + fn close(self) { + drop(self); + } +} + +// Credit for the idea goes to https://jack.wrenn.fyi/blog/undroppable +pub struct Undroppable(Option); +impl Drop for Undroppable { + fn drop(&mut self) { + // Use an assertion at compile time to prevent this code from compiling if generated + #[allow(clippy::assertions_on_constants)] + const { + assert!(false, "Undroppable DbTxn was dropped. Ensure all code paths call commit or close"); + } + } +} +impl Get for Undroppable { + fn get(&self, key: impl AsRef<[u8]>) -> Option> { + self.0.as_ref().unwrap().get(key) + } +} +impl DbTxn for Undroppable { + fn put(&mut self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>) { + self.0.as_mut().unwrap().put(key, value); + } + fn del(&mut self, key: impl AsRef<[u8]>) { + self.0.as_mut().unwrap().del(key); + } + fn commit(mut self) { + self.0.take().unwrap().commit(); + let _ = core::mem::ManuallyDrop::new(self); + } + fn close(mut self) { + drop(self.0.take().unwrap()); + let _ = core::mem::ManuallyDrop::new(self); + } } /// A database supporting atomic transaction. @@ -51,6 +91,10 @@ pub trait Db: 'static + Send + Sync + Clone + Get { let dst_len = u8::try_from(item_dst.len()).unwrap(); [[db_len].as_ref(), db_dst, [dst_len].as_ref(), item_dst, key.as_ref()].concat() } - /// Open a new transaction. - fn txn(&mut self) -> Self::Transaction<'_>; + /// Open a new transaction which may be dropped. + fn unsafe_txn(&mut self) -> Self::Transaction<'_>; + /// Open a new transaction which must be committed or closed. + fn txn(&mut self) -> Undroppable> { + Undroppable(Some(self.unsafe_txn())) + } } diff --git a/common/db/src/mem.rs b/common/db/src/mem.rs index d24aa109f..8ff2d272c 100644 --- a/common/db/src/mem.rs +++ b/common/db/src/mem.rs @@ -74,7 +74,7 @@ impl Get for MemDb { } impl Db for MemDb { type Transaction<'a> = MemDbTxn<'a>; - fn txn(&mut self) -> MemDbTxn<'_> { + fn unsafe_txn(&mut self) -> MemDbTxn<'_> { MemDbTxn(self, HashMap::new(), HashSet::new()) } } diff --git a/common/db/src/parity_db.rs b/common/db/src/parity_db.rs index 9ae345f6f..f4cbb486f 100644 --- a/common/db/src/parity_db.rs +++ b/common/db/src/parity_db.rs @@ -37,7 +37,7 @@ impl Get for Arc { } impl Db for Arc { type Transaction<'a> = Transaction<'a>; - fn txn(&mut self) -> Self::Transaction<'_> { + fn unsafe_txn(&mut self) -> Self::Transaction<'_> { Transaction(self, vec![]) } } diff --git a/common/db/src/rocks.rs b/common/db/src/rocks.rs index 1d42d902e..d6329eb5a 100644 --- a/common/db/src/rocks.rs +++ b/common/db/src/rocks.rs @@ -39,7 +39,7 @@ impl Get for Arc> { } impl Db for Arc> { type Transaction<'a> = Transaction<'a, T>; - fn txn(&mut self) -> Self::Transaction<'_> { + fn unsafe_txn(&mut self) -> Self::Transaction<'_> { let mut opts = WriteOptions::default(); opts.set_sync(true); Transaction(self.transaction_opt(&opts, &Default::default()), &**self) diff --git a/coordinator/cosign/src/delay.rs b/coordinator/cosign/src/delay.rs index 3439135b4..f10f2c349 100644 --- a/coordinator/cosign/src/delay.rs +++ b/coordinator/cosign/src/delay.rs @@ -24,6 +24,15 @@ pub(crate) struct CosignDelayTask { pub(crate) db: D, } +struct AwaitUndroppable(Option>>); +impl Drop for AwaitUndroppable { + fn drop(&mut self) { + if let Some(mut txn) = self.0.take() { + (unsafe { core::mem::ManuallyDrop::take(&mut txn) }).close(); + } + } +} + impl ContinuallyRan for CosignDelayTask { type Error = DoesNotError; @@ -35,14 +44,18 @@ impl ContinuallyRan for CosignDelayTask { // Receive the next block to mark as cosigned let Some((block_number, time_evaluated)) = CosignedBlocks::try_recv(&mut txn) else { + txn.close(); break; }; + // Calculate when we should mark it as valid let time_valid = SystemTime::UNIX_EPOCH + Duration::from_secs(time_evaluated) + ACKNOWLEDGEMENT_DELAY; // Sleep until then + let mut txn = AwaitUndroppable(Some(core::mem::ManuallyDrop::new(txn))); tokio::time::sleep(SystemTime::now().duration_since(time_valid).unwrap_or(Duration::ZERO)) .await; + let mut txn = core::mem::ManuallyDrop::into_inner(txn.0.take().unwrap()); // Set the cosigned block LatestCosignedBlockNumber::set(&mut txn, &block_number); diff --git a/coordinator/cosign/src/evaluator.rs b/coordinator/cosign/src/evaluator.rs index 4216d5a7c..0e70ad67c 100644 --- a/coordinator/cosign/src/evaluator.rs +++ b/coordinator/cosign/src/evaluator.rs @@ -87,7 +87,7 @@ impl ContinuallyRan for CosignEvaluatorTask ContinuallyRan for CosignIntendTask { self.serai.latest_finalized_block().await.map_err(|e| format!("{e:?}"))?.number(); for block_number in start_block_number ..= latest_block_number { - let mut txn = self.db.txn(); + let mut txn = self.db.unsafe_txn(); let (block, mut has_events) = block_has_events_justifying_a_cosign(&self.serai, block_number) diff --git a/coordinator/cosign/src/lib.rs b/coordinator/cosign/src/lib.rs index dae8647b3..e13f86975 100644 --- a/coordinator/cosign/src/lib.rs +++ b/coordinator/cosign/src/lib.rs @@ -424,7 +424,7 @@ impl Cosigning { // Since we verified this cosign's signature, and have a chain sufficiently long, handle the // cosign - let mut txn = self.db.txn(); + let mut txn = self.db.unsafe_txn(); if !faulty { // If this is for a future global session, we don't acknowledge this cosign at this time @@ -480,3 +480,30 @@ impl Cosigning { res } } + +mod tests { + use super::*; + + struct RNC; + impl RequestNotableCosigns for RNC { + /// The error type which may be encountered when requesting notable cosigns. + type Error = (); + + /// Request the notable cosigns for this global session. + fn request_notable_cosigns( + &self, + global_session: [u8; 32], + ) -> impl Send + Future> { + async move { Ok(()) } + } + } + #[tokio::test] + async fn test() { + let db: serai_db::MemDb = serai_db::MemDb::new(); + let serai = unsafe { core::mem::transmute(0u64) }; + let request = RNC; + let tasks = vec![]; + let _ = Cosigning::spawn(db, serai, request, tasks); + core::future::pending().await + } +}