From 05db1e487f20cfb774beb9337bebad8f62d70357 Mon Sep 17 00:00:00 2001 From: Ryan Burn Date: Thu, 12 Sep 2024 09:07:28 -0700 Subject: [PATCH] fix: avoid overflow when counting scalars (PROOF-906) (#38) * avoid overflow * fix clippy * drop dead code --- src/compute/fixed_msm.rs | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/compute/fixed_msm.rs b/src/compute/fixed_msm.rs index 744c04c..dc25c59 100644 --- a/src/compute/fixed_msm.rs +++ b/src/compute/fixed_msm.rs @@ -4,6 +4,13 @@ use ark_ec::short_weierstrass::Affine; use rayon::prelude::*; use std::marker::PhantomData; +fn count_scalars_per_output(scalars_len: usize, output_bit_table: &[u32]) -> u32 { + let bit_sum: usize = output_bit_table.iter().map(|s| *s as usize).sum(); + let num_output_bytes = (bit_sum + 7) / 8; + assert!(scalars_len % num_output_bytes == 0); + (scalars_len / num_output_bytes).try_into().unwrap() +} + /// Handle to compute multi-scalar multiplications (MSMs) with pre-specified generators /// /// # Example 1 - compute an MSM using the handle @@ -96,10 +103,7 @@ impl MsmHandle { /// exponents for generator g_i with the output scalars packed contiguously and padded with zeros. pub fn packed_msm(&self, res: &mut [T], output_bit_table: &[u32], scalars: &[u8]) { let num_outputs = res.len() as u32; - let bit_sum: u32 = output_bit_table.iter().sum(); - let num_output_bytes = (bit_sum + 7) / 8; - assert!(scalars.len() as u32 % num_output_bytes == 0); - let n = scalars.len() as u32 / num_output_bytes; + let n = count_scalars_per_output(scalars.len(), output_bit_table); unsafe { blitzar_sys::sxt_fixed_packed_multiexponentiation( res.as_ptr() as *mut std::ffi::c_void, @@ -235,3 +239,24 @@ impl SwMsmHandle for MsmHandle> { }); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn we_can_count_the_number_of_scalars_per_output() { + let output_bit_table = [1]; + let n = count_scalars_per_output(1, &output_bit_table); + assert_eq!(n, 1); + + let output_bit_table = [14, 2]; + let n = count_scalars_per_output(10, &output_bit_table); + assert_eq!(n, 5); + + // we handle cases that overflow + let output_bit_table = [u32::MAX, 1]; + let n = count_scalars_per_output((u32::MAX as usize) + 1, &output_bit_table); + assert_eq!(n, 8); + } +}