diff --git a/src/wallet/coin_store.rs b/src/wallet/coin_store.rs index 8ba382f3..994698db 100644 --- a/src/wallet/coin_store.rs +++ b/src/wallet/coin_store.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, future::Future}; -use chia_protocol::CoinState; +use chia_protocol::{Coin, CoinState}; use parking_lot::Mutex; /// Keeps track of the state of coins in a wallet. @@ -8,6 +8,12 @@ pub trait CoinStore { /// Applies coin state updates. fn update_coin_state(&self, coin_states: Vec) -> impl Future + Send; + /// Gets a list of unspent coins. + fn unspent_coins(&self) -> impl Future> + Send; + + /// Gets the current state of a coin. + fn coin_state(&self, coin_id: [u8; 32]) -> impl Future> + Send; + /// Gets coin states for a given puzzle hash. fn is_used(&self, puzzle_hash: [u8; 32]) -> impl Future + Send; } @@ -30,22 +36,41 @@ impl CoinStore for MemoryCoinStore { async fn update_coin_state(&self, coin_states: Vec) { for coin_state in coin_states { let puzzle_hash = &coin_state.coin.puzzle_hash; + let puzzle_hash = <&[u8; 32]>::from(puzzle_hash); + let mut db = self.coin_states.lock(); - if let Some(items) = self - .coin_states - .lock() - .get_mut(<&[u8; 32]>::from(puzzle_hash)) - { + if let Some(items) = db.get_mut(puzzle_hash) { match items.iter_mut().find(|item| item.coin == coin_state.coin) { Some(value) => { *value = coin_state; } None => items.push(coin_state), } + } else { + db.insert(*puzzle_hash, vec![coin_state]); } } } + async fn unspent_coins(&self) -> Vec { + self.coin_states + .lock() + .values() + .flatten() + .filter(|coin_state| coin_state.spent_height.is_none()) + .map(|coin_state| coin_state.coin.clone()) + .collect() + } + + async fn coin_state(&self, coin_id: [u8; 32]) -> Option { + self.coin_states + .lock() + .values() + .flatten() + .find(|coin_state| coin_state.coin.coin_id() == coin_id) + .cloned() + } + async fn is_used(&self, puzzle_hash: [u8; 32]) -> bool { self.coin_states .lock() diff --git a/src/wallet/sync.rs b/src/wallet/sync.rs index 4d809d84..a546113c 100644 --- a/src/wallet/sync.rs +++ b/src/wallet/sync.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use chia_client::{Error, Peer, PeerEvent}; +use tokio::sync::mpsc; use crate::{CoinStore, DerivationStore}; @@ -25,6 +26,7 @@ pub async fn incremental_sync( derivation_store: Arc, coin_store: Arc, config: SyncConfig, + synced_sender: mpsc::Sender<()>, ) -> Result<(), Error<()>> { let mut event_receiver = peer.receiver().resubscribe(); @@ -47,6 +49,8 @@ pub async fn incremental_sync( ) .await?; + synced_sender.send(()).await.ok(); + while let Ok(event) = event_receiver.recv().await { if let PeerEvent::CoinStateUpdate(update) = event { coin_store.update_coin_state(update.items).await; @@ -57,13 +61,16 @@ pub async fn incremental_sync( &config, ) .await?; + + synced_sender.send(()).await.ok(); } } Ok(()) } -async fn subscribe( +/// Subscribe to another set of puzzle hashes. +pub async fn subscribe( peer: &Peer, coin_store: &impl CoinStore, puzzle_hashes: Vec<[u8; 32]>, @@ -75,7 +82,8 @@ async fn subscribe( Ok(()) } -async fn derive_more( +/// Create more derivations for a wallet. +pub async fn derive_more( peer: &Peer, derivation_store: &impl DerivationStore, coin_store: &impl CoinStore, @@ -93,21 +101,26 @@ async fn derive_more( subscribe(peer, coin_store, puzzle_hashes).await } -async fn unused_index( +/// Gets the last unused derivation index for a wallet. +pub async fn unused_index( derivation_store: &impl DerivationStore, coin_store: &impl CoinStore, ) -> Option { let derivations = derivation_store.derivations().await; + let mut unused_index = None; for index in (0..derivations).rev() { let puzzle_hash = derivation_store.puzzle_hash(index).await.unwrap(); if !coin_store.is_used(puzzle_hash).await { - return Some(index); + unused_index = Some(index); + } else { + break; } } - None + unused_index } -async fn sync_to_unused_index( +/// Syncs a wallet such that there are enough unused derivations. +pub async fn sync_to_unused_index( peer: &Peer, derivation_store: &impl DerivationStore, coin_store: &impl CoinStore, @@ -132,7 +145,7 @@ async fn sync_to_unused_index( if let Some(unused_index) = result { // Calculate the extra unused derivations after that index. - let extra_indices = derivations - 1 - unused_index; + let extra_indices = derivations - unused_index; // Make sure at least `gap` indices are available if needed. if extra_indices < config.minimum_unused_derivations {