Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid overflow when counting scalars (PROOF-906) #38

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/compute/fixed_msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ use ark_ec::short_weierstrass::Affine;
use rayon::prelude::*;
use std::marker::PhantomData;

fn count_scalars_per_output(scalars_len: usize, output_bit_table: &[u32]) -> u32 {
let bit_sum: usize = output_bit_table.iter().map(|s| *s as usize).sum();
let num_output_bytes = (bit_sum + 7) / 8;
assert!(scalars_len % num_output_bytes == 0);
(scalars_len / num_output_bytes).try_into().unwrap()
}

/// Handle to compute multi-scalar multiplications (MSMs) with pre-specified generators
///
/// # Example 1 - compute an MSM using the handle
Expand Down Expand Up @@ -96,10 +103,7 @@ impl<T: CurveId> MsmHandle<T> {
/// exponents for generator g_i with the output scalars packed contiguously and padded with zeros.
pub fn packed_msm(&self, res: &mut [T], output_bit_table: &[u32], scalars: &[u8]) {
let num_outputs = res.len() as u32;
let bit_sum: u32 = output_bit_table.iter().sum();
let num_output_bytes = (bit_sum + 7) / 8;
assert!(scalars.len() as u32 % num_output_bytes == 0);
let n = scalars.len() as u32 / num_output_bytes;
let n = count_scalars_per_output(scalars.len(), output_bit_table);
unsafe {
blitzar_sys::sxt_fixed_packed_multiexponentiation(
res.as_ptr() as *mut std::ffi::c_void,
Expand Down Expand Up @@ -235,3 +239,24 @@ impl<C: SwCurveConfig + Clone> SwMsmHandle for MsmHandle<ElementP2<C>> {
});
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn we_can_count_the_number_of_scalars_per_output() {
let output_bit_table = [1];
let n = count_scalars_per_output(1, &output_bit_table);
assert_eq!(n, 1);

let output_bit_table = [14, 2];
let n = count_scalars_per_output(10, &output_bit_table);
assert_eq!(n, 5);

// we handle cases that overflow
let output_bit_table = [u32::MAX, 1];
let n = count_scalars_per_output((u32::MAX as usize) + 1, &output_bit_table);
assert_eq!(n, 8);
}
}
Loading