From 52c64632b937f38fb8a73b8fd578ccbf07ea8804 Mon Sep 17 00:00:00 2001 From: kevaundray Date: Mon, 23 Sep 2024 22:25:29 +0100 Subject: [PATCH] chore: Add a more aggressive caching/precompuation strategy for first five elements in CRS (#98) --- banderwagon/Cargo.toml | 8 + banderwagon/benches/benchmark.rs | 62 +++++++ banderwagon/src/lib.rs | 1 + banderwagon/src/msm_windowed_sign.rs | 241 +++++++++++++++++++++++++++ ipa-multipoint/src/committer.rs | 24 ++- 5 files changed, 333 insertions(+), 3 deletions(-) create mode 100644 banderwagon/benches/benchmark.rs create mode 100644 banderwagon/src/msm_windowed_sign.rs diff --git a/banderwagon/Cargo.toml b/banderwagon/Cargo.toml index 3faa7e5..6b6d004 100644 --- a/banderwagon/Cargo.toml +++ b/banderwagon/Cargo.toml @@ -11,9 +11,17 @@ ark-ff = { version = "^0.4.2", default-features = false } ark-ec = { version = "^0.4.2", default-features = false } ark-serialize = { version = "^0.4.2", default-features = false } rayon = "*" + [dev-dependencies] hex = "0.4.3" +criterion = "0.5.1" +rand = "0.8.4" +sha3 = "0.10.8" [features] default = ["parallel"] parallel = ["ark-ff/parallel", "ark-ff/asm", "ark-ec/parallel"] + +[[bench]] +name = "benchmark" +harness = false diff --git a/banderwagon/benches/benchmark.rs b/banderwagon/benches/benchmark.rs new file mode 100644 index 0000000..da3c437 --- /dev/null +++ b/banderwagon/benches/benchmark.rs @@ -0,0 +1,62 @@ +use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr}; +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::RngCore; + +pub fn msm_wnaf(c: &mut Criterion) { + const NUM_ELEMENTS: usize = 5; + + let bases = random_point(120, NUM_ELEMENTS); + let scalars = random_scalars(NUM_ELEMENTS, 16); + + let precomp = MSMPrecompWnaf::new(&bases, 12); + + c.bench_function(&format!("msm wnaf: {}", NUM_ELEMENTS), |b| { + b.iter(|| precomp.mul(&scalars)) + }); + + let precomp = MSMPrecompWindowSigned::new(&bases, 16); + c.bench_function(&format!("msm precomp 16: {}", NUM_ELEMENTS), |b| { + b.iter(|| precomp.mul(&scalars)) + }); +} + +pub fn keccak_32bytes(c: &mut Criterion) { + use rand::Rng; + use sha3::{Digest, Keccak256}; + + c.bench_function("keccak 64 bytes", |b| { + b.iter_with_setup( + // Setup function: generates new random data for each iteration + || { + let keccak = Keccak256::default(); + let mut rand_buffer = [0u8; 64]; + rand::thread_rng().fill(&mut rand_buffer); + (keccak, rand_buffer) + }, + |(mut keccak, rand_buffer)| { + keccak.update(&rand_buffer); + keccak.finalize() + }, + ) + }); +} + +fn random_point(seed: u64, num_points: usize) -> Vec { + (0..num_points) + .map(|i| Element::prime_subgroup_generator() * Fr::from((seed + i as u64 + 1) as u64)) + .collect() +} +fn random_scalars(num_points: usize, num_bytes: usize) -> Vec { + use ark_ff::PrimeField; + + (0..num_points) + .map(|_| { + let mut bytes = vec![0u8; num_bytes]; + rand::thread_rng().fill_bytes(&mut bytes[..]); + Fr::from_le_bytes_mod_order(&bytes) + }) + .collect() +} + +criterion_group!(benches, msm_wnaf, keccak_32bytes); +criterion_main!(benches); diff --git a/banderwagon/src/lib.rs b/banderwagon/src/lib.rs index eed146f..eaefd81 100644 --- a/banderwagon/src/lib.rs +++ b/banderwagon/src/lib.rs @@ -1,4 +1,5 @@ pub mod msm; +pub mod msm_windowed_sign; pub mod trait_impls; mod element; diff --git a/banderwagon/src/msm_windowed_sign.rs b/banderwagon/src/msm_windowed_sign.rs new file mode 100644 index 0000000..491fa39 --- /dev/null +++ b/banderwagon/src/msm_windowed_sign.rs @@ -0,0 +1,241 @@ +use crate::Element; +use ark_ec::CurveGroup; +use ark_ed_on_bls12_381_bandersnatch::{EdwardsAffine, EdwardsProjective, Fr}; +use ark_ff::Zero; +use ark_ff::{BigInteger, BigInteger256}; +use std::ops::Neg; + +#[derive(Debug, Clone)] +pub struct MSMPrecompWindowSigned { + tables: Vec>, + num_windows: usize, + window_size: usize, +} + +impl MSMPrecompWindowSigned { + pub fn new(bases: &[Element], window_size: usize) -> MSMPrecompWindowSigned { + use ark_ff::PrimeField; + + let number_of_windows = Fr::MODULUS_BIT_SIZE as usize / window_size + 1; + + let precomputed_points: Vec<_> = bases + .iter() + .map(|point| { + Self::precompute_points( + window_size, + number_of_windows, + EdwardsAffine::from(point.0), + ) + }) + .collect(); + + MSMPrecompWindowSigned { + window_size, + tables: precomputed_points, + num_windows: number_of_windows, + } + } + + fn precompute_points( + window_size: usize, + number_of_windows: usize, + point: EdwardsAffine, + ) -> Vec { + let window_size_scalar = Fr::from(1 << window_size); + use ark_ff::Field; + + use rayon::prelude::*; + + let all_tables: Vec<_> = (0..number_of_windows) + .into_par_iter() + .flat_map(|window_index| { + let window_scalar = window_size_scalar.pow([window_index as u64]); + let mut lookup_table = Vec::with_capacity(1 << (window_size - 1)); + let point = EdwardsProjective::from(point) * window_scalar; + let mut current = point; + // Compute and store multiples + for _ in 0..(1 << (window_size - 1)) { + lookup_table.push(current); + current += point; + } + EdwardsProjective::normalize_batch(&lookup_table) + }) + .collect(); + + all_tables + } + + pub fn mul(&self, scalars: &[Fr]) -> Element { + let scalars_bytes: Vec<_> = scalars + .iter() + .map(|a| { + let bigint: BigInteger256 = (*a).into(); + bigint.to_bytes_le() + }) + .collect(); + + let mut points_to_add = Vec::new(); + + for window_idx in 0..self.num_windows { + for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { + let sub_table = &self.tables[scalar_idx]; + let point_idx = + get_booth_index(window_idx, self.window_size, scalar_bytes.as_ref()); + + if point_idx == 0 { + continue; + } + let sign = point_idx.is_positive(); + let point_idx = point_idx.unsigned_abs() as usize - 1; + + // Scale the point index by the window index to figure out whether + // we need P, 2^wP, 2^{2w}P, etc + let scaled_point_index = window_idx * (1 << (self.window_size - 1)) + point_idx; + let mut point = sub_table[scaled_point_index]; + + if !sign { + point = -point; + } + + points_to_add.push(point); + } + } + + let mut result = EdwardsProjective::zero(); + for point in points_to_add { + result += point; + } + + Element(result) + } +} + +// TODO: Link to halo2 file + docs + comments +pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { + // Booth encoding: + // * step by `window` size + // * slice by size of `window + 1`` + // * each window overlap by 1 bit + // * append a zero bit to the least significant end + // Indexing rule for example window size 3 where we slice by 4 bits: + // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` + // So we can reduce the bucket size without preprocessing scalars + // and remembering them as in classic signed digit encoding + + let skip_bits = (window_index * window_size).saturating_sub(1); + let skip_bytes = skip_bits / 8; + + // fill into a u32 + let mut v: [u8; 4] = [0; 4]; + for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) { + *dst = *src + } + let mut tmp = u32::from_le_bytes(v); + + // pad with one 0 if slicing the least significant window + if window_index == 0 { + tmp <<= 1; + } + + // remove further bits + tmp >>= skip_bits - (skip_bytes * 8); + // apply the booth window + tmp &= (1 << (window_size + 1)) - 1; + + let sign = tmp & (1 << window_size) == 0; + + // div ceil by 2 + tmp = (tmp + 1) >> 1; + + // find the booth action index + if sign { + tmp as i32 + } else { + ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg() + } +} + +#[test] +fn smoke_test_interop_strauss() { + use ark_ff::UniformRand; + + let length = 5; + let scalars: Vec<_> = (0..length) + .map(|_| Fr::rand(&mut rand::thread_rng())) + .collect(); + let points: Vec<_> = (0..length) + .map(|_| Element::prime_subgroup_generator() * Fr::rand(&mut rand::thread_rng())) + .collect(); + + let precomp = MSMPrecompWindowSigned::new(&points, 2); + let result = precomp.mul(&scalars); + + let mut expected = Element::zero(); + for (scalar, point) in scalars.into_iter().zip(points) { + expected += point * scalar + } + + assert_eq!(expected, result) +} + +#[cfg(test)] +mod booth_tests { + use std::ops::Neg; + + use ark_ed_on_bls12_381_bandersnatch::Fr; + use ark_ff::{BigInteger, BigInteger256, Field, PrimeField}; + + use super::get_booth_index; + use crate::Element; + + #[test] + fn smoke_scalar_mul() { + let gen = Element::prime_subgroup_generator(); + let s = -Fr::ONE; + + let res = gen * s; + + let got = mul(&s, &gen, 4); + + assert_eq!(Element::from(res), got) + } + + fn mul(scalar: &Fr, point: &Element, window: usize) -> Element { + let u_bigint: BigInteger256 = (*scalar).into(); + use ark_ff::Field; + let u = u_bigint.to_bytes_le(); + let n = Fr::MODULUS_BIT_SIZE as usize / window + 1; + + let table = (0..=1 << (window - 1)) + .map(|i| point * &Fr::from(i as u64)) + .collect::>(); + + let table_scalars = (0..=1 << (window - 1)) + .map(|i| Fr::from(i as u64)) + .collect::>(); + + let mut acc: Element = Element::zero(); + let mut acc_scalar = Fr::ZERO; + for i in (0..n).rev() { + for _ in 0..window { + acc = acc + acc; + acc_scalar = acc_scalar + acc_scalar; + } + + let idx = get_booth_index(i as usize, window, u.as_ref()); + + if idx.is_negative() { + acc += table[idx.unsigned_abs() as usize].neg(); + acc_scalar -= table_scalars[idx.unsigned_abs() as usize]; + } + if idx.is_positive() { + acc += table[idx.unsigned_abs() as usize]; + acc_scalar += table_scalars[idx.unsigned_abs() as usize]; + } + } + + assert_eq!(acc_scalar, *scalar); + + acc.into() + } +} diff --git a/ipa-multipoint/src/committer.rs b/ipa-multipoint/src/committer.rs index 968877e..3e4d2ef 100644 --- a/ipa-multipoint/src/committer.rs +++ b/ipa-multipoint/src/committer.rs @@ -1,4 +1,4 @@ -use banderwagon::{msm::MSMPrecompWnaf, Element, Fr}; +use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr}; // This is the functionality that commits to the branch nodes and computes the delta optimization // For consistency with the Pcs, ensure that this component uses the same CRS as the Pcs @@ -24,19 +24,31 @@ pub trait Committer { #[derive(Clone, Debug)] pub struct DefaultCommitter { + precomp_first_five: MSMPrecompWindowSigned, precomp: MSMPrecompWnaf, } impl DefaultCommitter { pub fn new(points: &[Element]) -> Self { + // Take the first five elements and use a more aggressive optimization strategy + // since they are used for computing storage keys. + + let (points_five, _) = points.split_at(5); + let precomp_first_five = MSMPrecompWindowSigned::new(points_five, 16); let precomp = MSMPrecompWnaf::new(points, 12); - Self { precomp } + Self { + precomp, + precomp_first_five, + } } } impl Committer for DefaultCommitter { fn commit_lagrange(&self, evaluations: &[Fr]) -> Element { + if evaluations.len() <= 5 { + return self.precomp_first_five.mul(evaluations); + } // Preliminary benchmarks indicate that the parallel version is faster // for vectors of length 64 or more if evaluations.len() >= 64 { @@ -47,6 +59,12 @@ impl Committer for DefaultCommitter { } fn scalar_mul(&self, value: Fr, lagrange_index: usize) -> Element { - self.precomp.mul_index(value, lagrange_index) + if lagrange_index < 5 { + let mut arr = [Fr::from(0u64); 5]; + arr[lagrange_index] = value; + self.precomp_first_five.mul(&arr) + } else { + self.precomp.mul_index(value, lagrange_index) + } } }