Skip to content

Commit

Permalink
refactor: Generalize ndarray helper par_zeros
Browse files Browse the repository at this point in the history
changelog: ignore
  • Loading branch information
jan-ferdinand committed Oct 14, 2024
1 parent d3da84a commit b4765f6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
36 changes: 25 additions & 11 deletions triton-vm/src/ndarray_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ pub fn contiguous_column_slices(column_indices: &[usize]) -> Vec<usize> {
.concat()
}

pub fn fast_zeros_column_major<FF: Zero + Send + Sync + Copy>(
num_rows: usize,
num_columns: usize,
) -> Array2<FF> {
let mut array = Array2::uninit((num_rows, num_columns).f());
/// Faster than [`Array2::zeros`] through parallelism.
pub fn par_zeros<FF>(shape: impl ShapeBuilder<Dim = ndarray::Dim<[Ix; 2]>>) -> Array2<FF>
where
FF: Zero + Send + Sync + Copy,
{
let mut array = Array2::uninit(shape);
array.par_mapv_inplace(|_| MaybeUninit::new(FF::zero()));

unsafe {
Expand Down Expand Up @@ -184,30 +185,43 @@ mod test {
}

#[proptest]
fn fast_zeros_column_major_has_right_dimensions(
fn par_zeros_has_right_dimensions(
#[strategy(0usize..1000)] height: usize,
#[strategy(0usize..1000)] width: usize,
column_majority: bool,
) {
let matrix = fast_zeros_column_major::<XFieldElement>(height, width);
let shape = (height, width).set_f(column_majority);
let matrix = par_zeros::<XFieldElement>(shape);
prop_assert_eq!(height, matrix.nrows());
prop_assert_eq!(width, matrix.ncols());
}

#[proptest]
fn fast_zeros_column_major_is_not_standard_layout(
fn par_zeros_row_major_is_standard_layout(
#[strategy(2usize..1000)] height: usize,
#[strategy(2usize..1000)] width: usize,
) {
let matrix = fast_zeros_column_major::<XFieldElement>(height, width);
let matrix = par_zeros::<XFieldElement>((height, width));
prop_assert!(matrix.is_standard_layout());
}

#[proptest]
fn par_zeros_column_major_is_not_standard_layout(
#[strategy(2usize..1000)] height: usize,
#[strategy(2usize..1000)] width: usize,
) {
let matrix = par_zeros::<XFieldElement>((height, width).f());
prop_assert!(!matrix.is_standard_layout());
}

#[proptest]
fn fast_zeros_column_major_is_all_zeros(
fn par_zeros_is_all_zeros(
#[strategy(0usize..1000)] height: usize,
#[strategy(0usize..1000)] width: usize,
column_majority: bool,
) {
let matrix = fast_zeros_column_major::<XFieldElement>(height, width);
let shape = (height, width).set_f(column_majority);
let matrix = par_zeros::<XFieldElement>(shape);
prop_assert!(matrix.iter().all(|e| e.is_zero()));
}
}
11 changes: 7 additions & 4 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::error::ProvingError;
use crate::error::VerificationError;
use crate::fri;
use crate::fri::Fri;
use crate::ndarray_helper::fast_zeros_column_major;
use crate::ndarray_helper;
use crate::profiler::profiler;
use crate::proof::Claim;
use crate::proof::Proof;
Expand Down Expand Up @@ -965,10 +965,13 @@ impl Stark {

// for every coset, evaluate constraints
profiler!(start "zero-initialization");
// column majority (“`F`”) for contiguous column slices
let mut quotient_multicoset_evaluations =
fast_zeros_column_major(num_rows, NUM_QUOTIENT_SEGMENTS);
let mut main_columns = fast_zeros_column_major(num_rows, MasterMainTable::NUM_COLUMNS);
let mut aux_columns = fast_zeros_column_major(num_rows, MasterAuxTable::NUM_COLUMNS);
ndarray_helper::par_zeros((num_rows, NUM_QUOTIENT_SEGMENTS).f());
let mut main_columns =
ndarray_helper::par_zeros((num_rows, MasterMainTable::NUM_COLUMNS).f());
let mut aux_columns =
ndarray_helper::par_zeros((num_rows, MasterAuxTable::NUM_COLUMNS).f());
profiler!(stop "zero-initialization");

profiler!(start "calculate quotients");
Expand Down
12 changes: 8 additions & 4 deletions triton-vm/src/table/master_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ use crate::arithmetic_domain::ArithmeticDomain;
use crate::challenges::Challenges;
use crate::config::CacheDecision;
use crate::error::ProvingError;
use crate::ndarray_helper::fast_zeros_column_major;
use crate::ndarray_helper;
use crate::ndarray_helper::horizontal_multi_slice_mut;
use crate::ndarray_helper::partial_sums;
use crate::profiler::profiler;
Expand Down Expand Up @@ -826,7 +826,9 @@ impl MasterMainTable {
// For the current approach to trace randomizers to work, the randomized trace
// must be _exactly_ twice as long as the trace without trace randomizers.
let trace_domain = randomized_trace_domain.halve().unwrap();
let trace_table = fast_zeros_column_major(trace_domain.length, Self::NUM_COLUMNS);

// column majority (“`F`”) for contiguous column slices
let trace_table = ndarray_helper::par_zeros((trace_domain.length, Self::NUM_COLUMNS).f());

let mut master_main_table = Self {
num_trace_randomizers,
Expand Down Expand Up @@ -947,8 +949,10 @@ impl MasterMainTable {
/// adding some number of columns.
pub fn extend(&self, challenges: &Challenges) -> MasterAuxTable {
profiler!(start "initialize master table");
let mut aux_trace_table =
fast_zeros_column_major(self.trace_table().nrows(), MasterAuxTable::NUM_COLUMNS);
// column majority (“`F`”) for contiguous column slices
let mut aux_trace_table = ndarray_helper::par_zeros(
(self.trace_table().nrows(), MasterAuxTable::NUM_COLUMNS).f(),
);

let randomizers_start = MasterAuxTable::NUM_COLUMNS - NUM_RANDOMIZER_POLYNOMIALS;
aux_trace_table
Expand Down

0 comments on commit b4765f6

Please sign in to comment.