From c43b6a188403f101f9389a7c30ef96a7d566bfed Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Tue, 27 Feb 2024 20:02:08 -0700 Subject: [PATCH] zcash_client_backend: Make `AccountId` an associated type of `WalletRead` This PR was extracted from https://github.com/zcash/librustzcash/pull/1175 in order to make the changes to `zcash_client_backend` usable without the additional generalizations to `zcash_client_sqlite` made by that PR. Co-authored-by: Andrew Arnott --- zcash_client_backend/CHANGELOG.md | 35 +++- zcash_client_backend/src/data_api.rs | 165 ++++++++++-------- zcash_client_backend/src/data_api/chain.rs | 6 +- zcash_client_backend/src/data_api/wallet.rs | 108 +++++++++--- .../src/data_api/wallet/input_selection.rs | 8 +- zcash_client_backend/src/decrypt.rs | 17 +- zcash_client_backend/src/scanning.rs | 57 +++--- zcash_client_backend/src/wallet.rs | 38 ++-- zcash_client_sqlite/src/lib.rs | 74 ++++---- zcash_client_sqlite/src/testing.rs | 5 +- zcash_client_sqlite/src/wallet.rs | 89 ++++++---- zcash_client_sqlite/src/wallet/init.rs | 2 +- zcash_client_sqlite/src/wallet/sapling.rs | 8 +- 13 files changed, 393 insertions(+), 219 deletions(-) diff --git a/zcash_client_backend/CHANGELOG.md b/zcash_client_backend/CHANGELOG.md index 8dc4c1e333..98a7f02610 100644 --- a/zcash_client_backend/CHANGELOG.md +++ b/zcash_client_backend/CHANGELOG.md @@ -56,6 +56,9 @@ and this library adheres to Rust's notion of }` - `WalletSummary::next_sapling_subtree_index` - `wallet`: + - `HDSeedAccount` + - `ImportedAccount` + - `Account` - `propose_standard_transfer_to_address` - `create_proposed_transactions` - `input_selection`: @@ -96,6 +99,34 @@ and this library adheres to Rust's notion of - `parse::Param::name` ### Changed +- Several structs and functions had `AccountId` added as a generic type + parameter in order to decouple the concepts of an account identifier + and a ZIP-32 account index. Many APIs that previously referenced `zcash_primitives::zip32::AccountId` now reference the generic type. Impacted types and functions are: + - `zcash_client_backend::data_api::wallet::InputSelector::propose_transaction` + - `zcash_client_backend::data_api::wallet::propose_transfer` + - `zcash_client_backend::data_api::wallet::propose_standard_transfer_to_address` + - `zcash_client_backend::data_api::WalletRead::get_account_birthday` + - `zcash_client_backend::data_api::WalletRead::get_current_address` + - `zcash_client_backend::data_api::WalletRead::get_unified_full_viewing_keys` + - `zcash_client_backend::data_api::WalletRead::get_account_for_ufvk` + - `zcash_client_backend::data_api::WalletRead::get_wallet_summary` + - `zcash_client_backend::data_api::WalletRead::get_sapling_nullifiers` + - `zcash_client_backend::data_api::WalletRead::get_transparent_receivers` + - `zcash_client_backend::data_api::WalletRead::get_transparent_balances` + - `zcash_client_backend::data_api::WalletRead::get_account_ids` + - `zcash_client_backend::data_api::ScannedBlock` + - `zcash_client_backend::data_api::DecryptedTransaction` + - `zcash_client_backend::data_api::SentTransaction` + - `zcash_client_backend::data_api::SentTransactionOutput` + - `zcash_client_backend::data_api::WalletWrite::create_account` + - `zcash_client_backend::data_api::WalletWrite::get_next_available_address` + - `zcash_client_backend::decrypt::DecryptedOutput` + - `zcash_client_backend::decrypt::decrypt_transaction` + - `zcash_client_backend::scanning::scan_block` + - `zcash_client_backend::wallet::Recipient` + - `zcash_client_backend::wallet::WalletTx` + - `zcash_client_backend::wallet::WalletSaplingSpend` + - `zcash_client_backend::wallet::WalletSaplingOutput` - `zcash_client_backend::data_api`: - `BlockMetadata::sapling_tree_size` now returns an `Option` instead of a `u32` for future consistency with Orchard. @@ -114,6 +145,8 @@ and this library adheres to Rust's notion of - `WalletSummary::new` now takes an additional `next_sapling_subtree_index` argument. - Changes to the `WalletRead` trait: + - Added associated type `AccountId`. + - Added `get_account` function. - `get_checkpoint_depth` has been removed without replacement. This is no longer needed given the change to use the stored anchor height for transaction proposal execution. @@ -273,7 +306,7 @@ and this library adheres to Rust's notion of ### Removed - `zcash_client_backend::wallet`: - `ReceivedSaplingNote` (use `zcash_client_backend::ReceivedNote` instead). - - `input_selection::{Proposal, ProposalError}` (moved to + - `input_selection::{Proposal, ShieldedInputs, ProposalError}` (moved to `zcash_client_backend::proposal`). - `SentTransactionOutput::sapling_change_to` - the note created by an internal transfer is now conveyed in the `recipient` field. diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index 16e0f757f2..ff5a3f4ae1 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -1,12 +1,22 @@ //! Interfaces for wallet data persistence & low-level wallet utilities. use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Debug, io, num::{NonZeroU32, TryFromIntError}, }; +use self::scanning::ScanRange; +use self::{chain::CommitmentTreeRoot, wallet::Account}; +use crate::{ + address::UnifiedAddress, + decrypt::DecryptedOutput, + keys::{UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey}, + proto::service::TreeState, + wallet::{Note, NoteId, ReceivedNote, Recipient, WalletTransparentOutput, WalletTx}, + ShieldedProtocol, +}; use incrementalmerkletree::{frontier::Frontier, Retention}; use secrecy::SecretVec; use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree}; @@ -18,21 +28,9 @@ use zcash_primitives::{ components::amount::{Amount, BalanceError, NonNegativeAmount}, Transaction, TxId, }, - zip32::{AccountId, Scope}, -}; - -use crate::{ - address::UnifiedAddress, - decrypt::DecryptedOutput, - keys::{UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey}, - proto::service::TreeState, - wallet::{Note, NoteId, ReceivedNote, Recipient, WalletTransparentOutput, WalletTx}, - ShieldedProtocol, + zip32::Scope, }; -use self::chain::CommitmentTreeRoot; -use self::scanning::ScanRange; - #[cfg(feature = "transparent-inputs")] use { crate::wallet::TransparentAddressMetadata, @@ -291,18 +289,18 @@ impl Ratio { /// this circumstance it is possible that a newly created transaction could conflict with a /// not-yet-mined transaction in the mempool. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct WalletSummary { - account_balances: BTreeMap, +pub struct WalletSummary { + account_balances: HashMap, chain_tip_height: BlockHeight, fully_scanned_height: BlockHeight, scan_progress: Option>, next_sapling_subtree_index: u64, } -impl WalletSummary { +impl WalletSummary { /// Constructs a new [`WalletSummary`] from its constituent parts. pub fn new( - account_balances: BTreeMap, + account_balances: HashMap, chain_tip_height: BlockHeight, fully_scanned_height: BlockHeight, scan_progress: Option>, @@ -318,7 +316,7 @@ impl WalletSummary { } /// Returns the balances of accounts in the wallet, keyed by account ID. - pub fn account_balances(&self) -> &BTreeMap { + pub fn account_balances(&self) -> &HashMap { &self.account_balances } @@ -361,6 +359,9 @@ pub trait InputSource { /// The type of errors produced by a wallet backend. type Error; + /// The type used to track unique account identifiers. + type AccountId; + /// Backend-specific note identifier. /// /// For example, this might be a database identifier type @@ -384,7 +385,7 @@ pub trait InputSource { /// be included. fn select_spendable_notes( &self, - account: AccountId, + account: &Self::AccountId, target_value: Amount, sources: &[ShieldedProtocol], anchor_height: BlockHeight, @@ -425,6 +426,12 @@ pub trait WalletRead { /// The type of errors that may be generated when querying a wallet data store. type Error; + /// The type used to track unique account identifiers. + type AccountId: Eq + std::hash::Hash; + + /// Gets the parameters that went into creating an account (e.g. seed+index or uvk). + fn get_account(&self, account_id: &Self::AccountId) -> Result, Self::Error>; + /// Returns the height of the chain as known to the wallet as of the most recent call to /// [`WalletWrite::update_chain_tip`]. /// @@ -503,7 +510,7 @@ pub trait WalletRead { /// Returns the birthday height for the given account, or an error if the account is not known /// to the wallet. - fn get_account_birthday(&self, account: AccountId) -> Result; + fn get_account_birthday(&self, account: &Self::AccountId) -> Result; /// Returns the most recently generated unified address for the specified account, if the /// account identifier specified refers to a valid account for this wallet. @@ -512,26 +519,26 @@ pub trait WalletRead { /// account. fn get_current_address( &self, - account: AccountId, + account: &Self::AccountId, ) -> Result, Self::Error>; /// Returns all unified full viewing keys known to this wallet. fn get_unified_full_viewing_keys( &self, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Returns the account id corresponding to a given [`UnifiedFullViewingKey`], if any. fn get_account_for_ufvk( &self, ufvk: &UnifiedFullViewingKey, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Returns the wallet balances and sync status for an account given the specified minimum /// number of confirmations, or `Ok(None)` if the wallet has no balance data available. fn get_wallet_summary( &self, min_confirmations: u32, - ) -> Result, Self::Error>; + ) -> Result>, Self::Error>; /// Returns the memo for a note. /// @@ -549,7 +556,7 @@ pub trait WalletRead { fn get_sapling_nullifiers( &self, query: NullifierQuery, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Returns the nullifiers for Orchard notes that the wallet is tracking, along with their /// associated account IDs, that are either unspent or have not yet been confirmed as spent (in @@ -558,7 +565,7 @@ pub trait WalletRead { fn get_orchard_nullifiers( &self, query: NullifierQuery, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Returns the set of all transparent receivers associated with the given account. /// @@ -568,7 +575,7 @@ pub trait WalletRead { #[cfg(feature = "transparent-inputs")] fn get_transparent_receivers( &self, - _account: AccountId, + _account: &Self::AccountId, ) -> Result>, Self::Error> { Ok(HashMap::new()) } @@ -578,14 +585,14 @@ pub trait WalletRead { #[cfg(feature = "transparent-inputs")] fn get_transparent_balances( &self, - _account: AccountId, + _account: &Self::AccountId, _max_height: BlockHeight, ) -> Result, Self::Error> { Ok(HashMap::new()) } /// Returns a vector with the IDs of all accounts known to this wallet. - fn get_account_ids(&self) -> Result, Self::Error>; + fn get_account_ids(&self) -> Result, Self::Error>; } /// Metadata describing the sizes of the zcash note commitment trees as of a particular block. @@ -700,23 +707,23 @@ pub struct ScannedBlockCommitments { /// decrypted and extracted from a [`CompactBlock`]. /// /// [`CompactBlock`]: crate::proto::compact_formats::CompactBlock -pub struct ScannedBlock { +pub struct ScannedBlock { block_height: BlockHeight, block_hash: BlockHash, block_time: u32, - transactions: Vec>, + transactions: Vec>, sapling: ScannedBundles, #[cfg(feature = "orchard")] orchard: ScannedBundles, } -impl ScannedBlock { +impl ScannedBlock { /// Constructs a new `ScannedBlock` pub(crate) fn from_parts( block_height: BlockHeight, block_hash: BlockHash, block_time: u32, - transactions: Vec>, + transactions: Vec>, sapling: ScannedBundles, #[cfg(feature = "orchard")] orchard: ScannedBundles< orchard::note::NoteCommitment, @@ -750,7 +757,7 @@ impl ScannedBlock { } /// Returns the list of transactions from this block that are relevant to the wallet. - pub fn transactions(&self) -> &[WalletTx] { + pub fn transactions(&self) -> &[WalletTx] { &self.transactions } @@ -794,9 +801,9 @@ impl ScannedBlock { /// /// The purpose of this struct is to permit atomic updates of the /// wallet database when transactions are successfully decrypted. -pub struct DecryptedTransaction<'a> { +pub struct DecryptedTransaction<'a, AccountId> { pub tx: &'a Transaction, - pub sapling_outputs: &'a Vec>, + pub sapling_outputs: &'a Vec>, } /// A transaction that was constructed and sent by the wallet. @@ -804,11 +811,11 @@ pub struct DecryptedTransaction<'a> { /// The purpose of this struct is to permit atomic updates of the /// wallet database when transactions are created and submitted /// to the network. -pub struct SentTransaction<'a> { +pub struct SentTransaction<'a, AccountId> { pub tx: &'a Transaction, pub created: time::OffsetDateTime, pub account: AccountId, - pub outputs: Vec, + pub outputs: Vec>, pub fee_amount: Amount, #[cfg(feature = "transparent-inputs")] pub utxos_spent: Vec, @@ -817,14 +824,14 @@ pub struct SentTransaction<'a> { /// An output of a transaction generated by the wallet. /// /// This type is capable of representing both shielded and transparent outputs. -pub struct SentTransactionOutput { +pub struct SentTransactionOutput { output_index: usize, - recipient: Recipient, + recipient: Recipient, value: NonNegativeAmount, memo: Option, } -impl SentTransactionOutput { +impl SentTransactionOutput { /// Constructs a new [`SentTransactionOutput`] from its constituent parts. /// /// ### Fields: @@ -836,7 +843,7 @@ impl SentTransactionOutput { /// * `memo` - the memo that was sent with this output pub fn from_parts( output_index: usize, - recipient: Recipient, + recipient: Recipient, value: NonNegativeAmount, memo: Option, ) -> Self { @@ -859,7 +866,7 @@ impl SentTransactionOutput { } /// Returns the recipient address of the transaction, or the account id and /// resulting note for wallet-internal outputs. - pub fn recipient(&self) -> &Recipient { + pub fn recipient(&self) -> &Recipient { &self.recipient } /// Returns the value of the newly created output. @@ -999,6 +1006,7 @@ pub trait WalletWrite: WalletRead { /// /// Returns the account identifier for the newly-created wallet database entry, along with the /// associated [`UnifiedSpendingKey`]. + /// Note that the unique account identifier should *not* be assumed equivalent to the ZIP-32 account index. /// /// If `birthday.height()` is below the current chain tip, this operation will /// trigger a re-scan of the blocks at and above the provided height. The birthday height is @@ -1024,7 +1032,7 @@ pub trait WalletWrite: WalletRead { &mut self, seed: &SecretVec, birthday: AccountBirthday, - ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error>; + ) -> Result<(Self::AccountId, UnifiedSpendingKey), Self::Error>; /// Generates and persists the next available diversified address, given the current /// addresses known to the wallet. @@ -1033,7 +1041,7 @@ pub trait WalletWrite: WalletRead { /// account. fn get_next_available_address( &mut self, - account: AccountId, + account: &Self::AccountId, request: UnifiedAddressRequest, ) -> Result, Self::Error>; @@ -1044,7 +1052,7 @@ pub trait WalletWrite: WalletRead { /// `blocks` must be sequential, in order of increasing block height fn put_blocks( &mut self, - blocks: Vec>, + blocks: Vec>, ) -> Result<(), Self::Error>; /// Updates the wallet's view of the blockchain. @@ -1057,11 +1065,17 @@ pub trait WalletWrite: WalletRead { fn update_chain_tip(&mut self, tip_height: BlockHeight) -> Result<(), Self::Error>; /// Caches a decrypted transaction in the persistent wallet store. - fn store_decrypted_tx(&mut self, received_tx: DecryptedTransaction) -> Result<(), Self::Error>; + fn store_decrypted_tx( + &mut self, + received_tx: DecryptedTransaction, + ) -> Result<(), Self::Error>; /// Saves information about a transaction that was constructed and sent by the wallet to the /// persistent wallet store. - fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<(), Self::Error>; + fn store_sent_tx( + &mut self, + sent_tx: &SentTransaction, + ) -> Result<(), Self::Error>; /// Truncates the wallet database to the specified height. /// @@ -1167,7 +1181,7 @@ pub mod testing { consensus::{BlockHeight, Network}, memo::Memo, transaction::{components::Amount, Transaction, TxId}, - zip32::{AccountId, Scope}, + zip32::Scope, }; use crate::{ @@ -1218,6 +1232,7 @@ pub mod testing { impl InputSource for MockWalletDb { type Error = (); type NoteRef = u32; + type AccountId = u32; fn get_spendable_note( &self, @@ -1230,7 +1245,7 @@ pub mod testing { fn select_spendable_notes( &self, - _account: AccountId, + _account: &Self::AccountId, _target_value: Amount, _sources: &[ShieldedProtocol], _anchor_height: BlockHeight, @@ -1242,6 +1257,14 @@ pub mod testing { impl WalletRead for MockWalletDb { type Error = (); + type AccountId = u32; + + fn get_account( + &self, + _account_id: &Self::AccountId, + ) -> Result, Self::Error> { + Ok(None) + } fn chain_height(&self) -> Result, Self::Error> { Ok(None) @@ -1296,34 +1319,37 @@ pub mod testing { Ok(None) } - fn get_account_birthday(&self, _account: AccountId) -> Result { + fn get_account_birthday( + &self, + _account: &Self::AccountId, + ) -> Result { Err(()) } fn get_current_address( &self, - _account: AccountId, + _account: &Self::AccountId, ) -> Result, Self::Error> { Ok(None) } fn get_unified_full_viewing_keys( &self, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> { Ok(HashMap::new()) } fn get_account_for_ufvk( &self, _ufvk: &UnifiedFullViewingKey, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> { Ok(None) } fn get_wallet_summary( &self, _min_confirmations: u32, - ) -> Result, Self::Error> { + ) -> Result>, Self::Error> { Ok(None) } @@ -1338,7 +1364,7 @@ pub mod testing { fn get_sapling_nullifiers( &self, _query: NullifierQuery, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> { Ok(Vec::new()) } @@ -1346,14 +1372,14 @@ pub mod testing { fn get_orchard_nullifiers( &self, _query: NullifierQuery, - ) -> Result, Self::Error> { + ) -> Result, Self::Error> { Ok(Vec::new()) } #[cfg(feature = "transparent-inputs")] fn get_transparent_receivers( &self, - _account: AccountId, + _account: &Self::AccountId, ) -> Result>, Self::Error> { Ok(HashMap::new()) @@ -1362,13 +1388,13 @@ pub mod testing { #[cfg(feature = "transparent-inputs")] fn get_transparent_balances( &self, - _account: AccountId, + _account: &Self::AccountId, _max_height: BlockHeight, ) -> Result, Self::Error> { Ok(HashMap::new()) } - fn get_account_ids(&self) -> Result, Self::Error> { + fn get_account_ids(&self) -> Result, Self::Error> { Ok(Vec::new()) } } @@ -1380,16 +1406,16 @@ pub mod testing { &mut self, seed: &SecretVec, _birthday: AccountBirthday, - ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> { - let account = AccountId::ZERO; + ) -> Result<(Self::AccountId, UnifiedSpendingKey), Self::Error> { + let account = zip32::AccountId::ZERO; UnifiedSpendingKey::from_seed(&self.network, seed.expose_secret(), account) - .map(|k| (account, k)) + .map(|k| (u32::from(account), k)) .map_err(|_| ()) } fn get_next_available_address( &mut self, - _account: AccountId, + _account: &Self::AccountId, _request: UnifiedAddressRequest, ) -> Result, Self::Error> { Ok(None) @@ -1398,7 +1424,7 @@ pub mod testing { #[allow(clippy::type_complexity)] fn put_blocks( &mut self, - _blocks: Vec>, + _blocks: Vec>, ) -> Result<(), Self::Error> { Ok(()) } @@ -1409,12 +1435,15 @@ pub mod testing { fn store_decrypted_tx( &mut self, - _received_tx: DecryptedTransaction, + _received_tx: DecryptedTransaction, ) -> Result<(), Self::Error> { Ok(()) } - fn store_sent_tx(&mut self, _sent_tx: &SentTransaction) -> Result<(), Self::Error> { + fn store_sent_tx( + &mut self, + _sent_tx: &SentTransaction, + ) -> Result<(), Self::Error> { Ok(()) } diff --git a/zcash_client_backend/src/data_api/chain.rs b/zcash_client_backend/src/data_api/chain.rs index 1555785f85..996beadc67 100644 --- a/zcash_client_backend/src/data_api/chain.rs +++ b/zcash_client_backend/src/data_api/chain.rs @@ -146,6 +146,7 @@ use std::ops::Range; use sapling::note_encryption::PreparedIncomingViewingKey; +use subtle::ConditionallySelectable; use zcash_primitives::{ consensus::{self, BlockHeight}, zip32::Scope, @@ -161,6 +162,8 @@ use crate::{ pub mod error; use error::Error; +use super::WalletRead; + /// A struct containing metadata about a subtree root of the note commitment tree. /// /// This stores the block height at which the leaf that completed the subtree was @@ -277,6 +280,7 @@ where ParamsT: consensus::Parameters + Send + 'static, BlockSourceT: BlockSource, DbT: WalletWrite, + ::AccountId: Clone + ConditionallySelectable + Default + Send + 'static, { // Fetch the UnifiedFullViewingKeys we are tracking let ufvks = data_db @@ -374,7 +378,7 @@ where sapling_nullifiers.extend(scanned_block.transactions.iter().flat_map(|tx| { tx.sapling_outputs .iter() - .map(|out| (out.account(), *out.nf())) + .map(|out| (*out.account(), *out.nf())) })); prior_block_metadata = Some(scanned_block.to_block_metadata()); diff --git a/zcash_client_backend/src/data_api/wallet.rs b/zcash_client_backend/src/data_api/wallet.rs index c8f383a169..33c14b6599 100644 --- a/zcash_client_backend/src/data_api/wallet.rs +++ b/zcash_client_backend/src/data_api/wallet.rs @@ -1,11 +1,15 @@ -use std::num::NonZeroU32; - use nonempty::NonEmpty; use rand_core::OsRng; use sapling::{ note_encryption::{try_sapling_note_decryption, PreparedIncomingViewingKey}, prover::{OutputProver, SpendProver}, }; +use std::num::NonZeroU32; + +use zcash_keys::{ + address::UnifiedAddress, + keys::{UnifiedAddressRequest, UnifiedFullViewingKey}, +}; use zcash_primitives::{ consensus::{self, BlockHeight, NetworkUpgrade}, memo::MemoBytes, @@ -15,9 +19,11 @@ use zcash_primitives::{ fees::{zip317::FeeError as Zip317FeeError, FeeRule, StandardFeeRule}, Transaction, TxId, }, - zip32::{AccountId, Scope}, + zip32::Scope, }; +use zip32::DiversifierIndex; +use super::InputSource; use crate::{ address::Address, data_api::{ @@ -33,13 +39,6 @@ use crate::{ PoolType, ShieldedProtocol, }; -pub mod input_selection; -use input_selection::{ - GreedyInputSelector, GreedyInputSelectorError, InputSelector, InputSelectorError, -}; - -use super::InputSource; - #[cfg(feature = "transparent-inputs")] use { input_selection::ShieldingSelector, @@ -49,6 +48,50 @@ use { zcash_primitives::transaction::components::{OutPoint, TxOut}, }; +pub mod input_selection; +use input_selection::{ + GreedyInputSelector, GreedyInputSelectorError, InputSelector, InputSelectorError, +}; + +/// The inputs for adding an account to a wallet. +#[derive(Debug, Clone)] +pub enum Account { + /// An account that was derived from a ZIP-32 HD seed and account index. + Zip32 { + account_id: zip32::AccountId, + ufvk: UnifiedFullViewingKey, + }, + /// An account for which the seed and ZIP-32 account ID are not known. + ImportedUfvk(UnifiedFullViewingKey), +} + +impl Account { + /// Gets the default UA for the account. + pub fn default_address( + &self, + request: UnifiedAddressRequest, + ) -> (UnifiedAddress, DiversifierIndex) { + match self { + Account::Zip32 { ufvk, .. } => ufvk.default_address(request), + Account::ImportedUfvk(ufvk) => ufvk.default_address(request), + } + } + + /// Gets the unified full viewing key for this account, if available. + /// + /// Accounts initialized with an incoming viewing key will not have a unified full viewing key. + pub fn ufvk(&self) -> Option<&UnifiedFullViewingKey> { + match self { + Account::Zip32 { ufvk, .. } => Some(ufvk), + Account::ImportedUfvk(ufvk) => Some(ufvk), + } + } + + // TODO: When a UnifiedIncomingViewingKey is available, add a function that + // will return it. Even if the Account was initialized with a UFVK, we can always + // derive a UIVK from that. +} + /// Scans a [`Transaction`] for any information that can be decrypted by the accounts in /// the wallet, and saves it to the wallet. pub fn decrypt_and_store_transaction( @@ -59,6 +102,7 @@ pub fn decrypt_and_store_transaction( where ParamsT: consensus::Parameters, DbT: WalletWrite, + ::AccountId: Clone, { // Fetch the UnifiedFullViewingKeys we are tracking let ufvks = data.get_unified_full_viewing_keys()?; @@ -222,7 +266,13 @@ pub fn create_spend_to_address( > where ParamsT: consensus::Parameters + Clone, - DbT: WalletWrite + WalletCommitmentTrees + InputSource::Error>, + DbT: InputSource, + DbT: WalletWrite< + Error = ::Error, + AccountId = ::AccountId, + >, + DbT: WalletCommitmentTrees, + ::AccountId: Clone, ::NoteRef: Copy + Eq + Ord, { let account = wallet_db @@ -328,7 +378,13 @@ pub fn spend( >, > where - DbT: WalletWrite + WalletCommitmentTrees + InputSource::Error>, + DbT: InputSource, + DbT: WalletWrite< + Error = ::Error, + AccountId = ::AccountId, + >, + DbT: WalletCommitmentTrees, + ::AccountId: Clone, ::NoteRef: Copy + Eq + Ord, ParamsT: consensus::Parameters + Clone, InputsT: InputSelector, @@ -366,7 +422,7 @@ where pub fn propose_transfer( wallet_db: &mut DbT, params: &ParamsT, - spend_from_account: AccountId, + spend_from_account: ::AccountId, input_selector: &InputsT, request: zip321::TransactionRequest, min_confirmations: NonZeroU32, @@ -433,7 +489,7 @@ pub fn propose_standard_transfer_to_address( wallet_db: &mut DbT, params: &ParamsT, fee_rule: StandardFeeRule, - spend_from_account: AccountId, + spend_from_account: ::AccountId, min_confirmations: NonZeroU32, to: &Address, amount: NonNegativeAmount, @@ -451,7 +507,12 @@ pub fn propose_standard_transfer_to_address( > where ParamsT: consensus::Parameters + Clone, - DbT: WalletRead + InputSource::Error>, + DbT: InputSource, + DbT: WalletRead< + Error = ::Error, + AccountId = ::AccountId, + >, + ::AccountId: Clone, DbT::NoteRef: Copy + Eq + Ord, { let request = zip321::TransactionRequest::new(vec![Payment { @@ -559,6 +620,7 @@ pub fn create_proposed_transactions( > where DbT: WalletWrite + WalletCommitmentTrees, + ::AccountId: Clone, ParamsT: consensus::Parameters + Clone, FeeRuleT: FeeRule, { @@ -612,6 +674,7 @@ fn create_proposed_transaction( > where DbT: WalletWrite + WalletCommitmentTrees, + ::AccountId: Clone, ParamsT: consensus::Parameters + Clone, FeeRuleT: FeeRule, { @@ -756,7 +819,7 @@ where #[cfg(feature = "transparent-inputs")] let utxos_spent = { let known_addrs = wallet_db - .get_transparent_receivers(account) + .get_transparent_receivers(&account) .map_err(Error::DataSource)?; let mut utxos_spent: Vec = vec![]; @@ -1006,7 +1069,7 @@ where )?; sapling_output_meta.push(( Recipient::InternalAccount( - account, + account.clone(), PoolType::Shielded(ShieldedProtocol::Sapling), ), change_value.value(), @@ -1029,7 +1092,7 @@ where )?; orchard_output_meta.push(( Recipient::InternalAccount( - account, + account.clone(), PoolType::Shielded(ShieldedProtocol::Orchard), ), change_value.value(), @@ -1057,7 +1120,7 @@ where .expect("An action should exist in the transaction for each Orchard output."); let recipient = recipient - .map_internal_account(|pool| { + .map_internal_note(|pool| { assert!(pool == PoolType::Shielded(ShieldedProtocol::Orchard)); build_result .transaction() @@ -1068,7 +1131,7 @@ where .map(|(note, _, _)| Note::Orchard(note)) }) }) - .internal_account_transpose_option() + .internal_note_transpose_option() .expect("Wallet-internal outputs must be decryptable with the wallet's IVK"); SentTransactionOutput::from_parts(output_index, recipient, value, memo) @@ -1087,7 +1150,7 @@ where .expect("An output should exist in the transaction for each Sapling payment."); let recipient = recipient - .map_internal_account(|pool| { + .map_internal_note(|pool| { assert!(pool == PoolType::Shielded(ShieldedProtocol::Sapling)); build_result .transaction() @@ -1104,7 +1167,7 @@ where .map(|(note, _, _)| Note::Sapling(note)) }) }) - .internal_account_transpose_option() + .internal_note_transpose_option() .expect("Wallet-internal outputs must be decryptable with the wallet's IVK"); SentTransactionOutput::from_parts(output_index, recipient, value, memo) @@ -1206,6 +1269,7 @@ pub fn shield_transparent_funds( where ParamsT: consensus::Parameters, DbT: WalletWrite + WalletCommitmentTrees + InputSource::Error>, + ::AccountId: Clone, InputsT: ShieldingSelector, { let proposal = propose_shielding( diff --git a/zcash_client_backend/src/data_api/wallet/input_selection.rs b/zcash_client_backend/src/data_api/wallet/input_selection.rs index 45502c379f..2e56f32f06 100644 --- a/zcash_client_backend/src/data_api/wallet/input_selection.rs +++ b/zcash_client_backend/src/data_api/wallet/input_selection.rs @@ -17,7 +17,6 @@ use zcash_primitives::{ }, fees::FeeRule, }, - zip32::AccountId, }; use crate::{ @@ -149,7 +148,7 @@ pub trait InputSelector { wallet_db: &Self::InputSource, target_height: BlockHeight, anchor_height: BlockHeight, - account: AccountId, + account: ::AccountId, transaction_request: TransactionRequest, ) -> Result< Proposal::NoteRef>, @@ -315,6 +314,7 @@ impl GreedyInputSelector { impl InputSelector for GreedyInputSelector where DbT: InputSource, + ::AccountId: Clone, ChangeT: ChangeStrategy, ChangeT::FeeRule: Clone, { @@ -329,7 +329,7 @@ where wallet_db: &Self::InputSource, target_height: BlockHeight, anchor_height: BlockHeight, - account: AccountId, + account: ::AccountId, transaction_request: TransactionRequest, ) -> Result< Proposal, @@ -462,7 +462,7 @@ where shielded_inputs = wallet_db .select_spendable_notes( - account, + &account, amount_required.into(), selectable_pools, anchor_height, diff --git a/zcash_client_backend/src/decrypt.rs b/zcash_client_backend/src/decrypt.rs index 62c1fb8bdf..e658b0658a 100644 --- a/zcash_client_backend/src/decrypt.rs +++ b/zcash_client_backend/src/decrypt.rs @@ -7,7 +7,7 @@ use zcash_primitives::{ consensus::{self, BlockHeight}, memo::MemoBytes, transaction::Transaction, - zip32::{AccountId, Scope}, + zip32::Scope, }; use crate::keys::UnifiedFullViewingKey; @@ -27,7 +27,7 @@ pub enum TransferType { } /// A decrypted shielded output. -pub struct DecryptedOutput { +pub struct DecryptedOutput { /// The index of the output within [`shielded_outputs`]. /// /// [`shielded_outputs`]: zcash_primitives::transaction::TransactionData @@ -47,12 +47,12 @@ pub struct DecryptedOutput { /// Scans a [`Transaction`] for any information that can be decrypted by the set of /// [`UnifiedFullViewingKey`]s. -pub fn decrypt_transaction( +pub fn decrypt_transaction( params: &P, height: BlockHeight, tx: &Transaction, - ufvks: &HashMap, -) -> Vec> { + ufvks: &HashMap, +) -> Vec> { let zip212_enforcement = consensus::sapling_zip212_enforcement(params, height); tx.sapling_bundle() .iter() @@ -60,7 +60,9 @@ pub fn decrypt_transaction( ufvks .iter() .flat_map(move |(account, ufvk)| { - ufvk.sapling().into_iter().map(|dfvk| (*account, dfvk)) + ufvk.sapling() + .into_iter() + .map(|dfvk| (account.to_owned(), dfvk)) }) .flat_map(move |(account, dfvk)| { let ivk_external = @@ -74,6 +76,7 @@ pub fn decrypt_transaction( .iter() .enumerate() .flat_map(move |(index, output)| { + let account = account.clone(); try_sapling_note_decryption(&ivk_external, output, zip212_enforcement) .map(|ret| (ret, TransferType::Incoming)) .or_else(|| { @@ -92,7 +95,7 @@ pub fn decrypt_transaction( .map(move |((note, _, memo), transfer_type)| DecryptedOutput { index, note, - account, + account: account.clone(), memo: MemoBytes::from_bytes(&memo).expect("correct length"), transfer_type, }) diff --git a/zcash_client_backend/src/scanning.rs b/zcash_client_backend/src/scanning.rs index 38d03881fb..988f939c0b 100644 --- a/zcash_client_backend/src/scanning.rs +++ b/zcash_client_backend/src/scanning.rs @@ -3,6 +3,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; use std::fmt::{self, Debug}; +use std::hash::Hash; use incrementalmerkletree::{Position, Retention}; use sapling::{ @@ -13,10 +14,7 @@ use sapling::{ use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption}; use zcash_note_encryption::batch; use zcash_primitives::consensus::{BlockHeight, NetworkUpgrade}; -use zcash_primitives::{ - consensus, - zip32::{AccountId, Scope}, -}; +use zcash_primitives::{consensus, zip32::Scope}; use crate::data_api::{BlockMetadata, ScannedBlock, ScannedBundles}; use crate::{ @@ -251,14 +249,18 @@ impl fmt::Display for ScanError { /// [`IncrementalWitness`]: sapling::IncrementalWitness /// [`WalletSaplingOutput`]: crate::wallet::WalletSaplingOutput /// [`WalletTx`]: crate::wallet::WalletTx -pub fn scan_block( +pub fn scan_block< + P: consensus::Parameters + Send + 'static, + K: ScanningKey, + A: Clone + Default + Eq + Hash + Send + ConditionallySelectable + 'static, +>( params: &P, block: CompactBlock, - vks: &[(&AccountId, &K)], - sapling_nullifiers: &[(AccountId, sapling::Nullifier)], + vks: &[(&A, &K)], + sapling_nullifiers: &[(A, sapling::Nullifier)], prior_block_metadata: Option<&BlockMetadata>, -) -> Result, ScanError> { - scan_block_with_runner::<_, _, ()>( +) -> Result, ScanError> { + scan_block_with_runner::<_, _, (), A>( params, block, vks, @@ -268,20 +270,20 @@ pub fn scan_block( ) } -type TaggedBatch = - Batch<(AccountId, S), SaplingDomain, CompactOutputDescription, CompactDecryptor>; -type TaggedBatchRunner = - BatchRunner<(AccountId, S), SaplingDomain, CompactOutputDescription, CompactDecryptor, T>; +type TaggedBatch = Batch<(A, S), SaplingDomain, CompactOutputDescription, CompactDecryptor>; +type TaggedBatchRunner = + BatchRunner<(A, S), SaplingDomain, CompactOutputDescription, CompactDecryptor, T>; #[tracing::instrument(skip_all, fields(height = block.height))] -pub(crate) fn add_block_to_runner( +pub(crate) fn add_block_to_runner( params: &P, block: CompactBlock, - batch_runner: &mut TaggedBatchRunner, + batch_runner: &mut TaggedBatchRunner, ) where P: consensus::Parameters + Send + 'static, S: Clone + Send + 'static, - T: Tasks>, + T: Tasks>, + A: Clone + Default + Eq + Send + 'static, { let block_hash = block.hash(); let block_height = block.height(); @@ -333,15 +335,16 @@ fn check_hash_continuity( pub(crate) fn scan_block_with_runner< P: consensus::Parameters + Send + 'static, K: ScanningKey, - T: Tasks> + Sync, + T: Tasks> + Sync, + A: Send + Clone + Default + Eq + Hash + ConditionallySelectable + 'static, >( params: &P, block: CompactBlock, - vks: &[(&AccountId, K)], - nullifiers: &[(AccountId, sapling::Nullifier)], + vks: &[(&A, K)], + nullifiers: &[(A, sapling::Nullifier)], prior_block_metadata: Option<&BlockMetadata>, - mut batch_runner: Option<&mut TaggedBatchRunner>, -) -> Result, ScanError> { + mut batch_runner: Option<&mut TaggedBatchRunner>, +) -> Result, ScanError> { if let Some(scan_error) = check_hash_continuity(&block, prior_block_metadata) { return Err(scan_error); } @@ -444,7 +447,7 @@ pub(crate) fn scan_block_with_runner< )?; let compact_block_tx_count = block.vtx.len(); - let mut wtxs: Vec> = vec![]; + let mut wtxs: Vec> = vec![]; let mut sapling_nullifier_map = Vec::with_capacity(block.vtx.len()); let mut sapling_note_commitments: Vec<(sapling::Node, Retention)> = vec![]; for (tx_idx, tx) in block.vtx.into_iter().enumerate() { @@ -468,7 +471,7 @@ pub(crate) fn scan_block_with_runner< let spend = nullifiers .iter() .map(|&(account, nf)| CtOption::new(account, nf.ct_eq(&spend_nf))) - .fold(CtOption::new(AccountId::ZERO, 0.into()), |first, next| { + .fold(CtOption::new(A::default(), 0.into()), |first, next| { CtOption::conditional_select(&next, &first, first.is_some()) }) .map(|account| WalletSaplingSpend::from_parts(index, spend_nf, account)); @@ -498,7 +501,7 @@ pub(crate) fn scan_block_with_runner< u32::try_from(tx.actions.len()).expect("Orchard action count cannot exceed a u32"); // Check for incoming notes while incrementing tree and witnesses - let mut shielded_outputs: Vec> = vec![]; + let mut shielded_outputs: Vec> = vec![]; { let decoded = &tx .outputs @@ -874,7 +877,7 @@ mod tests { assert_eq!(tx.sapling_spends.len(), 0); assert_eq!(tx.sapling_outputs.len(), 1); assert_eq!(tx.sapling_outputs[0].index(), 0); - assert_eq!(tx.sapling_outputs[0].account(), account); + assert_eq!(*tx.sapling_outputs[0].account(), account); assert_eq!(tx.sapling_outputs[0].note().value().inner(), 5); assert_eq!( tx.sapling_outputs[0].note_commitment_tree_position(), @@ -955,7 +958,7 @@ mod tests { assert_eq!(tx.sapling_spends.len(), 0); assert_eq!(tx.sapling_outputs.len(), 1); assert_eq!(tx.sapling_outputs[0].index(), 0); - assert_eq!(tx.sapling_outputs[0].account(), AccountId::ZERO); + assert_eq!(*tx.sapling_outputs[0].account(), AccountId::ZERO); assert_eq!(tx.sapling_outputs[0].note().value().inner(), 5); assert_eq!( @@ -1010,7 +1013,7 @@ mod tests { assert_eq!(tx.sapling_outputs.len(), 0); assert_eq!(tx.sapling_spends[0].index(), 0); assert_eq!(tx.sapling_spends[0].nf(), &nf); - assert_eq!(tx.sapling_spends[0].account(), account); + assert_eq!(tx.sapling_spends[0].account().to_owned(), account); assert_eq!( scanned_block diff --git a/zcash_client_backend/src/wallet.rs b/zcash_client_backend/src/wallet.rs index c12143df76..d33dc5565a 100644 --- a/zcash_client_backend/src/wallet.rs +++ b/zcash_client_backend/src/wallet.rs @@ -14,7 +14,7 @@ use zcash_primitives::{ fees::transparent as transparent_fees, TxId, }, - zip32::{AccountId, Scope}, + zip32::Scope, }; use crate::{address::UnifiedAddress, fees::sapling as sapling_fees, PoolType, ShieldedProtocol}; @@ -65,15 +65,15 @@ impl NoteId { /// internal account ID and the pool to which funds were sent in the case of a wallet-internal /// output. #[derive(Debug, Clone)] -pub enum Recipient { +pub enum Recipient { Transparent(TransparentAddress), Sapling(sapling::PaymentAddress), Unified(UnifiedAddress, PoolType), InternalAccount(AccountId, N), } -impl Recipient { - pub fn map_internal_account B>(self, f: F) -> Recipient { +impl Recipient { + pub fn map_internal_note B>(self, f: F) -> Recipient { match self { Recipient::Transparent(t) => Recipient::Transparent(t), Recipient::Sapling(s) => Recipient::Sapling(s), @@ -83,8 +83,8 @@ impl Recipient { } } -impl Recipient> { - pub fn internal_account_transpose_option(self) -> Option> { +impl Recipient> { + pub fn internal_note_transpose_option(self) -> Option> { match self { Recipient::Transparent(t) => Some(Recipient::Transparent(t)), Recipient::Sapling(s) => Some(Recipient::Sapling(s)), @@ -97,11 +97,11 @@ impl Recipient> { /// A subset of a [`Transaction`] relevant to wallets and light clients. /// /// [`Transaction`]: zcash_primitives::transaction::Transaction -pub struct WalletTx { +pub struct WalletTx { pub txid: TxId, pub index: usize, - pub sapling_spends: Vec, - pub sapling_outputs: Vec>, + pub sapling_spends: Vec>, + pub sapling_outputs: Vec>, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -161,13 +161,13 @@ impl transparent_fees::InputView for WalletTransparentOutput { /// A subset of a [`SpendDescription`] relevant to wallets and light clients. /// /// [`SpendDescription`]: sapling::bundle::SpendDescription -pub struct WalletSaplingSpend { +pub struct WalletSaplingSpend { index: usize, nf: sapling::Nullifier, account: AccountId, } -impl WalletSaplingSpend { +impl WalletSaplingSpend { pub fn from_parts(index: usize, nf: sapling::Nullifier, account: AccountId) -> Self { Self { index, nf, account } } @@ -178,8 +178,8 @@ impl WalletSaplingSpend { pub fn nf(&self) -> &sapling::Nullifier { &self.nf } - pub fn account(&self) -> AccountId { - self.account + pub fn account(&self) -> &AccountId { + &self.account } } @@ -195,11 +195,11 @@ impl WalletSaplingSpend { /// `()` for sent notes. /// /// [`OutputDescription`]: sapling::bundle::OutputDescription -pub struct WalletSaplingOutput { +pub struct WalletSaplingOutput { index: usize, cmu: sapling::note::ExtractedNoteCommitment, ephemeral_key: EphemeralKeyBytes, - account: AccountId, + account: A, note: sapling::Note, is_change: bool, note_commitment_tree_position: Position, @@ -207,14 +207,14 @@ pub struct WalletSaplingOutput { recipient_key_scope: S, } -impl WalletSaplingOutput { +impl WalletSaplingOutput { /// Constructs a new `WalletSaplingOutput` value from its constituent parts. #[allow(clippy::too_many_arguments)] pub fn from_parts( index: usize, cmu: sapling::note::ExtractedNoteCommitment, ephemeral_key: EphemeralKeyBytes, - account: AccountId, + account: A, note: sapling::Note, is_change: bool, note_commitment_tree_position: Position, @@ -243,8 +243,8 @@ impl WalletSaplingOutput { pub fn ephemeral_key(&self) -> &EphemeralKeyBytes { &self.ephemeral_key } - pub fn account(&self) -> AccountId { - self.account + pub fn account(&self) -> &A { + &self.account } pub fn note(&self) -> &sapling::Note { &self.note diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 365d9b1833..fbe15c53a5 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -179,6 +179,7 @@ impl WalletDb { impl, P: consensus::Parameters> InputSource for WalletDb { type Error = SqliteClientError; type NoteRef = ReceivedNoteId; + type AccountId = AccountId; fn get_spendable_note( &self, @@ -199,7 +200,7 @@ impl, P: consensus::Parameters> InputSource for fn select_spendable_notes( &self, - account: AccountId, + account: &AccountId, target_value: Amount, _sources: &[ShieldedProtocol], anchor_height: BlockHeight, @@ -208,7 +209,7 @@ impl, P: consensus::Parameters> InputSource for wallet::sapling::select_spendable_sapling_notes( self.conn.borrow(), &self.params, - account, + *account, target_value, anchor_height, exclude, @@ -242,6 +243,7 @@ impl, P: consensus::Parameters> InputSource for impl, P: consensus::Parameters> WalletRead for WalletDb { type Error = SqliteClientError; + type AccountId = AccountId; fn chain_height(&self) -> Result, Self::Error> { wallet::scan_queue_extrema(self.conn.borrow()) @@ -294,15 +296,15 @@ impl, P: consensus::Parameters> WalletRead for W wallet::wallet_birthday(self.conn.borrow()).map_err(SqliteClientError::from) } - fn get_account_birthday(&self, account: AccountId) -> Result { - wallet::account_birthday(self.conn.borrow(), account).map_err(SqliteClientError::from) + fn get_account_birthday(&self, account: &AccountId) -> Result { + wallet::account_birthday(self.conn.borrow(), *account).map_err(SqliteClientError::from) } fn get_current_address( &self, - account: AccountId, + account: &AccountId, ) -> Result, Self::Error> { - wallet::get_current_address(self.conn.borrow(), &self.params, account) + wallet::get_current_address(self.conn.borrow(), &self.params, *account) .map(|res| res.map(|(addr, _)| addr)) } @@ -322,7 +324,7 @@ impl, P: consensus::Parameters> WalletRead for W fn get_wallet_summary( &self, min_confirmations: u32, - ) -> Result, Self::Error> { + ) -> Result>, Self::Error> { // This will return a runtime error if we call `get_wallet_summary` from two // threads at the same time, as transactions cannot nest. wallet::get_wallet_summary( @@ -356,34 +358,41 @@ impl, P: consensus::Parameters> WalletRead for W } } + #[cfg(feature = "orchard")] + fn get_orchard_nullifiers( + &self, + _query: NullifierQuery, + ) -> Result, Self::Error> { + todo!() + } + #[cfg(feature = "transparent-inputs")] fn get_transparent_receivers( &self, - _account: AccountId, + account: &AccountId, ) -> Result>, Self::Error> { - wallet::get_transparent_receivers(self.conn.borrow(), &self.params, _account) + wallet::get_transparent_receivers(self.conn.borrow(), &self.params, *account) } #[cfg(feature = "transparent-inputs")] fn get_transparent_balances( &self, - _account: AccountId, - _max_height: BlockHeight, + account: &AccountId, + max_height: BlockHeight, ) -> Result, Self::Error> { - wallet::get_transparent_balances(self.conn.borrow(), &self.params, _account, _max_height) - } - - #[cfg(feature = "orchard")] - fn get_orchard_nullifiers( - &self, - _query: NullifierQuery, - ) -> Result, Self::Error> { - todo!() + wallet::get_transparent_balances(self.conn.borrow(), &self.params, *account, max_height) } fn get_account_ids(&self) -> Result, Self::Error> { wallet::get_account_ids(self.conn.borrow()) } + + fn get_account( + &self, + account_id: &Self::AccountId, + ) -> Result, Self::Error> { + wallet::get_account(self.conn.borrow(), &self.params, *account_id) + } } impl WalletWrite for WalletDb { @@ -412,14 +421,14 @@ impl WalletWrite for WalletDb fn get_next_available_address( &mut self, - account: AccountId, + account: &AccountId, request: UnifiedAddressRequest, ) -> Result, Self::Error> { self.transactionally( - |wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) { + |wdb| match wdb.get_unified_full_viewing_keys()?.get(account) { Some(ufvk) => { let search_from = - match wallet::get_current_address(wdb.conn.0, &wdb.params, account)? { + match wallet::get_current_address(wdb.conn.0, &wdb.params, *account)? { Some((_, mut last_diversifier_index)) => { last_diversifier_index .increment() @@ -436,7 +445,7 @@ impl WalletWrite for WalletDb wallet::insert_address( wdb.conn.0, &wdb.params, - account, + *account, diversifier_index, &addr, )?; @@ -452,7 +461,7 @@ impl WalletWrite for WalletDb #[allow(clippy::type_complexity)] fn put_blocks( &mut self, - blocks: Vec>, + blocks: Vec>, ) -> Result<(), Self::Error> { self.transactionally(|wdb| { let start_positions = blocks.first().map(|block| { @@ -591,7 +600,10 @@ impl WalletWrite for WalletDb Ok(()) } - fn store_decrypted_tx(&mut self, d_tx: DecryptedTransaction) -> Result<(), Self::Error> { + fn store_decrypted_tx( + &mut self, + d_tx: DecryptedTransaction, + ) -> Result<(), Self::Error> { self.transactionally(|wdb| { let tx_ref = wallet::put_tx_data(wdb.conn.0, d_tx.tx, None, None)?; @@ -683,7 +695,7 @@ impl WalletWrite for WalletDb }) } - fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<(), Self::Error> { + fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<(), Self::Error> { self.transactionally(|wdb| { let tx_ref = wallet::put_tx_data( wdb.conn.0, @@ -1212,18 +1224,18 @@ mod tests { .build(); let account = AccountId::ZERO; - let current_addr = st.wallet().get_current_address(account).unwrap(); + let current_addr = st.wallet().get_current_address(&account).unwrap(); assert!(current_addr.is_some()); // TODO: Add Orchard let addr2 = st .wallet_mut() - .get_next_available_address(account, DEFAULT_UA_REQUEST) + .get_next_available_address(&account, DEFAULT_UA_REQUEST) .unwrap(); assert!(addr2.is_some()); assert_ne!(current_addr, addr2); - let addr2_cur = st.wallet().get_current_address(account).unwrap(); + let addr2_cur = st.wallet().get_current_address(&account).unwrap(); assert_eq!(addr2, addr2_cur); } @@ -1242,7 +1254,7 @@ mod tests { let receivers = st .wallet() - .get_transparent_receivers(AccountId::ZERO) + .get_transparent_receivers(&AccountId::ZERO) .unwrap(); // The receiver for the default UA should be in the set. diff --git a/zcash_client_sqlite/src/testing.rs b/zcash_client_sqlite/src/testing.rs index 43e1c8f54a..4c164b3d1c 100644 --- a/zcash_client_sqlite/src/testing.rs +++ b/zcash_client_sqlite/src/testing.rs @@ -731,7 +731,10 @@ impl TestState { }) } - pub(crate) fn get_wallet_summary(&self, min_confirmations: u32) -> Option { + pub(crate) fn get_wallet_summary( + &self, + min_confirmations: u32, + ) -> Option> { get_wallet_summary( &self.wallet().conn.unchecked_transaction().unwrap(), &self.wallet().params, diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index b0b1aa03e4..d25c0893e4 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -67,42 +67,42 @@ use incrementalmerkletree::Retention; use rusqlite::{self, named_params, OptionalExtension}; use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree}; -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::convert::TryFrom; use std::io::{self, Cursor}; use std::num::NonZeroU32; use std::ops::RangeInclusive; use tracing::debug; -use zcash_client_backend::data_api::{AccountBalance, Ratio, WalletSummary}; -use zcash_client_backend::wallet::Note; -use zcash_primitives::transaction::components::amount::NonNegativeAmount; -use zcash_primitives::zip32::Scope; - -use zcash_primitives::{ - block::BlockHash, - consensus::{self, BlockHeight, BranchId, NetworkUpgrade, Parameters}, - memo::{Memo, MemoBytes}, - merkle_tree::read_commitment_tree, - transaction::{components::Amount, Transaction, TransactionData, TxId}, - zip32::{AccountId, DiversifierIndex}, -}; use zcash_client_backend::{ address::{Address, UnifiedAddress}, data_api::{ scanning::{ScanPriority, ScanRange}, - AccountBirthday, BlockMetadata, SentTransactionOutput, SAPLING_SHARD_HEIGHT, + wallet::Account, + AccountBalance, AccountBirthday, BlockMetadata, Ratio, SentTransactionOutput, + WalletSummary, SAPLING_SHARD_HEIGHT, }, encoding::AddressCodec, keys::UnifiedFullViewingKey, - wallet::{NoteId, Recipient, WalletTx}, + wallet::{Note, NoteId, Recipient, WalletTx}, PoolType, ShieldedProtocol, }; +use zcash_primitives::{ + block::BlockHash, + consensus::{self, BlockHeight, BranchId, NetworkUpgrade, Parameters}, + memo::{Memo, MemoBytes}, + merkle_tree::read_commitment_tree, + transaction::{ + components::{amount::NonNegativeAmount, Amount}, + Transaction, TransactionData, TxId, + }, + zip32::{AccountId, DiversifierIndex, Scope}, +}; -use crate::wallet::commitment_tree::{get_max_checkpointed_height, SqliteShardStore}; -use crate::DEFAULT_UA_REQUEST; use crate::{ - error::SqliteClientError, SqlTransaction, WalletCommitmentTrees, WalletDb, PRUNING_DEPTH, + error::SqliteClientError, + wallet::commitment_tree::{get_max_checkpointed_height, SqliteShardStore}, + SqlTransaction, WalletCommitmentTrees, WalletDb, DEFAULT_UA_REQUEST, PRUNING_DEPTH, SAPLING_TABLES_PREFIX, }; @@ -603,7 +603,7 @@ pub(crate) fn get_wallet_summary( params: &P, min_confirmations: u32, progress: &impl ScanProgress, -) -> Result, SqliteClientError> { +) -> Result>, SqliteClientError> { let chain_tip_height = match scan_queue_extrema(tx)? { Some(range) => *range.end(), None => { @@ -655,7 +655,7 @@ pub(crate) fn get_wallet_summary( .map_err(|_| SqliteClientError::AccountIdOutOfRange) .map(|a| (a, AccountBalance::ZERO)) }) - .collect::, _>>()?; + .collect::, _>>()?; let sapling_trace = tracing::info_span!("stmt_select_notes").entered(); let mut stmt_select_notes = tx.prepare_cached( @@ -985,6 +985,29 @@ pub(crate) fn block_height_extrema( }) } +pub(crate) fn get_account( + conn: &rusqlite::Connection, + params: &P, + account_id: AccountId, +) -> Result, SqliteClientError> { + conn.query_row( + r#" + SELECT ufvk + FROM accounts + WHERE id = :account_id + "#, + named_params![":account_id": u32::from(account_id)], + |row| row.get::<_, String>(0), + ) + .optional()? + .map(|ufvk_bytes| { + let ufvk = UnifiedFullViewingKey::decode(params, &ufvk_bytes) + .map_err(SqliteClientError::CorruptedData)?; + Ok(Account::Zip32 { account_id, ufvk }) + }) + .transpose() +} + /// Returns the minimum and maximum heights of blocks in the chain which may be scanned. pub(crate) fn scan_queue_extrema( conn: &rusqlite::Connection, @@ -1608,7 +1631,7 @@ pub(crate) fn put_block( /// contain a note related to this wallet into the database. pub(crate) fn put_tx_meta( conn: &rusqlite::Connection, - tx: &WalletTx, + tx: &WalletTx, height: BlockHeight, ) -> Result { // It isn't there, so insert our transaction into the database. @@ -1795,7 +1818,7 @@ pub(crate) fn update_expired_notes( // and `put_sent_output` fn recipient_params( params: &P, - to: &Recipient, + to: &Recipient, ) -> (Option, Option, PoolType) { match to { Recipient::Transparent(addr) => (Some(addr.encode(params)), None, PoolType::Transparent), @@ -1819,7 +1842,7 @@ pub(crate) fn insert_sent_output( params: &P, tx_ref: i64, from_account: AccountId, - output: &SentTransactionOutput, + output: &SentTransactionOutput, ) -> Result<(), SqliteClientError> { let mut stmt_insert_sent_output = conn.prepare_cached( "INSERT INTO sent_notes ( @@ -1865,7 +1888,7 @@ pub(crate) fn put_sent_output( from_account: AccountId, tx_ref: i64, output_index: usize, - recipient: &Recipient, + recipient: &Recipient, value: NonNegativeAmount, memo: Option<&MemoBytes>, ) -> Result<(), SqliteClientError> { @@ -2034,7 +2057,7 @@ pub(crate) fn query_nullifier_map, S>( // change or explicit in-wallet recipient. put_tx_meta( conn, - &WalletTx:: { + &WalletTx:: { txid, index, sapling_spends: vec![], @@ -2111,14 +2134,14 @@ mod tests { // The default address is set for the test account assert_matches!( - st.wallet().get_current_address(AccountId::ZERO), + st.wallet().get_current_address(&AccountId::ZERO), Ok(Some(_)) ); // No default address is set for an un-initialized account assert_matches!( st.wallet() - .get_current_address(AccountId::try_from(1).unwrap()), + .get_current_address(&AccountId::try_from(1).unwrap()), Ok(None) ); } @@ -2135,7 +2158,7 @@ mod tests { let (account_id, _, _) = st.test_account().unwrap(); let uaddr = st .wallet() - .get_current_address(account_id) + .get_current_address(&account_id) .unwrap() .unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -2143,7 +2166,7 @@ mod tests { let height_1 = BlockHeight::from_u32(12345); let bal_absent = st .wallet() - .get_transparent_balances(account_id, height_1) + .get_transparent_balances(&account_id, height_1) .unwrap(); assert!(bal_absent.is_empty()); @@ -2195,7 +2218,7 @@ mod tests { ); assert_matches!( - st.wallet().get_transparent_balances(account_id, height_2), + st.wallet().get_transparent_balances(&account_id, height_2), Ok(h) if h.get(taddr) == Some(&value.into()) ); @@ -2240,7 +2263,7 @@ mod tests { let (account_id, usk, _) = st.test_account().unwrap(); let uaddr = st .wallet() - .get_current_address(account_id) + .get_current_address(&account_id) .unwrap() .unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -2269,7 +2292,7 @@ mod tests { let max_height = st.wallet().chain_height().unwrap().unwrap() + 1 - min_confirmations; assert_eq!( st.wallet() - .get_transparent_balances(account_id, max_height) + .get_transparent_balances(&account_id, max_height) .unwrap() .get(taddr) .cloned() diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index 14bf46d30f..9876511db1 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -1099,7 +1099,7 @@ mod tests { assert_eq!(tv.unified_addr, ua.encode(&Network::MainNetwork)); db_data - .get_next_available_address(account, DEFAULT_UA_REQUEST) + .get_next_available_address(&account, DEFAULT_UA_REQUEST) .unwrap() .expect("get_next_available_address generated an address"); } else { diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index 7061e62e8a..3661cd3d30 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -38,12 +38,12 @@ pub(crate) trait ReceivedSaplingOutput { fn recipient_key_scope(&self) -> Scope; } -impl ReceivedSaplingOutput for WalletSaplingOutput { +impl ReceivedSaplingOutput for WalletSaplingOutput { fn index(&self) -> usize { self.index() } fn account(&self) -> AccountId { - WalletSaplingOutput::account(self) + *WalletSaplingOutput::account(self) } fn note(&self) -> &sapling::Note { WalletSaplingOutput::note(self) @@ -66,7 +66,7 @@ impl ReceivedSaplingOutput for WalletSaplingOutput { } } -impl ReceivedSaplingOutput for DecryptedOutput { +impl ReceivedSaplingOutput for DecryptedOutput { fn index(&self) -> usize { self.index } @@ -1641,7 +1641,7 @@ pub(crate) mod tests { let uaddr = st .wallet() - .get_current_address(account_id) + .get_current_address(&account_id) .unwrap() .unwrap(); let taddr = uaddr.transparent().unwrap();