Skip to content

Commit

Permalink
fix: avoid overflow when counting scalars (PROOF-906) (#38)
Browse files Browse the repository at this point in the history
* avoid overflow

* fix clippy

* drop dead code
  • Loading branch information
rnburn authored Sep 12, 2024
1 parent 01882f7 commit 05db1e4
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/compute/fixed_msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ use ark_ec::short_weierstrass::Affine;
use rayon::prelude::*;
use std::marker::PhantomData;

fn count_scalars_per_output(scalars_len: usize, output_bit_table: &[u32]) -> u32 {
let bit_sum: usize = output_bit_table.iter().map(|s| *s as usize).sum();
let num_output_bytes = (bit_sum + 7) / 8;
assert!(scalars_len % num_output_bytes == 0);
(scalars_len / num_output_bytes).try_into().unwrap()
}

/// Handle to compute multi-scalar multiplications (MSMs) with pre-specified generators
///
/// # Example 1 - compute an MSM using the handle
Expand Down Expand Up @@ -96,10 +103,7 @@ impl<T: CurveId> MsmHandle<T> {
/// exponents for generator g_i with the output scalars packed contiguously and padded with zeros.
pub fn packed_msm(&self, res: &mut [T], output_bit_table: &[u32], scalars: &[u8]) {
let num_outputs = res.len() as u32;
let bit_sum: u32 = output_bit_table.iter().sum();
let num_output_bytes = (bit_sum + 7) / 8;
assert!(scalars.len() as u32 % num_output_bytes == 0);
let n = scalars.len() as u32 / num_output_bytes;
let n = count_scalars_per_output(scalars.len(), output_bit_table);
unsafe {
blitzar_sys::sxt_fixed_packed_multiexponentiation(
res.as_ptr() as *mut std::ffi::c_void,
Expand Down Expand Up @@ -235,3 +239,24 @@ impl<C: SwCurveConfig + Clone> SwMsmHandle for MsmHandle<ElementP2<C>> {
});
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn we_can_count_the_number_of_scalars_per_output() {
let output_bit_table = [1];
let n = count_scalars_per_output(1, &output_bit_table);
assert_eq!(n, 1);

let output_bit_table = [14, 2];
let n = count_scalars_per_output(10, &output_bit_table);
assert_eq!(n, 5);

// we handle cases that overflow
let output_bit_table = [u32::MAX, 1];
let n = count_scalars_per_output((u32::MAX as usize) + 1, &output_bit_table);
assert_eq!(n, 8);
}
}

0 comments on commit 05db1e4

Please sign in to comment.