From 98f8071b1926fff011531b7a819e8e4e89791ffb Mon Sep 17 00:00:00 2001 From: danda Date: Tue, 24 Sep 2024 15:10:11 -0700 Subject: [PATCH 1/3] feat: wallet unconfirmed balance. wallet subscribes to mempool via tokio broadcast to track owned utxos. Changes: * add spent_utxos, unspent_utxos lists to WalletState struct * add WalletState::handle_mempool_event() * add test confirmed_and_unconfirmed_balance() * add Mempool::event_channel (tokio broadcast channel) * Mempool faillible mutation methods return Result() * Mempool mutable methods broadcast MempoolEvent * add tests::shared::mine_block_to_wallet() * lib.rs: spawn wallet task for listening to mempool events and dispatch to WalletState for handling. * add locks::tokio::AtomicRw::try_lock_guard_mut() --- Cargo.lock | 4 +- src/lib.rs | 55 ++++-- src/locks/tokio/atomic_rw.rs | 7 + src/main_loop.rs | 10 +- src/mine_loop.rs | 6 +- src/models/state/mempool.rs | 181 ++++++++++------- src/models/state/mod.rs | 2 +- .../state/wallet/rusty_wallet_database.rs | 3 + src/models/state/wallet/wallet_state.rs | 185 +++++++++++++++++- src/peer_loop.rs | 2 +- src/tests/shared.rs | 24 +++ 11 files changed, 389 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a5206dc5..18efdb44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -700,9 +700,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.12" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" dependencies = [ "crossbeam-utils", ] diff --git a/src/lib.rs b/src/lib.rs index 64f43b93..4f2eaf5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,16 +90,6 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { DataDirectory::create_dir_if_not_exists(&data_dir.root_dir_path()).await?; info!("Data directory is {}", data_dir); - // Get wallet object, create various wallet secret files - let wallet_dir = data_dir.wallet_directory_path(); - DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; - let (wallet_secret, _) = - WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; - info!("Now getting wallet state. This may take a while if the database needs pruning."); - let wallet_state = - WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; - info!("Got wallet state."); - // Connect to or create databases for block index, peers, mutator set, block sync let block_index_db = ArchivalState::initialize_block_index_database(&data_dir).await?; info!("Got block index database"); @@ -111,7 +101,7 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { info!("Got archival mutator set"); let archival_state = ArchivalState::new( - data_dir, + data_dir.clone(), block_index_db, archival_mutator_set, cli_args.network, @@ -149,6 +139,17 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { }; let blockchain_state = BlockchainState::Archival(blockchain_archival_state); let mempool = Mempool::new(cli_args.max_mempool_size, latest_block.hash()); + + // Get wallet object, create various wallet secret files + let wallet_dir = data_dir.wallet_directory_path(); + DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; + let (wallet_secret, _) = + WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; + info!("Now getting wallet state. This may take a while if the database needs pruning."); + let wallet_state = + WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; + info!("Got wallet state."); + let mut global_state_lock = GlobalStateLock::new( wallet_state, blockchain_state, @@ -176,8 +177,11 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await?; info!("UTXO restoration check complete"); - // Connect to peers, and provide each peer task with a thread-safe copy of the state let mut task_join_handles = vec![]; + + task_join_handles.push(spawn_wallet_task(global_state_lock.clone()).await?); + + // Connect to peers, and provide each peer task with a thread-safe copy of the state for peer_address in global_state_lock.cli().peers.clone() { let peer_state_var = global_state_lock.clone(); // bump arc refcount let main_to_peer_broadcast_rx_clone: broadcast::Receiver = @@ -284,6 +288,33 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await } +pub(crate) async fn spawn_wallet_task( + mut global_state_lock: GlobalStateLock, +) -> Result> { + let mut mempool_subscriber = global_state_lock.lock_guard().await.mempool.subscribe(); + + let wallet_join_handle = tokio::task::Builder::new() + .name("wallet_mempool_listener") + .spawn(async move { + let mut events: std::collections::VecDeque<_> = Default::default(); + + while let Ok(event) = mempool_subscriber.recv().await { + events.push_back(event); + + if let Ok(mut gs) = global_state_lock.try_lock_guard_mut() { + while let Some(e) = events.pop_front() { + gs.wallet_state + .handle_mempool_event(e) + .await + .expect("Wallet should handle mempool event without error"); + } + } + } + })?; + + Ok(wallet_join_handle) +} + /// Time a fn call. Duration is returned as a float in seconds. pub fn time_fn_call(f: impl FnOnce() -> O) -> (O, f64) { let start = Instant::now(); diff --git a/src/locks/tokio/atomic_rw.rs b/src/locks/tokio/atomic_rw.rs index c02c1710..b372119a 100644 --- a/src/locks/tokio/atomic_rw.rs +++ b/src/locks/tokio/atomic_rw.rs @@ -6,6 +6,7 @@ use futures::future::BoxFuture; use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; use tokio::sync::RwLockWriteGuard; +use tokio::sync::TryLockError; use super::LockAcquisition; use super::LockCallbackFn; @@ -240,6 +241,12 @@ impl AtomicRw { AtomicRwWriteGuard::new(guard, &self.lock_callback_info) } + pub fn try_lock_guard_mut(&mut self) -> Result, TryLockError> { + self.try_acquire_write_cb(); + let guard = self.inner.try_write()?; + Ok(AtomicRwWriteGuard::new(guard, &self.lock_callback_info)) + } + /// Immutably access the data of type `T` in a closure and possibly return a result of type `R` /// /// # Examples diff --git a/src/main_loop.rs b/src/main_loop.rs index 050c9e99..8e64760c 100644 --- a/src/main_loop.rs +++ b/src/main_loop.rs @@ -529,7 +529,7 @@ impl MainLoopHandler { // Insert into mempool global_state_mut .mempool - .insert(&pt2m_transaction.transaction); + .insert(pt2m_transaction.transaction.to_owned())?; // send notification to peers let transaction_notification: TransactionNotification = @@ -970,7 +970,7 @@ impl MainLoopHandler { // Handle mempool cleanup, i.e. removing stale/too old txs from mempool _ = &mut mempool_cleanup_timer => { debug!("Timer: mempool-cleaner job"); - self.global_state_lock.lock_mut(|s| s.mempool.prune_stale_transactions()).await; + self.global_state_lock.lock_guard_mut().await.mempool.prune_stale_transactions()?; // Reset the timer to run this branch again in P seconds mempool_cleanup_timer.as_mut().reset(tokio::time::Instant::now() + mempool_cleanup_timer_interval); @@ -1026,8 +1026,10 @@ impl MainLoopHandler { // insert transaction into mempool self.global_state_lock - .lock_mut(|s| s.mempool.insert(&transaction)) - .await; + .lock_guard_mut() + .await + .mempool + .insert(*transaction)?; // do not shut down Ok(false) diff --git a/src/mine_loop.rs b/src/mine_loop.rs index e37f2acc..24b61694 100644 --- a/src/mine_loop.rs +++ b/src/mine_loop.rs @@ -51,7 +51,7 @@ use crate::util_types::mutator_set::mutator_set_accumulator::MutatorSetAccumulat const MOCK_MAX_BLOCK_SIZE: u32 = 1_000_000; /// Prepare a Block for mining -fn make_block_template( +pub(crate) fn make_block_template( previous_block: &Block, transaction: Transaction, mut block_timestamp: Timestamp, @@ -299,7 +299,7 @@ fn make_coinbase_transaction( /// Create the transaction that goes into the block template. The transaction is /// built from the mempool and from the coinbase transaction. Also returns the /// "sender randomness" used in the coinbase transaction. -fn create_block_transaction( +pub(crate) fn create_block_transaction( latest_block: &Block, global_state: &GlobalState, timestamp: Timestamp, @@ -604,7 +604,7 @@ mod mine_loop_tests { premine_receiver_global_state .mempool - .insert(&tx_by_preminer); + .insert(tx_by_preminer)?; assert_eq!(1, premine_receiver_global_state.mempool.len()); // Build transaction diff --git a/src/models/state/mempool.rs b/src/models/state/mempool.rs index fe5d24cf..5b2a8183 100644 --- a/src/models/state/mempool.rs +++ b/src/models/state/mempool.rs @@ -13,6 +13,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::iter::Rev; +use anyhow::Result; use bytesize::ByteSize; use get_size::GetSize; use num_traits::Zero; @@ -59,7 +60,13 @@ pub const TRANSACTION_NOTIFICATION_AGE_LIMIT_IN_SECS: u64 = 60 * 60 * 24; type LookupItem<'a> = (Digest, &'a Transaction); -#[derive(Debug, Clone, PartialEq, Eq, GetSize)] +#[derive(Debug, Clone)] +pub enum MempoolEvent { + AddTx(Transaction), + RemoveTx(Transaction), +} + +#[derive(Debug, GetSize)] pub struct Mempool { max_total_size: usize, @@ -75,6 +82,13 @@ pub struct Mempool { /// Records the digest of the block that the transactions were synced to. /// Used to discover reorganizations. tip_digest: Digest, + + /// a mpmc channel for interested parties to listen to mempool events + #[get_size(ignore)] // does not impl GetSize + event_channel: ( + tokio::sync::broadcast::Sender, + tokio::sync::broadcast::Receiver, + ), } impl Mempool { @@ -88,6 +102,7 @@ impl Mempool { tx_dictionary: table, queue, tip_digest, + event_channel: tokio::sync::broadcast::channel(100), } } @@ -140,7 +155,7 @@ impl Mempool { /// this method accepts only fully proven transactions (or, for the time being, faith witnesses). /// The caller must also ensure that the transaction does not have a timestamp /// in the too distant future. - pub fn insert(&mut self, transaction: &Transaction) -> Option { + pub fn insert(&mut self, transaction: Transaction) -> Result { match transaction.witness.vast.witness_type { WitnessType::RawWitness(_) => panic!("Can only insert fully proven transactions into mempool; not accepting raw witnesses."), WitnessType::Decomposition => panic!("Can only insert fully proven transactions into mempool; not accepting decompositions."), @@ -150,53 +165,57 @@ impl Mempool { } // If transaction to be inserted conflicts with a transaction that's already // in the mempool we preserve only the one with the highest fee density. - if let Some((txid, tx)) = self.transaction_conflicts_with(transaction) { + if let Some((txid, tx)) = self.transaction_conflicts_with(&transaction) { if tx.fee_density() < transaction.fee_density() { // If new transaction has a higher fee density than the one previously seen // remove the old one. - self.remove(txid); + self.remove(txid)?; } else { // If new transaction has a lower fee density than the one previous seen, // ignore it. Stop execution here. - return Some(txid); + return Ok(txid); } }; - let transaction_id: Digest = Hash::hash(transaction); + let transaction_id: Digest = Hash::hash(&transaction); self.queue.push(transaction_id, transaction.fee_density()); self.tx_dictionary - .insert(transaction_id, transaction.to_owned()); + .insert(transaction_id, transaction.clone()); assert_eq!( self.tx_dictionary.len(), self.queue.len(), "mempool's table and queue length must agree prior to shrink" ); - self.shrink_to_max_size(); + self.shrink_to_max_size()?; assert_eq!( self.tx_dictionary.len(), self.queue.len(), "mempool's table and queue length must agree after shrink" ); - None + self.sender().send(MempoolEvent::AddTx(transaction))?; + + Ok(transaction_id) } /// remove a transaction from the `Mempool` - pub fn remove(&mut self, transaction_id: Digest) -> Option { - if let rv @ Some(_) = self.tx_dictionary.remove(&transaction_id) { - self.queue.remove(&transaction_id); - debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - return rv; + pub fn remove(&mut self, transaction_id: Digest) -> Result { + match self.tx_dictionary.remove(&transaction_id) { + Some(tx) => { + self.queue.remove(&transaction_id); + debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); + self.sender().send(MempoolEvent::RemoveTx(tx))?; + Ok(true) + } + None => Ok(false), } - - None } /// Delete all transactions from the mempool. - pub fn clear(&mut self) { - self.queue.clear(); - self.tx_dictionary.clear(); + pub fn clear(&mut self) -> Result<()> { + // note: this causes event listeners to be notified of each removed tx. + self.retain(|_| false) } /// Return the number of transactions currently stored in the Mempool. @@ -248,28 +267,36 @@ impl Mempool { /// /// Computes in θ(lg N) #[allow(dead_code)] - pub fn pop_max(&mut self) -> Option<(Transaction, FeeDensity)> { + pub fn pop_max(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_max() { - let transaction = self.tx_dictionary.remove(&transaction_digest).unwrap(); - debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - Some((transaction, fee_density)) - } else { - None + if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { + debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); + + self.sender() + .send(MempoolEvent::RemoveTx(transaction.clone()))?; + + return Ok(Some((transaction, fee_density))); + } } + Ok(None) } /// Removes the transaction with the lowest [`FeeDensity`] from the mempool. /// Returns the removed value. /// /// Computes in θ(lg N) - pub fn pop_min(&mut self) -> Option<(Transaction, FeeDensity)> { + pub fn pop_min(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_min() { - let transaction = self.tx_dictionary.remove(&transaction_digest).unwrap(); - debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - Some((transaction, fee_density)) - } else { - None + if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { + debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); + + self.sender() + .send(MempoolEvent::RemoveTx(transaction.clone()))?; + + return Ok(Some((transaction, fee_density))); + } } + Ok(None) } /// Removes all transactions from the mempool that do not satisfy the @@ -277,7 +304,7 @@ impl Mempool { /// Modelled after [HashMap::retain](std::collections::HashMap::retain()) /// /// Computes in O(capacity) >= O(N) - pub fn retain(&mut self, mut predicate: F) + pub fn retain(&mut self, mut predicate: F) -> Result<()> where F: FnMut(LookupItem) -> bool, { @@ -291,25 +318,27 @@ impl Mempool { } for t in victims { - self.remove(t); + self.remove(t)?; } debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.shrink_to_fit() + self.shrink_to_fit(); + + Ok(()) } /// Remove transactions from mempool that are older than the specified /// timestamp. Prunes base on the transaction's timestamp. /// /// Computes in O(n) - pub fn prune_stale_transactions(&mut self) { + pub fn prune_stale_transactions(&mut self) -> Result<()> { let cutoff = Timestamp::now() - Timestamp::seconds(MEMPOOL_TX_THRESHOLD_AGE_IN_SECS); let keep = |(_transaction_id, transaction): LookupItem| -> bool { cutoff < transaction.kernel.timestamp }; - self.retain(keep); + self.retain(keep) } /// Remove from the mempool all transactions that become invalid because @@ -319,14 +348,14 @@ impl Mempool { &mut self, previous_mutator_set_accumulator: MutatorSetAccumulator, block: &Block, - ) { + ) -> Result<()> { // If we discover a reorganization, we currently just clear the mempool, // as we don't have the ability to roll transaction removal record integrity // proofs back to previous blocks. It would be nice if we could handle a // reorganization that's at least a few blocks deep though. let previous_block_digest = block.header().prev_block_digest; if self.tip_digest != previous_block_digest { - self.clear(); + self.clear()?; } // The general strategy is to check whether the SWBF index set of a given @@ -368,7 +397,7 @@ impl Mempool { }; // Remove the transactions that become invalid with this block - self.retain(keep); + self.retain(keep)?; // Update the remaining transactions so their mutator set data is still valid for tx in self.tx_dictionary.values_mut() { @@ -380,22 +409,25 @@ impl Mempool { // Maintaining the mutator set data could have increased the size of the // transactions in the mempool. So we should shrink it to max size after // applying the block. - self.shrink_to_max_size(); + self.shrink_to_max_size()?; // Update the sync-label to keep track of reorganizations let current_block_digest = block.hash(); self.set_tip_digest_sync_label(current_block_digest); + + Ok(()) } /// Shrink the memory pool to the value of its `max_size` field. /// Likely computes in O(n) - fn shrink_to_max_size(&mut self) { + fn shrink_to_max_size(&mut self) -> Result<()> { // Repeately remove the least valuable transaction - while self.get_size() > self.max_total_size && self.pop_min().is_some() { + while self.get_size() > self.max_total_size && self.pop_min()?.is_some() { continue; } - self.shrink_to_fit() + self.shrink_to_fit(); + Ok(()) } /// Shrinks internal data structures as much as possible. @@ -433,6 +465,15 @@ impl Mempool { let dpq_clone = self.queue.clone(); dpq_clone.into_sorted_iter().rev() } + + pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver { + self.sender().subscribe() + } + + fn sender(&self) -> &tokio::sync::broadcast::Sender { + let (sender, _) = &self.event_channel; + sender + } } #[cfg(test)] @@ -470,7 +511,7 @@ mod tests { use super::*; #[tokio::test] - pub async fn insert_then_get_then_remove_then_get() { + pub async fn insert_then_get_then_remove_then_get() -> Result<()> { let network = Network::Alpha; let genesis_block = Block::genesis_block(network); let mut mempool = Mempool::new(ByteSize::gb(1), genesis_block.hash()); @@ -484,20 +525,20 @@ mod tests { ); let transaction_digest = Hash::hash(&transaction); assert!(!mempool.contains(transaction_digest)); - mempool.insert(&transaction); + mempool.insert(transaction.clone())?; assert!(mempool.contains(transaction_digest)); let transaction_get_option = mempool.get(transaction_digest); assert_eq!(Some(&transaction), transaction_get_option); assert!(mempool.contains(transaction_digest)); - let transaction_remove_option = mempool.remove(transaction_digest); - assert_eq!(Some(transaction), transaction_remove_option); + assert!(mempool.remove(transaction_digest)?); + assert!(!mempool.contains(transaction_digest)); + + assert!(!mempool.remove(transaction_digest)?); assert!(!mempool.contains(transaction_digest)); - let transaction_second_remove_option = mempool.remove(transaction_digest); - assert_eq!(None, transaction_second_remove_option); - assert!(!mempool.contains(transaction_digest)) + Ok(()) } // Create a mempool with n transactions. @@ -513,7 +554,7 @@ mod tests { &wallet_state, None, ); - mempool.insert(&t); + mempool.insert(t).unwrap(); } mempool } @@ -553,7 +594,7 @@ mod tests { #[traced_test] #[tokio::test] - async fn prune_stale_transactions() { + async fn prune_stale_transactions() -> Result<()> { let network = Network::Alpha; let genesis_block = Block::genesis_block(network); let mut mempool = Mempool::new(ByteSize::gb(1), genesis_block.hash()); @@ -574,7 +615,7 @@ mod tests { &wallet_state, timestamp, ); - mempool.insert(&t); + mempool.insert(t)?; } for i in 0u32..5 { @@ -585,12 +626,14 @@ mod tests { &wallet_state, None, ); - mempool.insert(&t); + mempool.insert(t)?; } assert_eq!(mempool.len(), 11); - mempool.prune_stale_transactions(); - assert_eq!(mempool.len(), 5) + mempool.prune_stale_transactions()?; + assert_eq!(mempool.len(), 5); + + Ok(()) } #[traced_test] @@ -695,7 +738,7 @@ mod tests { // Add this transaction to a mempool let mut mempool = Mempool::new(ByteSize::gb(1), block_1.hash()); - mempool.insert(&tx_by_preminer); + mempool.insert(tx_by_preminer.clone())?; // Create another transaction that's valid to be included in block 2, but isn't actually // included by the miner. This transaction is inserted into the mempool, but since it's @@ -725,7 +768,7 @@ mod tests { .await .unwrap(); - mempool.insert(&tx_by_other_original); + mempool.insert(tx_by_other_original)?; // Create next block which includes preminer's transaction let (mut block_2, _, _) = @@ -741,7 +784,7 @@ mod tests { block_1.kernel.body.mutator_set_accumulator.clone(), &block_2, ) - .await; + .await?; assert_eq!(1, mempool.len()); // Create a new block to verify that the non-mined transaction contains @@ -785,7 +828,7 @@ mod tests { previous_block.kernel.body.mutator_set_accumulator.clone(), &next_block, ) - .await; + .await?; previous_block = next_block; } @@ -810,7 +853,7 @@ mod tests { previous_block.kernel.body.mutator_set_accumulator.clone(), &block_14, ) - .await; + .await?; assert!( mempool.is_empty(), @@ -822,7 +865,7 @@ mod tests { #[traced_test] #[tokio::test] - async fn reorganization_does_not_crash_mempool() { + async fn reorganization_does_not_crash_mempool() -> Result<()> { // Verify that reorganizations do not crash the client, and other // qualities. @@ -873,7 +916,7 @@ mod tests { .await .unwrap(); - premine_receiver_global_state.mempool.insert(&unmined_tx); + premine_receiver_global_state.mempool.insert(unmined_tx)?; let mut rng = thread_rng(); @@ -957,6 +1000,8 @@ mod tests { "All retained txs in the mempool must be confirmable relative to the new block. Or the mempool must be empty." ); + + Ok(()) } #[traced_test] @@ -994,7 +1039,9 @@ mod tests { .unwrap(); assert_eq!(0, preminer_state.mempool.len()); - preminer_state.mempool.insert(&tx_by_preminer_low_fee); + preminer_state + .mempool + .insert(tx_by_preminer_low_fee.clone())?; assert_eq!(1, preminer_state.mempool.len()); assert_eq!( @@ -1021,7 +1068,9 @@ mod tests { .await .unwrap(); - preminer_state.mempool.insert(&tx_by_preminer_high_fee); + preminer_state + .mempool + .insert(tx_by_preminer_high_fee.clone())?; assert_eq!(1, preminer_state.mempool.len()); assert_eq!( &tx_by_preminer_high_fee, @@ -1047,7 +1096,7 @@ mod tests { .await .unwrap(); - preminer_state.mempool.insert(&tx_by_preminer_medium_fee); + preminer_state.mempool.insert(tx_by_preminer_medium_fee)?; assert_eq!(1, preminer_state.mempool.len()); assert_eq!( &tx_by_preminer_high_fee, diff --git a/src/models/state/mod.rs b/src/models/state/mod.rs index 52d78ab7..28654a1f 100644 --- a/src/models/state/mod.rs +++ b/src/models/state/mod.rs @@ -1332,7 +1332,7 @@ impl GlobalState { myself .mempool .update_with_block(previous_ms_accumulator, &new_block) - .await; + .await?; myself.chain.light_state_mut().set_block(new_block); diff --git a/src/models/state/wallet/rusty_wallet_database.rs b/src/models/state/wallet/rusty_wallet_database.rs index 683f33a5..73a0e588 100644 --- a/src/models/state/wallet/rusty_wallet_database.rs +++ b/src/models/state/wallet/rusty_wallet_database.rs @@ -18,6 +18,9 @@ pub struct RustyWalletDatabase { // list of utxos we have already received in a block monitored_utxos: DbtVec, + // list of utxos presently in the mempool + // monitored_mempool_utxos: DbtVec, + // list of off-chain utxos we are expecting to receive in a future block expected_utxos: DbtVec, diff --git a/src/models/state/wallet/wallet_state.rs b/src/models/state/wallet/wallet_state.rs index 9865a5f2..06aa5fea 100644 --- a/src/models/state/wallet/wallet_state.rs +++ b/src/models/state/wallet/wallet_state.rs @@ -6,6 +6,7 @@ use std::path::PathBuf; use anyhow::bail; use anyhow::Result; use itertools::Itertools; +use num_traits::CheckedSub; use num_traits::Zero; use serde_derive::Deserialize; use serde_derive::Serialize; @@ -37,6 +38,7 @@ use crate::models::blockchain::type_scripts::native_currency::NativeCurrency; use crate::models::blockchain::type_scripts::neptune_coins::NeptuneCoins; use crate::models::consensus::tasm::program::ConsensusProgram; use crate::models::consensus::timestamp::Timestamp; +use crate::models::state::mempool::MempoolEvent; use crate::models::state::wallet::monitored_utxo::MonitoredUtxo; use crate::prelude::twenty_first; use crate::util_types::mutator_set::addition_record::AdditionRecord; @@ -63,8 +65,10 @@ pub struct WalletState { pub wallet_db: RustyWalletDatabase, pub wallet_secret: WalletSecret, pub number_of_mps_per_utxo: usize, - wallet_directory_path: PathBuf, + + mempool_spent_utxos: HashMap>, + mempool_unspent_utxos: HashMap>, } /// Contains the cryptographic (non-public) data that is needed to recover the mutator set @@ -200,6 +204,8 @@ impl WalletState { wallet_secret, number_of_mps_per_utxo: cli_args.number_of_mps_per_utxo, wallet_directory_path: data_dir.wallet_directory_path(), + mempool_spent_utxos: Default::default(), + mempool_unspent_utxos: Default::default(), }; // Wallet state has to be initialized with the genesis block, otherwise the outputs @@ -237,6 +243,78 @@ impl WalletState { wallet_state } + pub async fn handle_mempool_event(&mut self, event: MempoolEvent) -> Result<()> { + match event { + MempoolEvent::AddTx(tx) => { + debug!("handling mempool AddTx event."); + + let spent_utxos = self.scan_for_spent_utxos(&tx).await; + + let announced_utxos = self + .scan_for_announced_utxos(&tx) + .chain(self.scan_for_expected_utxos(&tx).await) + .collect_vec(); + + let tx_hash = Hash::hash(&tx); + self.mempool_spent_utxos.insert(tx_hash, spent_utxos); + self.mempool_unspent_utxos.insert(tx_hash, announced_utxos); + } + MempoolEvent::RemoveTx(tx) => { + debug!("handling mempool RemoveTx event."); + let tx_hash = Hash::hash(&tx); + self.mempool_spent_utxos.remove(&tx_hash); + self.mempool_unspent_utxos.remove(&tx_hash); + } + } + Ok(()) + } + + pub fn mempool_spent_utxos_iter(&self) -> impl Iterator { + self.mempool_spent_utxos + .values() + .flatten() + .map(|(utxo, ..)| utxo) + } + + pub fn mempool_unspent_utxos_iter(&self) -> impl Iterator { + self.mempool_unspent_utxos + .values() + .flatten() + .map(|a| &a.utxo) + } + + pub async fn confirmed_balance( + &self, + tip_digest: Digest, + timestamp: Timestamp, + ) -> NeptuneCoins { + let wallet_status = self.get_wallet_status_from_lock(tip_digest).await; + + wallet_status.synced_unspent_available_amount(timestamp) + } + + pub async fn unconfirmed_balance( + &self, + tip_digest: Digest, + timestamp: Timestamp, + ) -> NeptuneCoins { + self.confirmed_balance(tip_digest, timestamp) + .await + .checked_sub( + &self + .mempool_spent_utxos_iter() + .map(|u| u.get_native_currency_amount()) + .sum(), + ) + .unwrap() + .safe_add( + self.mempool_unspent_utxos_iter() + .map(|u| u.get_native_currency_amount()) + .sum(), + ) + .unwrap() + } + // note: does not verify we do not have any dups. pub(crate) async fn add_expected_utxo(&mut self, expected_utxo: ExpectedUtxo) { self.wallet_db @@ -1288,6 +1366,111 @@ mod tests { } } + mod wallet_balance { + use generation_address::GenerationReceivingAddress; + + use super::*; + use crate::models::blockchain::transaction::UtxoNotifyMethod; + use crate::models::state::wallet::address::ReceivingAddress; + use crate::tests::shared::mine_block_to_wallet; + + #[traced_test] + #[tokio::test] + async fn confirmed_and_unconfirmed_balance() -> Result<()> { + let mut rng = thread_rng(); + let network = Network::RegTest; + let mut global_state_lock = + mock_genesis_global_state(network, 0, WalletSecret::new_random()).await; + let _wallet_task_jh = crate::spawn_wallet_task(global_state_lock.clone()).await?; + let change_key = global_state_lock + .lock_guard_mut() + .await + .wallet_state + .next_unused_spending_key(KeyType::Generation); + let coinbase_amt = NeptuneCoins::new(100); + let send_amt = NeptuneCoins::new(5); + + let tip_digest = mine_block_to_wallet(&mut global_state_lock).await?.hash(); + + let tx = { + let gs = global_state_lock.lock_guard().await; + assert_eq!( + gs.wallet_state + .confirmed_balance(tip_digest, Timestamp::now()) + .await, + coinbase_amt + ); + assert_eq!( + gs.wallet_state + .unconfirmed_balance(tip_digest, Timestamp::now()) + .await, + coinbase_amt + ); + + // --- Setup. generate an output that our wallet cannot claim. --- + let outputs = vec![( + ReceivingAddress::from(GenerationReceivingAddress::derive_from_seed(rng.gen())), + send_amt, + )]; + + let mut tx_outputs = gs.generate_tx_outputs(outputs, UtxoNotifyMethod::OnChain)?; + + gs.create_transaction( + &mut tx_outputs, + change_key, + UtxoNotifyMethod::OnChain, + NeptuneCoins::zero(), + Timestamp::now(), + ) + .await? + }; + + global_state_lock + .lock_guard_mut() + .await + .mempool + .insert(tx)?; + + // we must yield so the wallet task can process the mempool events + tokio::task::yield_now().await; + + { + let gs = global_state_lock.lock_guard().await; + assert_eq!( + gs.wallet_state + .confirmed_balance(tip_digest, Timestamp::now()) + .await, + coinbase_amt + ); + debug!("calculated confirmed balance"); + assert_eq!( + gs.wallet_state + .unconfirmed_balance(tip_digest, Timestamp::now()) + .await, + coinbase_amt.checked_sub(&send_amt).unwrap() + ); + debug!("calculated unconfirmed balance"); + } + + global_state_lock.lock_guard_mut().await.mempool.clear()?; + + // we must yield so the wallet task can process the mempool events + tokio::task::yield_now().await; + + assert_eq!( + global_state_lock + .lock_guard() + .await + .wallet_state + .unconfirmed_balance(tip_digest, Timestamp::now()) + .await, + coinbase_amt + ); + + Ok(()) + } + } + mod expected_utxos { use crate::models::blockchain::transaction::utxo::LockScript; use crate::tests::shared::make_mock_transaction; diff --git a/src/peer_loop.rs b/src/peer_loop.rs index a7a4a0d3..11f5616b 100644 --- a/src/peer_loop.rs +++ b/src/peer_loop.rs @@ -2491,7 +2491,7 @@ mod peer_loop_tests { .lock_guard_mut() .await .mempool - .insert(&transaction_1); + .insert(transaction_1)?; assert!( !state_lock.lock_guard().await.mempool.is_empty(), "Mempool must be non-empty after insertion" diff --git a/src/tests/shared.rs b/src/tests/shared.rs index 33b6533b..95a343e0 100644 --- a/src/tests/shared.rs +++ b/src/tests/shared.rs @@ -989,3 +989,27 @@ pub async fn mock_genesis_archival_state( (archival_state, peer_db, data_dir) } + +// this will create and store the next block including any transactions +// presently in the mempool. The coinbase will go to our own wallet. +// +// the stored block does NOT have valid proof-of-work. +pub async fn mine_block_to_wallet(global_state_lock: &mut GlobalStateLock) -> Result { + let state = global_state_lock.lock_guard().await; + let tip_block = state.chain.light_state(); + + let timestamp = Timestamp::now(); + let (transaction, coinbase_expected_utxo) = + crate::mine_loop::create_block_transaction(tip_block, &state, timestamp); + + let (header, body) = + crate::mine_loop::make_block_template(tip_block, transaction, timestamp, None); + let block = Block::new(header, body, Block::mk_std_block_type(None)); + drop(state); + + global_state_lock + .store_coinbase_block(block.clone(), coinbase_expected_utxo) + .await?; + + Ok(block) +} From da500843b182ea65c2afa8a42a5a49b25bc323ae Mon Sep 17 00:00:00 2001 From: danda Date: Wed, 25 Sep 2024 20:55:29 -0700 Subject: [PATCH 2/3] refactor: wallet updates now atomic with mempool removes the mempool broadcast channel and wallet listener task. Instead all mempool mutations go through GlobalState methods which inform wallet of the changes. This makes changes atomic over mempool+wallet so they are always in sync. Changes: * remove Mempool::event_channel * Mempool &mut methods only callable by super * Mempool &mut methods return MempoolEvent(s) * add MempoolEvent::UpdateTxMutatorSet. (unused) * add GlobalState methods: mempool_clear, mempool_insert, mempool_prune_stale_transactions * remove spawn_wallet_task from lib.rs * add/improve doc-comments --- src/lib.rs | 55 ++------- src/locks/tokio/atomic_rw.rs | 3 + src/main_loop.rs | 10 +- src/mine_loop.rs | 4 +- src/models/state/mempool.rs | 114 ++++++++++-------- src/models/state/mod.rs | 18 +++ .../state/wallet/rusty_wallet_database.rs | 3 - src/models/state/wallet/wallet_state.rs | 73 ++++++++--- src/peer_loop.rs | 4 +- 9 files changed, 164 insertions(+), 120 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4f2eaf5c..64f43b93 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,16 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { DataDirectory::create_dir_if_not_exists(&data_dir.root_dir_path()).await?; info!("Data directory is {}", data_dir); + // Get wallet object, create various wallet secret files + let wallet_dir = data_dir.wallet_directory_path(); + DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; + let (wallet_secret, _) = + WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; + info!("Now getting wallet state. This may take a while if the database needs pruning."); + let wallet_state = + WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; + info!("Got wallet state."); + // Connect to or create databases for block index, peers, mutator set, block sync let block_index_db = ArchivalState::initialize_block_index_database(&data_dir).await?; info!("Got block index database"); @@ -101,7 +111,7 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { info!("Got archival mutator set"); let archival_state = ArchivalState::new( - data_dir.clone(), + data_dir, block_index_db, archival_mutator_set, cli_args.network, @@ -139,17 +149,6 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { }; let blockchain_state = BlockchainState::Archival(blockchain_archival_state); let mempool = Mempool::new(cli_args.max_mempool_size, latest_block.hash()); - - // Get wallet object, create various wallet secret files - let wallet_dir = data_dir.wallet_directory_path(); - DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; - let (wallet_secret, _) = - WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; - info!("Now getting wallet state. This may take a while if the database needs pruning."); - let wallet_state = - WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; - info!("Got wallet state."); - let mut global_state_lock = GlobalStateLock::new( wallet_state, blockchain_state, @@ -177,11 +176,8 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await?; info!("UTXO restoration check complete"); - let mut task_join_handles = vec![]; - - task_join_handles.push(spawn_wallet_task(global_state_lock.clone()).await?); - // Connect to peers, and provide each peer task with a thread-safe copy of the state + let mut task_join_handles = vec![]; for peer_address in global_state_lock.cli().peers.clone() { let peer_state_var = global_state_lock.clone(); // bump arc refcount let main_to_peer_broadcast_rx_clone: broadcast::Receiver = @@ -288,33 +284,6 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await } -pub(crate) async fn spawn_wallet_task( - mut global_state_lock: GlobalStateLock, -) -> Result> { - let mut mempool_subscriber = global_state_lock.lock_guard().await.mempool.subscribe(); - - let wallet_join_handle = tokio::task::Builder::new() - .name("wallet_mempool_listener") - .spawn(async move { - let mut events: std::collections::VecDeque<_> = Default::default(); - - while let Ok(event) = mempool_subscriber.recv().await { - events.push_back(event); - - if let Ok(mut gs) = global_state_lock.try_lock_guard_mut() { - while let Some(e) = events.pop_front() { - gs.wallet_state - .handle_mempool_event(e) - .await - .expect("Wallet should handle mempool event without error"); - } - } - } - })?; - - Ok(wallet_join_handle) -} - /// Time a fn call. Duration is returned as a float in seconds. pub fn time_fn_call(f: impl FnOnce() -> O) -> (O, f64) { let start = Instant::now(); diff --git a/src/locks/tokio/atomic_rw.rs b/src/locks/tokio/atomic_rw.rs index b372119a..9d6559d0 100644 --- a/src/locks/tokio/atomic_rw.rs +++ b/src/locks/tokio/atomic_rw.rs @@ -241,6 +241,9 @@ impl AtomicRw { AtomicRwWriteGuard::new(guard, &self.lock_callback_info) } + /// Attempt to acquire write lock immediately. + /// + /// If the lock cannot be acquired without waiting, an error is returned. pub fn try_lock_guard_mut(&mut self) -> Result, TryLockError> { self.try_acquire_write_cb(); let guard = self.inner.try_write()?; diff --git a/src/main_loop.rs b/src/main_loop.rs index 8e64760c..0f50b6a7 100644 --- a/src/main_loop.rs +++ b/src/main_loop.rs @@ -528,8 +528,8 @@ impl MainLoopHandler { // Insert into mempool global_state_mut - .mempool - .insert(pt2m_transaction.transaction.to_owned())?; + .mempool_insert(pt2m_transaction.transaction.to_owned()) + .await?; // send notification to peers let transaction_notification: TransactionNotification = @@ -970,7 +970,7 @@ impl MainLoopHandler { // Handle mempool cleanup, i.e. removing stale/too old txs from mempool _ = &mut mempool_cleanup_timer => { debug!("Timer: mempool-cleaner job"); - self.global_state_lock.lock_guard_mut().await.mempool.prune_stale_transactions()?; + self.global_state_lock.lock_guard_mut().await.mempool_prune_stale_transactions().await?; // Reset the timer to run this branch again in P seconds mempool_cleanup_timer.as_mut().reset(tokio::time::Instant::now() + mempool_cleanup_timer_interval); @@ -1028,8 +1028,8 @@ impl MainLoopHandler { self.global_state_lock .lock_guard_mut() .await - .mempool - .insert(*transaction)?; + .mempool_insert(*transaction) + .await?; // do not shut down Ok(false) diff --git a/src/mine_loop.rs b/src/mine_loop.rs index 24b61694..10f8461c 100644 --- a/src/mine_loop.rs +++ b/src/mine_loop.rs @@ -603,8 +603,8 @@ mod mine_loop_tests { .await?; premine_receiver_global_state - .mempool - .insert(tx_by_preminer)?; + .mempool_insert(tx_by_preminer) + .await?; assert_eq!(1, premine_receiver_global_state.mempool.len()); // Build transaction diff --git a/src/models/state/mempool.rs b/src/models/state/mempool.rs index 5b2a8183..d676fb9f 100644 --- a/src/models/state/mempool.rs +++ b/src/models/state/mempool.rs @@ -60,10 +60,21 @@ pub const TRANSACTION_NOTIFICATION_AGE_LIMIT_IN_SECS: u64 = 60 * 60 * 24; type LookupItem<'a> = (Digest, &'a Transaction); +/// Represents a mempool state change. +/// +/// For purpose of notifying interested parties #[derive(Debug, Clone)] pub enum MempoolEvent { + /// a transaction was added to the mempool AddTx(Transaction), + + /// a transaction was removed from the mempool RemoveTx(Transaction), + + /// the mutator-set of a transaction was updated in the mempool. + /// + /// (Digest of Tx before update, Tx after mutator-set updated) + UpdateTxMutatorSet(Digest, Transaction), } #[derive(Debug, GetSize)] @@ -76,21 +87,20 @@ pub struct Mempool { /// Allows the mempool to report transactions sorted by [`FeeDensity`] in /// both descending and ascending order. - #[get_size(ignore)] // This is relatively small compared to `LookupTable` + #[get_size(ignore)] // This is relatively small compared to `tx_dictionary` queue: DoublePriorityQueue, /// Records the digest of the block that the transactions were synced to. /// Used to discover reorganizations. tip_digest: Digest, - - /// a mpmc channel for interested parties to listen to mempool events - #[get_size(ignore)] // does not impl GetSize - event_channel: ( - tokio::sync::broadcast::Sender, - tokio::sync::broadcast::Receiver, - ), } +/// note that all methods that modify state and result in a MempoolEvent +/// notification are private or pub(super). This enforces that these methods +/// can only be called from/via GlobalState. +/// +/// Mempool updates must go through GlobalState so that it can +/// forward mempool events to the wallet in atomic fashion. impl Mempool { /// instantiate a new, empty `Mempool` pub fn new(max_total_size: ByteSize, tip_digest: Digest) -> Self { @@ -102,12 +112,11 @@ impl Mempool { tx_dictionary: table, queue, tip_digest, - event_channel: tokio::sync::broadcast::channel(100), } } /// Update the block digest to which all transactions are synced. - fn set_tip_digest_sync_label(&mut self, tip_digest: Digest) { + pub(super) fn set_tip_digest_sync_label(&mut self, tip_digest: Digest) { self.tip_digest = tip_digest; } @@ -155,7 +164,14 @@ impl Mempool { /// this method accepts only fully proven transactions (or, for the time being, faith witnesses). /// The caller must also ensure that the transaction does not have a timestamp /// in the too distant future. - pub fn insert(&mut self, transaction: Transaction) -> Result { + /// + /// this method may return: + /// 2 events: RemoveTx,AddTx. tx replaces an older one with lower fee. + /// 1 event: AddTx. tx does not replace an older one. + /// 0 events: tx not added because an older matching tx has a higher fee. + pub(super) fn insert(&mut self, transaction: Transaction) -> Result> { + let mut events = vec![]; + match transaction.witness.vast.witness_type { WitnessType::RawWitness(_) => panic!("Can only insert fully proven transactions into mempool; not accepting raw witnesses."), WitnessType::Decomposition => panic!("Can only insert fully proven transactions into mempool; not accepting decompositions."), @@ -169,11 +185,13 @@ impl Mempool { if tx.fee_density() < transaction.fee_density() { // If new transaction has a higher fee density than the one previously seen // remove the old one. - self.remove(txid)?; + if let Some(e) = self.remove(txid)? { + events.push(e); + } } else { // If new transaction has a lower fee density than the one previous seen, // ignore it. Stop execution here. - return Ok(txid); + return Ok(events); } }; @@ -194,26 +212,33 @@ impl Mempool { "mempool's table and queue length must agree after shrink" ); - self.sender().send(MempoolEvent::AddTx(transaction))?; + events.push(MempoolEvent::AddTx(transaction)); - Ok(transaction_id) + Ok(events) } /// remove a transaction from the `Mempool` - pub fn remove(&mut self, transaction_id: Digest) -> Result { + pub(super) fn remove(&mut self, transaction_id: Digest) -> Result> { match self.tx_dictionary.remove(&transaction_id) { Some(tx) => { self.queue.remove(&transaction_id); debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender().send(MempoolEvent::RemoveTx(tx))?; - Ok(true) + Ok(Some(MempoolEvent::RemoveTx(tx))) } - None => Ok(false), + None => Ok(None), } } /// Delete all transactions from the mempool. - pub fn clear(&mut self) -> Result<()> { + /// + /// note that this will return a MempoolEvent for every removed Tx. + /// In the case of a full block, that could be a lot of Tx and + /// significant memory usage. Of course the mempool itself will + /// be emptied at the same time. + /// + /// If the mem usage ever becomes a problem we could accept a closure + /// to handle the events individually as each Tx is removed. + pub(super) fn clear(&mut self) -> Result> { // note: this causes event listeners to be notified of each removed tx. self.retain(|_| false) } @@ -267,15 +292,14 @@ impl Mempool { /// /// Computes in θ(lg N) #[allow(dead_code)] - pub fn pop_max(&mut self) -> Result> { + fn pop_max(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_max() { if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender() - .send(MempoolEvent::RemoveTx(transaction.clone()))?; + let event = MempoolEvent::RemoveTx(transaction); - return Ok(Some((transaction, fee_density))); + return Ok(Some((event, fee_density))); } } Ok(None) @@ -285,15 +309,14 @@ impl Mempool { /// Returns the removed value. /// /// Computes in θ(lg N) - pub fn pop_min(&mut self) -> Result> { + fn pop_min(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_min() { if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender() - .send(MempoolEvent::RemoveTx(transaction.clone()))?; + let event = MempoolEvent::RemoveTx(transaction); - return Ok(Some((transaction, fee_density))); + return Ok(Some((event, fee_density))); } } Ok(None) @@ -304,7 +327,7 @@ impl Mempool { /// Modelled after [HashMap::retain](std::collections::HashMap::retain()) /// /// Computes in O(capacity) >= O(N) - pub fn retain(&mut self, mut predicate: F) -> Result<()> + fn retain(&mut self, mut predicate: F) -> Result> where F: FnMut(LookupItem) -> bool, { @@ -317,21 +340,24 @@ impl Mempool { } } + let mut events = Vec::with_capacity(victims.len()); for t in victims { - self.remove(t)?; + if let Some(e) = self.remove(t)? { + events.push(e); + } } debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); self.shrink_to_fit(); - Ok(()) + Ok(events) } /// Remove transactions from mempool that are older than the specified /// timestamp. Prunes base on the transaction's timestamp. /// /// Computes in O(n) - pub fn prune_stale_transactions(&mut self) -> Result<()> { + pub(super) fn prune_stale_transactions(&mut self) -> Result> { let cutoff = Timestamp::now() - Timestamp::seconds(MEMPOOL_TX_THRESHOLD_AGE_IN_SECS); let keep = |(_transaction_id, transaction): LookupItem| -> bool { @@ -344,11 +370,11 @@ impl Mempool { /// Remove from the mempool all transactions that become invalid because /// of a newly received block. Also update all mutator set data for mempool /// transactions that were not removed. - pub async fn update_with_block( + pub(super) async fn update_with_block( &mut self, previous_mutator_set_accumulator: MutatorSetAccumulator, block: &Block, - ) -> Result<()> { + ) -> Result> { // If we discover a reorganization, we currently just clear the mempool, // as we don't have the ability to roll transaction removal record integrity // proofs back to previous blocks. It would be nice if we could handle a @@ -397,13 +423,14 @@ impl Mempool { }; // Remove the transactions that become invalid with this block - self.retain(keep)?; + let mut events = self.retain(keep)?; // Update the remaining transactions so their mutator set data is still valid - for tx in self.tx_dictionary.values_mut() { + for (tx_id, tx) in self.tx_dictionary.iter_mut() { *tx = tx .new_with_updated_mutator_set_records(&previous_mutator_set_accumulator, block) .expect("Updating mempool transaction must succeed"); + events.push(MempoolEvent::UpdateTxMutatorSet(*tx_id, (*tx).clone())); } // Maintaining the mutator set data could have increased the size of the @@ -415,7 +442,7 @@ impl Mempool { let current_block_digest = block.hash(); self.set_tip_digest_sync_label(current_block_digest); - Ok(()) + Ok(events) } /// Shrink the memory pool to the value of its `max_size` field. @@ -465,15 +492,6 @@ impl Mempool { let dpq_clone = self.queue.clone(); dpq_clone.into_sorted_iter().rev() } - - pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver { - self.sender().subscribe() - } - - fn sender(&self) -> &tokio::sync::broadcast::Sender { - let (sender, _) = &self.event_channel; - sender - } } #[cfg(test)] @@ -532,10 +550,10 @@ mod tests { assert_eq!(Some(&transaction), transaction_get_option); assert!(mempool.contains(transaction_digest)); - assert!(mempool.remove(transaction_digest)?); + assert!(mempool.remove(transaction_digest)?.is_some()); assert!(!mempool.contains(transaction_digest)); - assert!(!mempool.remove(transaction_digest)?); + assert!(mempool.remove(transaction_digest)?.is_none()); assert!(!mempool.contains(transaction_digest)); Ok(()) diff --git a/src/models/state/mod.rs b/src/models/state/mod.rs index 28654a1f..a3755411 100644 --- a/src/models/state/mod.rs +++ b/src/models/state/mod.rs @@ -1382,6 +1382,24 @@ impl GlobalState { pub fn cli(&self) -> &cli_args::Args { &self.cli } + + /// clears all Tx from mempool and notifies wallet of changes. + pub async fn mempool_clear(&mut self) -> Result<()> { + let events = self.mempool.clear()?; + self.wallet_state.handle_mempool_events(events).await + } + + /// adds Tx to mempool and notifies wallet of change. + pub async fn mempool_insert(&mut self, transaction: Transaction) -> Result<()> { + let events = self.mempool.insert(transaction)?; + self.wallet_state.handle_mempool_events(events).await + } + + /// prunes stale tx in mempool and notifies wallet of changes. + pub async fn mempool_prune_stale_transactions(&mut self) -> Result<()> { + let events = self.mempool.prune_stale_transactions()?; + self.wallet_state.handle_mempool_events(events).await + } } #[cfg(test)] diff --git a/src/models/state/wallet/rusty_wallet_database.rs b/src/models/state/wallet/rusty_wallet_database.rs index 73a0e588..683f33a5 100644 --- a/src/models/state/wallet/rusty_wallet_database.rs +++ b/src/models/state/wallet/rusty_wallet_database.rs @@ -18,9 +18,6 @@ pub struct RustyWalletDatabase { // list of utxos we have already received in a block monitored_utxos: DbtVec, - // list of utxos presently in the mempool - // monitored_mempool_utxos: DbtVec, - // list of off-chain utxos we are expecting to receive in a future block expected_utxos: DbtVec, diff --git a/src/models/state/wallet/wallet_state.rs b/src/models/state/wallet/wallet_state.rs index 06aa5fea..1a020f56 100644 --- a/src/models/state/wallet/wallet_state.rs +++ b/src/models/state/wallet/wallet_state.rs @@ -18,6 +18,7 @@ use tokio::io::BufWriter; use tracing::debug; use tracing::error; use tracing::info; +use tracing::trace; use tracing::warn; use twenty_first::math::bfield_codec::BFieldCodec; use twenty_first::math::digest::Digest; @@ -67,6 +68,8 @@ pub struct WalletState { pub number_of_mps_per_utxo: usize, wallet_directory_path: PathBuf, + /// these two fields are for monitoring wallet-affecting utxos in the mempool. + /// key is Tx hash. for removing watched utxos when a tx is removed from mempool. mempool_spent_utxos: HashMap>, mempool_unspent_utxos: HashMap>, } @@ -243,10 +246,30 @@ impl WalletState { wallet_state } - pub async fn handle_mempool_event(&mut self, event: MempoolEvent) -> Result<()> { + /// handles a list of mempool events + pub(in crate::models::state) async fn handle_mempool_events( + &mut self, + events: impl IntoIterator, + ) -> Result<()> { + for event in events { + self.handle_mempool_event(event).await? + } + Ok(()) + } + + /// handles a single mempool event. + /// + /// note: the wallet watches the mempool in order to keep track of + /// unconfirmed utxos sent from or to the wallet. This enables + /// calculation of unconfirmed balance. It also lays foundation for + /// spending unconfirmed utxos. (issue #189) + pub(in crate::models::state) async fn handle_mempool_event( + &mut self, + event: MempoolEvent, + ) -> Result<()> { match event { MempoolEvent::AddTx(tx) => { - debug!("handling mempool AddTx event."); + trace!("handling mempool AddTx event."); let spent_utxos = self.scan_for_spent_utxos(&tx).await; @@ -260,11 +283,14 @@ impl WalletState { self.mempool_unspent_utxos.insert(tx_hash, announced_utxos); } MempoolEvent::RemoveTx(tx) => { - debug!("handling mempool RemoveTx event."); + trace!("handling mempool RemoveTx event."); let tx_hash = Hash::hash(&tx); self.mempool_spent_utxos.remove(&tx_hash); self.mempool_unspent_utxos.remove(&tx_hash); } + MempoolEvent::UpdateTxMutatorSet(_tx_hash_pre_update, _tx_post_update) => { + // Utxos are not affected by MutatorSet update, so this is a no-op. + } } Ok(()) } @@ -1374,6 +1400,15 @@ mod tests { use crate::models::state::wallet::address::ReceivingAddress; use crate::tests::shared::mine_block_to_wallet; + /// basic test for confirmed and unconfirmed balance. + /// + /// This test: + /// 1. mines a block to self worth 100 + /// 2. sends 5 to a 3rd party, and 95 change back to self. + /// 3. verifies that confirmed balance is 100 + /// 4. verifies that unconfirmed balance is 95 + /// 5. empties the mempool (removing our unconfirmed tx) + /// 6. verifies that unconfirmed balance is 100 #[traced_test] #[tokio::test] async fn confirmed_and_unconfirmed_balance() -> Result<()> { @@ -1381,18 +1416,20 @@ mod tests { let network = Network::RegTest; let mut global_state_lock = mock_genesis_global_state(network, 0, WalletSecret::new_random()).await; - let _wallet_task_jh = crate::spawn_wallet_task(global_state_lock.clone()).await?; let change_key = global_state_lock .lock_guard_mut() .await .wallet_state .next_unused_spending_key(KeyType::Generation); + let coinbase_amt = NeptuneCoins::new(100); let send_amt = NeptuneCoins::new(5); + // mine a block to our wallet. we should have 100 coins after. let tip_digest = mine_block_to_wallet(&mut global_state_lock).await?.hash(); let tx = { + // verify that confirmed and unconfirmed balance are both 100. let gs = global_state_lock.lock_guard().await; assert_eq!( gs.wallet_state @@ -1407,14 +1444,14 @@ mod tests { coinbase_amt ); - // --- Setup. generate an output that our wallet cannot claim. --- + // generate an output that our wallet cannot claim. let outputs = vec![( ReceivingAddress::from(GenerationReceivingAddress::derive_from_seed(rng.gen())), send_amt, )]; + // create tx, with 5 coins going to 3rd party and 95 coins change back to self. let mut tx_outputs = gs.generate_tx_outputs(outputs, UtxoNotifyMethod::OnChain)?; - gs.create_transaction( &mut tx_outputs, change_key, @@ -1425,16 +1462,16 @@ mod tests { .await? }; + // add the tx to the mempool. + // note that the wallet should be notified of these changes. global_state_lock .lock_guard_mut() .await - .mempool - .insert(tx)?; - - // we must yield so the wallet task can process the mempool events - tokio::task::yield_now().await; + .mempool_insert(tx) + .await?; { + // verify that confirmed balance is still 100 let gs = global_state_lock.lock_guard().await; assert_eq!( gs.wallet_state @@ -1442,21 +1479,23 @@ mod tests { .await, coinbase_amt ); - debug!("calculated confirmed balance"); + // verify that unconfirmed balance is now 95. assert_eq!( gs.wallet_state .unconfirmed_balance(tip_digest, Timestamp::now()) .await, coinbase_amt.checked_sub(&send_amt).unwrap() ); - debug!("calculated unconfirmed balance"); } - global_state_lock.lock_guard_mut().await.mempool.clear()?; - - // we must yield so the wallet task can process the mempool events - tokio::task::yield_now().await; + // clear the mempool, which drops our unconfirmed tx. + global_state_lock + .lock_guard_mut() + .await + .mempool_clear() + .await?; + // verify that wallet's unconfirmed balance is 100 again. assert_eq!( global_state_lock .lock_guard() diff --git a/src/peer_loop.rs b/src/peer_loop.rs index 11f5616b..10d265e2 100644 --- a/src/peer_loop.rs +++ b/src/peer_loop.rs @@ -2490,8 +2490,8 @@ mod peer_loop_tests { state_lock .lock_guard_mut() .await - .mempool - .insert(transaction_1)?; + .mempool_insert(transaction_1) + .await?; assert!( !state_lock.lock_guard().await.mempool.is_empty(), "Mempool must be non-empty after insertion" From 9ec1de6cb968367f8ca8622c282a60edacb80cc8 Mon Sep 17 00:00:00 2001 From: danda Date: Thu, 3 Oct 2024 23:24:18 -0700 Subject: [PATCH 3/3] fix: prevent spending same input twice. closes #189. input selection now ignores spent inputs from unconfirmed tx in the mempool. Also fixes an input selection bug where the balance check considers available funds (not timelocked) but the input utxo selection does not. Adds a test to verify that same input can no longer be spent twice. --- src/models/state/wallet/wallet_state.rs | 92 ++++++++++++++++++++++-- src/models/state/wallet/wallet_status.rs | 14 ++-- 2 files changed, 97 insertions(+), 9 deletions(-) diff --git a/src/models/state/wallet/wallet_state.rs b/src/models/state/wallet/wallet_state.rs index 1a020f56..fd496a4d 100644 --- a/src/models/state/wallet/wallet_state.rs +++ b/src/models/state/wallet/wallet_state.rs @@ -1009,14 +1009,24 @@ impl WalletState { // membership proofs. let wallet_status = self.get_wallet_status_from_lock(tip_digest).await; + // filter out any utxos that are already spent in the mempool. + let unspent_utxos = wallet_status + .synced_unspent_available_iter(timestamp) + .filter(|(wse, _)| !self.mempool_spent_utxos_iter().any(|u| *u == wse.utxo)) + .collect_vec(); + let unspent_available_amount = unspent_utxos + .iter() + .map(|(wse, _)| wse.utxo.get_native_currency_amount()) + .sum::(); + // First check that we have enough. Otherwise return an error. - if wallet_status.synced_unspent_available_amount(timestamp) < requested_amount { + if unspent_available_amount < requested_amount { bail!( "Insufficient synced amount to create transaction. Requested: {}, Total synced UTXOs: {}. Total synced amount: {}. Synced unspent available amount: {}. Synced unspent timelocked amount: {}. Total unsynced UTXOs: {}. Unsynced unspent amount: {}. Block is: {}", requested_amount, wallet_status.synced_unspent.len(), wallet_status.synced_unspent.iter().map(|(wse, _msmp)| wse.utxo.get_native_currency_amount()).sum::(), - wallet_status.synced_unspent_available_amount(timestamp), + unspent_available_amount, wallet_status.synced_unspent_timelocked_amount(timestamp), wallet_status.unsynced_unspent.len(), wallet_status.unsynced_unspent_amount(), @@ -1027,8 +1037,7 @@ impl WalletState { let mut allocated_amount = NeptuneCoins::zero(); while allocated_amount < requested_amount { - let (wallet_status_element, membership_proof) = - wallet_status.synced_unspent[ret.len()].clone(); + let (wallet_status_element, membership_proof) = unspent_utxos[ret.len()].clone(); // find spending key for this utxo. let spending_key = match self.find_spending_key_for_utxo(&wallet_status_element.utxo) { @@ -1392,7 +1401,7 @@ mod tests { } } - mod wallet_balance { + mod unconfirmed_tx { use generation_address::GenerationReceivingAddress; use super::*; @@ -1508,6 +1517,79 @@ mod tests { Ok(()) } + + // this test attempts to spend the same input twice in the same block. + // this results in an "insufficient funds" error in this case because + // the input selection code ignores the spent utxo on the 2nd attempt + // and no other input utxos are available to fund the tx. + #[traced_test] + #[tokio::test] + async fn attempt_spend_input_in_mempool() -> Result<()> { + let mut rng = thread_rng(); + let network = Network::RegTest; + let mut global_state_lock = + mock_genesis_global_state(network, 0, WalletSecret::new_random()).await; + let change_key = global_state_lock + .lock_guard_mut() + .await + .wallet_state + .next_unused_spending_key(KeyType::Generation); + + let send_amt = NeptuneCoins::new(5); + + // mine a block to our wallet. we should have 100 coins after. + mine_block_to_wallet(&mut global_state_lock).await?.hash(); + + // generate an output that our wallet cannot claim. + let outputs = vec![( + ReceivingAddress::from(GenerationReceivingAddress::derive_from_seed(rng.gen())), + send_amt, + )]; + + let tx = { + let gs = global_state_lock.lock_guard().await; + + // create tx, with 5 coins going to 3rd party and 95 coins change back to self. + let mut tx_outputs = + gs.generate_tx_outputs(outputs.clone(), UtxoNotifyMethod::OnChain)?; + gs.create_transaction( + &mut tx_outputs, + change_key, + UtxoNotifyMethod::OnChain, + NeptuneCoins::zero(), + Timestamp::now(), + ) + .await? + }; + + // add the tx to the mempool. + // note that the wallet should be notified of these changes. + global_state_lock + .lock_guard_mut() + .await + .mempool_insert(tx) + .await?; + + let gs = global_state_lock.lock_guard().await; + let mut tx_outputs = gs.generate_tx_outputs(outputs, UtxoNotifyMethod::OnChain)?; + let result = gs + .create_transaction( + &mut tx_outputs, + change_key, + UtxoNotifyMethod::OnChain, + NeptuneCoins::zero(), + Timestamp::now(), + ) + .await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Insufficient synced amount to create transaction")); + + Ok(()) + } } mod expected_utxos { diff --git a/src/models/state/wallet/wallet_status.rs b/src/models/state/wallet/wallet_status.rs index 4c09bdb9..0ab9b288 100644 --- a/src/models/state/wallet/wallet_status.rs +++ b/src/models/state/wallet/wallet_status.rs @@ -40,12 +40,18 @@ pub struct WalletStatus { } impl WalletStatus { - pub fn synced_unspent_available_amount(&self, timestamp: Timestamp) -> NeptuneCoins { + pub fn synced_unspent_available_iter( + &self, + timestamp: Timestamp, + ) -> impl Iterator { self.synced_unspent .iter() - .map(|(wse, _msmp)| &wse.utxo) - .filter(|utxo| utxo.can_spend_at(timestamp)) - .map(|utxo| utxo.get_native_currency_amount()) + .filter(move |(wse, _)| wse.utxo.can_spend_at(timestamp)) + } + + pub fn synced_unspent_available_amount(&self, timestamp: Timestamp) -> NeptuneCoins { + self.synced_unspent_available_iter(timestamp) + .map(|(wse, _)| wse.utxo.get_native_currency_amount()) .sum::() } pub fn synced_unspent_timelocked_amount(&self, timestamp: Timestamp) -> NeptuneCoins {