Skip to content

Commit

Permalink
refactor: Inline used-once barycentric methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Oct 1, 2024
1 parent 50f24d1 commit 289c9fb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 49 deletions.
43 changes: 0 additions & 43 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use itertools::Itertools;
use ndarray::ArrayView1;
use num_traits::Zero;
use rayon::prelude::*;
use std::ops::Mul;
use twenty_first::math::traits::FiniteField;
use twenty_first::math::traits::PrimitiveRootOfUnity;
use twenty_first::prelude::*;

use crate::arithmetic_domain::ArithmeticDomain;
Expand Down Expand Up @@ -640,47 +638,6 @@ pub fn barycentric_evaluate<FF: FiniteField + Mul<XFieldElement, Output = XField
twenty_first::math::polynomial::barycentric_evaluate(codeword, indeterminate)
}

pub(crate) fn batch_barycentric_preprocess(
codeword_length: usize,
indeterminate: XFieldElement,
) -> (Vec<XFieldElement>, XFieldElement) {
let root_order = codeword_length.try_into().unwrap();
let generator = BFieldElement::primitive_root_of_unity(root_order).unwrap();
let domain_iter = (0..root_order).scan(bfe!(1), |acc, _| {
let to_yield = Some(*acc);
*acc *= generator;
to_yield
});

let domain_shift = domain_iter.clone().map(|d| indeterminate - d).collect();
let domain_shift_inverses = XFieldElement::batch_inversion(domain_shift);
let domain_over_domain_shift = domain_iter
.into_iter()
.zip(domain_shift_inverses)
.map(|(d, inv)| d * inv);
let denominator_inverse = domain_over_domain_shift
.clone()
.sum::<XFieldElement>()
.inverse();

(domain_over_domain_shift.collect_vec(), denominator_inverse)
}

pub(crate) fn batch_barycentric_evaluate<
FF: FiniteField + Mul<XFieldElement, Output = XFieldElement>,
>(
codeword: ArrayView1<FF>,
preprocessing_data: &(Vec<XFieldElement>, XFieldElement),
) -> XFieldElement {
let (domain_over_domain_shift, denominator_inverse) = preprocessing_data;
let numerator = domain_over_domain_shift
.iter()
.zip(codeword.iter())
.map(|(&dsi, &abscis)| abscis * dsi)
.sum::<XFieldElement>();
numerator * *denominator_inverse
}

#[cfg(test)]
mod tests {
use std::cmp::max;
Expand Down
31 changes: 25 additions & 6 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ use crate::arithmetic_domain::ArithmeticDomain;
use crate::challenges::Challenges;
use crate::config::CacheDecision;
use crate::error::ProvingError;
use crate::fri::batch_barycentric_evaluate;
use crate::fri::batch_barycentric_preprocess;
use crate::ndarray_helper::fast_zeros_column_major;
use crate::ndarray_helper::horizontal_multi_slice_mut;
use crate::ndarray_helper::partial_sums;
Expand Down Expand Up @@ -319,13 +317,34 @@ where
XFieldElement::ZERO;
self.randomized_trace_table().ncols()
]);
let barycentric_preprocessing_data =
batch_barycentric_preprocess(self.randomized_trace_domain().length, indeterminate);

// The following is a batched version of barycentric Lagrangian evaluation.
// Since the method `barycentric_evaluate` is self-contained, not returning
// intermediate items necessary for batching, and since returning and reusing
// those indermediate items would produce a challenging interface, the relevant
// parts are reimplemented here.
let domain = self.randomized_trace_domain().domain_values();
let domain_shift = domain.iter().map(|&d| indeterminate - d).collect();
let domain_shift_inverses = XFieldElement::batch_inversion(domain_shift);
let domain_over_domain_shift = domain
.into_iter()
.zip(domain_shift_inverses)
.map(|(d, inv)| d * inv);
let denominator_inverse = domain_over_domain_shift
.clone()
.sum::<XFieldElement>()
.inverse();

Zip::from(ood_row.axis_iter_mut(Axis(0)))
.and(self.randomized_trace_table().axis_iter(Axis(1)))
.par_for_each(|v, codeword| {
let value = batch_barycentric_evaluate(codeword, &barycentric_preprocessing_data);
Array0::from_elem((), value).move_into(v);
let numerator = domain_over_domain_shift
.clone()
.zip(codeword)
.map(|(dsi, &abscis)| abscis * dsi)
.sum::<XFieldElement>();

Array0::from_elem((), numerator * denominator_inverse).move_into(v);
});
ood_row
}
Expand Down

0 comments on commit 289c9fb

Please sign in to comment.