Skip to content

Commit

Permalink
feat!: improve api to apply selection to transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
evanlinjin committed Aug 27, 2024
1 parent 434c6e8 commit 5d59aca
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ std = []
[dev-dependencies]
rand = "0.8"
proptest = "1.4"
bitcoin = "0.30"
bitcoin = "0.32"
32 changes: 16 additions & 16 deletions src/bnb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use alloc::collections::BinaryHeap;
/// An [`Iterator`] that iterates over rounds of branch and bound to minimize the score of the
/// provided [`BnbMetric`].
#[derive(Debug)]
pub(crate) struct BnbIter<'a, M: BnbMetric> {
queue: BinaryHeap<Branch<'a>>,
pub(crate) struct BnbIter<'a, C, M: BnbMetric> {
queue: BinaryHeap<Branch<'a, C>>,
best: Option<Ordf32>,
/// The `BnBMetric` that will score each selection
metric: M,
}

impl<'a, M: BnbMetric> Iterator for BnbIter<'a, M> {
type Item = Option<(CoinSelector<'a>, Ordf32)>;
impl<'a, C, M: BnbMetric> Iterator for BnbIter<'a, C, M> {
type Item = Option<(CoinSelector<'a, C>, Ordf32)>;

fn next(&mut self) -> Option<Self::Item> {
// {
Expand Down Expand Up @@ -70,8 +70,8 @@ impl<'a, M: BnbMetric> Iterator for BnbIter<'a, M> {
}
}

impl<'a, M: BnbMetric> BnbIter<'a, M> {
pub(crate) fn new(mut selector: CoinSelector<'a>, metric: M) -> Self {
impl<'a, C, M: BnbMetric> BnbIter<'a, C, M> {
pub(crate) fn new(mut selector: CoinSelector<'a, C>, metric: M) -> Self {
let mut iter = BnbIter {
queue: BinaryHeap::default(),
best: None,
Expand All @@ -87,7 +87,7 @@ impl<'a, M: BnbMetric> BnbIter<'a, M> {
iter
}

fn consider_adding_to_queue(&mut self, cs: &CoinSelector<'a>, is_exclusion: bool) {
fn consider_adding_to_queue(&mut self, cs: &CoinSelector<'a, C>, is_exclusion: bool) {
let bound = self.metric.bound(cs);
if let Some(bound) = bound {
let is_good_enough = match self.best {
Expand Down Expand Up @@ -127,7 +127,7 @@ impl<'a, M: BnbMetric> BnbIter<'a, M> {
}*/
}

fn insert_new_branches(&mut self, cs: &CoinSelector<'a>) {
fn insert_new_branches(&mut self, cs: &CoinSelector<'a, C>) {
let (next_index, next) = match cs.unselected().next() {
Some(c) => c,
None => return, // exhausted
Expand Down Expand Up @@ -161,13 +161,13 @@ impl<'a, M: BnbMetric> BnbIter<'a, M> {
}

#[derive(Debug, Clone)]
struct Branch<'a> {
struct Branch<'a, C> {
lower_bound: Ordf32,
selector: CoinSelector<'a>,
selector: CoinSelector<'a, C>,
is_exclusion: bool,
}

impl<'a> Ord for Branch<'a> {
impl<'a, C> Ord for Branch<'a, C> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
// NOTE: Reverse comparision `lower_bound` because we want a min-heap (by default BinaryHeap
// is a max-heap).
Expand All @@ -181,19 +181,19 @@ impl<'a> Ord for Branch<'a> {
}
}

impl<'a> PartialOrd for Branch<'a> {
impl<'a, C> PartialOrd for Branch<'a, C> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<'a> PartialEq for Branch<'a> {
impl<'a, C> PartialEq for Branch<'a, C> {
fn eq(&self, other: &Self) -> bool {
self.lower_bound == other.lower_bound
}
}

impl<'a> Eq for Branch<'a> {}
impl<'a, C> Eq for Branch<'a, C> {}

/// A branch and bound metric where we minimize the [`Ordf32`] score.
///
Expand All @@ -202,7 +202,7 @@ pub trait BnbMetric {
/// Get the score of a given selection.
///
/// If this returns `None`, the selection is invalid.
fn score(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32>;
fn score<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32>;

/// Get the lower bound score using a heuristic.
///
Expand All @@ -211,7 +211,7 @@ pub trait BnbMetric {
///
/// If this returns `None`, the current branch and all descendant branches will not have valid
/// solutions.
fn bound(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32>;
fn bound<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32>;

/// Returns whether the metric requies we order candidates by descending value per weight unit.
fn requires_ordering_by_descending_value_pwu(&self) -> bool {
Expand Down
118 changes: 86 additions & 32 deletions src/coin_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,49 @@ use alloc::{borrow::Cow, collections::BTreeSet, vec::Vec};
///
/// [`select`]: CoinSelector::select
/// [`bnb_solutions`]: CoinSelector::bnb_solutions
#[derive(Debug, Clone)]
pub struct CoinSelector<'a> {
candidates: &'a [Candidate],
#[derive(Debug)]
pub struct CoinSelector<'a, C> {
inputs: &'a [C],
candidates: Cow<'a, [Candidate]>,
selected: Cow<'a, BTreeSet<usize>>,
banned: Cow<'a, BTreeSet<usize>>,
candidate_order: Cow<'a, Vec<usize>>,
}

impl<'a> CoinSelector<'a> {
/// Creates a new coin selector from some candidate inputs and a `base_weight`.
///
/// The `base_weight` is the weight of the transaction without any inputs and without a change
/// output.
///
/// The `CoinSelector` does not keep track of the final transaction's output count. The caller
/// is responsible for including the potential output-count varint weight change in the
/// corresponding [`DrainWeights`].
impl<'a, C> Clone for CoinSelector<'a, C> {
fn clone(&self) -> Self {
Self {
inputs: self.inputs,
candidates: self.candidates.clone(),
selected: self.selected.clone(),
banned: self.banned.clone(),
candidate_order: self.candidate_order.clone(),
}
}
}

impl<'a, C> CoinSelector<'a, C> {
/// Create a coin selector from raw `inputs` and a `to_candidate` closure.
///
/// Note that methods in `CoinSelector` will refer to inputs by the index in the `candidates`
/// slice you pass in.
pub fn new(candidates: &'a [Candidate]) -> Self {
/// `to_candidate` maps each raw input to a [`Candidate`] representation.
pub fn new<F>(inputs: &'a [C], to_candidate: F) -> Self
where
F: Fn(&C) -> Candidate,
{
Self {
candidates,
inputs,
candidates: inputs.iter().map(to_candidate).collect(),
selected: Cow::Owned(Default::default()),
banned: Cow::Owned(Default::default()),
candidate_order: Cow::Owned((0..candidates.len()).collect()),
candidate_order: Cow::Owned((0..inputs.len()).collect()),
}
}

/// Get a reference to the raw inputs.
pub fn raw_inputs(&self) -> &[C] {
self.inputs
}

/// Iterate over all the candidates in their currently sorted order. Each item has the original
/// index with the candidate.
pub fn candidates(
Expand All @@ -62,11 +76,39 @@ impl<'a> CoinSelector<'a> {
self.selected.to_mut().remove(&index)
}

/// Convienince method to pick elements of a slice by the indexes that are currently selected.
/// Obviously the slice must represent the inputs ordered in the same way as when they were
/// passed to `Candidates::new`.
pub fn apply_selection<T>(&self, candidates: &'a [T]) -> impl Iterator<Item = &'a T> + '_ {
self.selected.iter().map(move |i| &candidates[*i])
/// Apply the current coin selection.
///
/// `apply_action` is a closure that is meant to construct an unsigned transaction based on the
/// current selection. `apply_action` is a [`FnMut`] so it can mutate a structure of the
/// caller's liking (most likely a transaction). The input is a [`FinishAction`], which conveys
/// adding inputs or outputs.
///
/// # Errors
///
/// The selection must satisfy `target` otherwise an [`InsufficientFunds`] error is returned.
pub fn finish<F>(
self,
target: Target,
change_policy: ChangePolicy,
mut apply_action: F,
) -> Result<(), InsufficientFunds>
where
F: FnMut(FinishAction<'a, C>),
{
let excess = self.excess(target, Drain::NONE);
if excess < 0 {
let missing = excess.unsigned_abs();
return Err(InsufficientFunds { missing });
}
let drain = self.drain(target, change_policy);
for i in self.selected.iter().copied() {
apply_action(FinishAction::Input(&self.inputs[i]));
}
apply_action(FinishAction::TargetOutput(target));
if drain.is_some() {
apply_action(FinishAction::DrainOutput(drain));
}
Ok(())
}

/// Select the input at `index`. `index` refers to its position in the original `candidates`
Expand Down Expand Up @@ -331,7 +373,7 @@ impl<'a> CoinSelector<'a> {
let mut excess_waste = self.excess(target, drain).max(0) as f32;
// we allow caller to discount this waste depending on how wasteful excess actually is
// to them.
excess_waste *= excess_discount.max(0.0).min(1.0);
excess_waste *= excess_discount.clamp(0.0, 1.0);
waste += excess_waste;
} else {
waste +=
Expand Down Expand Up @@ -489,7 +531,7 @@ impl<'a> CoinSelector<'a> {
#[must_use]
pub fn select_until(
&mut self,
mut predicate: impl FnMut(&CoinSelector<'a>) -> bool,
mut predicate: impl FnMut(&CoinSelector<'a, C>) -> bool,
) -> Option<()> {
loop {
if predicate(&*self) {
Expand All @@ -503,7 +545,7 @@ impl<'a> CoinSelector<'a> {
}

/// Return an iterator that can be used to select candidates.
pub fn select_iter(self) -> SelectIter<'a> {
pub fn select_iter(self) -> SelectIter<'a, C> {
SelectIter { cs: self.clone() }
}

Expand All @@ -517,7 +559,7 @@ impl<'a> CoinSelector<'a> {
pub fn bnb_solutions<M: BnbMetric>(
&self,
metric: M,
) -> impl Iterator<Item = Option<(CoinSelector<'a>, Ordf32)>> {
) -> impl Iterator<Item = Option<(CoinSelector<'a, C>, Ordf32)>> {
crate::bnb::BnbIter::new(self.clone(), metric)
}

Expand Down Expand Up @@ -545,7 +587,7 @@ impl<'a> CoinSelector<'a> {
}
}

impl<'a> core::fmt::Display for CoinSelector<'a> {
impl<'a, C> core::fmt::Display for CoinSelector<'a, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "[")?;
let mut candidates = self.candidates().peekable();
Expand All @@ -572,12 +614,12 @@ impl<'a> core::fmt::Display for CoinSelector<'a> {
/// The `SelectIter` allows you to select candidates by calling [`Iterator::next`].
///
/// The [`Iterator::Item`] is a tuple of `(selector, last_selected_index, last_selected_candidate)`.
pub struct SelectIter<'a> {
cs: CoinSelector<'a>,
pub struct SelectIter<'a, C> {
cs: CoinSelector<'a, C>,
}

impl<'a> Iterator for SelectIter<'a> {
type Item = (CoinSelector<'a>, usize, Candidate);
impl<'a, C> Iterator for SelectIter<'a, C> {
type Item = (CoinSelector<'a, C>, usize, Candidate);

fn next(&mut self) -> Option<Self::Item> {
let (index, wv) = self.cs.unselected().next()?;
Expand All @@ -586,7 +628,7 @@ impl<'a> Iterator for SelectIter<'a> {
}
}

impl<'a> DoubleEndedIterator for SelectIter<'a> {
impl<'a, C> DoubleEndedIterator for SelectIter<'a, C> {
fn next_back(&mut self) -> Option<Self::Item> {
let (index, wv) = self.cs.unselected().next_back()?;
self.cs.select(index);
Expand Down Expand Up @@ -632,6 +674,18 @@ impl core::fmt::Display for NoBnbSolution {
#[cfg(feature = "std")]
impl std::error::Error for NoBnbSolution {}

/// Action to apply on a transaction.
///
/// This is used in [`CoinSelector::finish`] to populate a transaction with the current selection.
pub enum FinishAction<'a, C> {
/// Input to add to the transaction.
Input(&'a C),
/// Recipient output to add to the transaction.
TargetOutput(Target),
/// Drain (change) output to add to the transction.
DrainOutput(Drain),
}

/// A `Candidate` represents an input candidate for [`CoinSelector`].
///
/// This can either be a single UTXO, or a group of UTXOs that should be spent together.
Expand Down
18 changes: 11 additions & 7 deletions src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ pub use changeless::*;
//
// NOTE: this should stay private because it requires cs to be sorted such that all negative
// effective value candidates are next to each other.
fn change_lower_bound(cs: &CoinSelector, target: Target, change_policy: ChangePolicy) -> Drain {
fn change_lower_bound<C>(
cs: &CoinSelector<C>,
target: Target,
change_policy: ChangePolicy,
) -> Drain {
let has_change_now = cs.drain_value(target, change_policy).is_some();

if has_change_now {
Expand All @@ -38,7 +42,7 @@ macro_rules! impl_for_tuple {
where $($a: BnbMetric),*
{
#[allow(unused)]
fn score(&mut self, cs: &CoinSelector<'_>) -> Option<crate::float::Ordf32> {
fn score<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<crate::float::Ordf32> {
let mut acc = Option::<f32>::None;
for (score, ratio) in [$((self.$b.0.score(cs)?, self.$b.1)),*] {
let score: Ordf32 = score;
Expand All @@ -51,7 +55,7 @@ macro_rules! impl_for_tuple {
acc.map(Ordf32)
}
#[allow(unused)]
fn bound(&mut self, cs: &CoinSelector<'_>) -> Option<crate::float::Ordf32> {
fn bound<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<crate::float::Ordf32> {
let mut acc = Option::<f32>::None;
for (score, ratio) in [$((self.$b.0.bound(cs)?, self.$b.1)),*] {
let score: Ordf32 = score;
Expand All @@ -72,7 +76,7 @@ macro_rules! impl_for_tuple {
}

impl_for_tuple!();
impl_for_tuple!(A 0 B 1);
impl_for_tuple!(A 0 B 1 C 2);
impl_for_tuple!(A 0 B 1 C 2 D 3);
impl_for_tuple!(A 0 B 1 C 2 D 3 E 4);
impl_for_tuple!(TA 0 TB 1);
impl_for_tuple!(TA 0 TB 1 TC 2);
impl_for_tuple!(TA 0 TB 1 TC 2 TD 3);
impl_for_tuple!(TA 0 TB 1 TC 2 TD 3 TE 4);
4 changes: 2 additions & 2 deletions src/metrics/changeless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct Changeless {
}

impl BnbMetric for Changeless {
fn score(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32> {
fn score<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32> {
if cs.is_target_met(self.target)
&& cs.drain_value(self.target, self.change_policy).is_none()
{
Expand All @@ -21,7 +21,7 @@ impl BnbMetric for Changeless {
}
}

fn bound(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32> {
fn bound<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32> {
if change_lower_bound(cs, self.target, self.change_policy).is_some() {
None
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/metrics/lowest_fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct LowestFee {
}

impl BnbMetric for LowestFee {
fn score(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32> {
fn score<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32> {
if !cs.is_target_met(self.target) {
return None;
}
Expand All @@ -44,7 +44,7 @@ impl BnbMetric for LowestFee {
Some(Ordf32(long_term_fee as f32))
}

fn bound(&mut self, cs: &CoinSelector<'_>) -> Option<Ordf32> {
fn bound<C>(&mut self, cs: &CoinSelector<'_, C>) -> Option<Ordf32> {
if cs.is_target_met(self.target) {
let current_score = self.score(cs).unwrap();

Expand Down
Loading

0 comments on commit 5d59aca

Please sign in to comment.