diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 02d06fb3..087ba46d 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -6,6 +6,7 @@ use itertools::izip; use itertools::Itertools; use ndarray::prelude::*; use ndarray::Zip; +use num_traits::ConstOne; use num_traits::ConstZero; use num_traits::Zero; use rayon::prelude::*; @@ -163,7 +164,6 @@ impl Stark { Self::compute_quotient_segments( &master_main_table, &master_aux_table, - fri.domain, quotient_domain, &challenges, "ient_combination_weights, @@ -399,38 +399,21 @@ impl Stark { } fn compute_quotient_segments( - master_main_table: &MasterMainTable, - master_aux_table: &MasterAuxTable, - fri_domain: ArithmeticDomain, + main_table: &MasterMainTable, + aux_table: &MasterAuxTable, quotient_domain: ArithmeticDomain, challenges: &Challenges, quotient_combination_weights: &[XFieldElement], ) -> (Array2, Array1>) { let (Some(main_quotient_domain_codewords), Some(aux_quotient_domain_codewords)) = ( - master_main_table.quotient_domain_table(), - master_aux_table.quotient_domain_table(), + main_table.quotient_domain_table(), + aux_table.quotient_domain_table(), ) else { profiler!(start "quotient calculation (just-in-time)"); - profiler!(start "interpolate" ("LDE")); - let randomized_main_columns = (0..MasterMainTable::NUM_COLUMNS) - .into_par_iter() - .map(|i| master_main_table.randomized_column_interpolant(i)) - .collect::>() - .into(); - let randomized_aux_columns = (0..MasterAuxTable::NUM_COLUMNS) - .into_par_iter() - .map(|i| master_aux_table.randomized_column_interpolant(i)) - .collect::>() - .into(); - profiler!(stop "interpolate"); - let (fri_domain_quotient_segment_codewords, quotient_segment_polynomials) = Self::compute_quotient_segments_with_jit_lde( - randomized_main_columns, - randomized_aux_columns, - master_main_table.trace_domain(), - master_main_table.randomized_trace_domain(), - master_main_table.fri_domain(), + main_table, + aux_table, challenges, quotient_combination_weights, ); @@ -446,7 +429,7 @@ impl Stark { let quotient_codeword = all_quotients_combined( main_quotient_domain_codewords, aux_quotient_domain_codewords, - master_main_table.trace_domain(), + main_table.trace_domain(), quotient_domain, challenges, quotient_combination_weights, @@ -458,6 +441,8 @@ impl Stark { profiler!(start "quotient LDE" ("LDE")); let quotient_segment_polynomials = Self::interpolate_quotient_segments(quotient_codeword, quotient_domain); + + let fri_domain = main_table.fri_domain(); let fri_domain_quotient_segment_codewords = Self::fri_domain_segment_polynomials(quotient_segment_polynomials.view(), fri_domain); profiler!(stop "quotient LDE"); @@ -947,59 +932,85 @@ impl Stark { /// there. The resulting coset-quotients are linearly recombined to produce the /// quotient segment codewords. fn compute_quotient_segments_with_jit_lde( - randomized_main_columns: Array1>, - randomized_aux_columns: Array1>, - trace_domain: ArithmeticDomain, - randomized_trace_domain: ArithmeticDomain, - fri_domain: ArithmeticDomain, + main_table: &MasterMainTable, + aux_table: &MasterAuxTable, challenges: &Challenges, quotient_combination_weights: &[XFieldElement], ) -> (Array2, Array1>) { - let num_rows = randomized_trace_domain.length; - let root_order = (num_rows * NUM_QUOTIENT_SEGMENTS).try_into().unwrap(); + // A factor to trade off memory use against compute time. The higher the factor, + // - the smaller the domain the trace polynomials are evaluated over, with the + // randomized trace domain at one extreme and 1 (?) at the other, and + // - the more cosets the trace polynomials are evaluated over, with + // `NUM_QUOTIENT_SEGMENTS` at one extreme and a multiple of that at the other. + const EVALUATION_DOMAIN_SIZE_REDUCTION: usize = 2; + const NUM_COSETS: usize = NUM_QUOTIENT_SEGMENTS * EVALUATION_DOMAIN_SIZE_REDUCTION; + + assert!(EVALUATION_DOMAIN_SIZE_REDUCTION.is_power_of_two()); + + let mut working_domain = main_table.randomized_trace_domain(); + for _ in 0..EVALUATION_DOMAIN_SIZE_REDUCTION.ilog2() { + working_domain = working_domain.halve().unwrap(); + } + let working_domain = working_domain; + let fri_domain = main_table.fri_domain(); + + let num_rows = working_domain.length; + let coset_root_order = (num_rows * NUM_COSETS).try_into().unwrap(); - // the powers of ι define `num_quotient_segments`-many cosets of the randomized trace domain - let iota = BFieldElement::primitive_root_of_unity(root_order) + // the powers of ι define `num_quotient_segments`-many cosets of the trace domain + let iota = BFieldElement::primitive_root_of_unity(coset_root_order) .expect("Cannot find ι, a primitive nth root of unity of the right order n."); - let domain = ArithmeticDomain::of_length(num_rows).unwrap(); + let psi = fri_domain.offset; // for every coset, evaluate constraints profiler!(start "zero-initialization"); // column majority (“`F`”) for contiguous column slices let mut quotient_multicoset_evaluations = - ndarray_helper::par_zeros((num_rows, NUM_QUOTIENT_SEGMENTS).f()); + ndarray_helper::par_zeros((num_rows, NUM_COSETS).f()); let mut main_columns = ndarray_helper::par_zeros((num_rows, MasterMainTable::NUM_COLUMNS).f()); let mut aux_columns = ndarray_helper::par_zeros((num_rows, MasterAuxTable::NUM_COLUMNS).f()); profiler!(stop "zero-initialization"); + profiler!(start "poly interpolate" ("LDE")); + let randomized_main_columns = (0..MasterMainTable::NUM_COLUMNS) + .into_par_iter() + .map(|i| main_table.randomized_column_interpolant(i)) + .collect(); + let randomized_aux_columns = (0..MasterAuxTable::NUM_COLUMNS) + .into_par_iter() + .map(|i| aux_table.randomized_column_interpolant(i)) + .collect(); + let randomized_main_columns = Array1::from_vec(randomized_main_columns); + let randomized_aux_columns = Array1::from_vec(randomized_aux_columns); + profiler!(stop "poly interpolate"); + profiler!(start "calculate quotients"); - for (coset_index, quotient_column) in (0..u64::try_from(NUM_QUOTIENT_SEGMENTS).unwrap()) - .zip(quotient_multicoset_evaluations.columns_mut()) + for (coset_index, quotient_column) in + (0..).zip(quotient_multicoset_evaluations.columns_mut()) { // always also offset by fri domain offset to avoid division-by-zero errors - let domain = domain.with_offset(iota.mod_pow(coset_index) * fri_domain.offset); - profiler!(start "evaluate" ("LDE")); - Zip::from(randomized_main_columns.axis_iter(Axis(0))) + let working_domain = working_domain.with_offset(iota.mod_pow(coset_index) * psi); + profiler!(start "poly evaluate" ("LDE")); + Zip::from(randomized_main_columns.view()) .and(main_columns.axis_iter_mut(Axis(1))) - .into_par_iter() - .for_each(|(trace_poly, target_column)| { - Array1::from(domain.evaluate(&trace_poly[[]])).move_into(target_column); + .for_each(|trace_poly, target_column| { + Array1::from(working_domain.evaluate(trace_poly)).move_into(target_column); }); - Zip::from(randomized_aux_columns.axis_iter(Axis(0))) + Zip::from(randomized_aux_columns.view()) .and(aux_columns.axis_iter_mut(Axis(1))) - .into_par_iter() - .for_each(|(trace_poly, target_column)| { - Array1::from(domain.evaluate(&trace_poly[[]])).move_into(target_column); + .for_each(|trace_poly, target_column| { + Array1::from(working_domain.evaluate(trace_poly)).move_into(target_column); }); - profiler!(stop "evaluate"); + profiler!(stop "poly evaluate"); + profiler!(start "AIR evaluation" ("AIR")); let all_quotients = all_quotients_combined( main_columns.view(), aux_columns.view(), - trace_domain, - domain, + main_table.trace_domain(), + working_domain, challenges, quotient_combination_weights, ); @@ -1009,11 +1020,10 @@ impl Stark { profiler!(stop "calculate quotients"); profiler!(start "segmentify"); - let segmentification = Self::segmentify( + let segmentification = Self::segmentify::( quotient_multicoset_evaluations, - fri_domain.offset, + psi, iota, - randomized_trace_domain, fri_domain, ); profiler!(stop "segmentify"); @@ -1022,128 +1032,212 @@ impl Stark { } /// Map a matrix whose columns represent the evaluation of a high-degree - /// polynomial on all cosets of the trace domain, to + /// polynomial f on all cosets (i.e., a full partition) of some domain, to /// 1. a matrix of segment codewords (on the FRI domain), and /// 2. an array of matching segment polynomials, /// /// such that the segment polynomials correspond to the interleaving split of /// the original high-degree polynomial. /// - /// For example, let f(X) have degree 2N where N is the trace domain length. - /// Then the input is an Nx2 matrix representing the values of f(X) on the trace - /// domain and its coset. The segment polynomials are f_E(X) and f_O(X) such - /// that f(X) = f_E(X^2) + X*f_O(X^2) and the segment codewords are their - /// evaluations on the FRI domain. + /// For example, let f(X) have degree M·N where N is the chosen domain's length. + /// Then the input is an N×M matrix representing the values of f(X) on the + /// chosen domain and its cosets: + /// + /// ```txt + /// ⎛ ⋮ ⋮ ⋮ ⎞ ┬ + /// ⎜ f(coset_0) … f(coset_M-1) ⎟ domain length + /// ⎝ ⋮ ⋮ ⋮ ⎠ ┴ + /// + /// ├───────── NUM_COSETS ────────┤ + /// ``` + /// + /// The `NUM_SEGMENTS` (=:`K`) produced segment polynomials are f_i(X) such that + /// f(X) = Σ_k x^k · f_k(X^K). + /// For example, for `K = 2`, this is f(X) = f_E(X²) + X·f_O(X²). + /// + /// The produced segment codewords are the segment polynomial's evaluations on + /// the FRI domain: + /// + /// ```txt + /// ⎛ ⋮ ⋮ ⋮ ⎞ ┬ + /// ⎜ f_0(FRI_dom) … f_K-1(FRI_dom) ⎟ FRI domain length + /// ⎝ ⋮ ⋮ ⋮ ⎠ ┴ + /// + /// ├────────── NUM_SEGMENTS ─────────┤ + /// ``` + // + // The mechanics of this function are backed by some serious maths. The main + // idea is based on the segmentation formula. For K segments, this is + // + // f(X) = Σ_{k=0}^{K-1} X^k · f_k(X^K) + // + // where each f_k is one segment. When substituting X for X·ξ^l, where ξ is a + // Kth root of unity (i.e., ξ^K = 1), this gives rise to K equations, where + // l ∈ { 0, …, K-1 }: + // + // f(X·ξ^0) = Σ_{k=0}^{K-1} (X·ξ^0)^k · f_k(X^K) + // ⋮ + // f(X·ξ^(K-1)) = Σ_{k=0}^{K-1} (X·ξ^(K-1))^k · f_k(X^K) + // + // Note how the indeterminates of the f_k are identical for all rows. That is, + // the mapping between f's evaluations on (the “right”) cosets and f's segments + // is a linear one. // - // This method is factored out from `compute_quotient_segments` for the purpose - // of testing. Conceptually, it belongs there. - fn segmentify( + // ⎛ ⋮ ⎞ ⎛ ⋮ ⎞ ⎛ ⋮ ⎞ + // ⎜ f(X·ξ^l) ⎟ = ⎜ … X^k · ξ^(k·l) … ⎟ · ⎜ f_k(X^K) ⎟ + // ⎝ ⋮ ⎠ ⎝ ⋮ ⎠ ⎝ ⋮ ⎠ + // + // This function works by applying the inversion of the coefficient matrix to + // the function's input, i.e., to the left hand side of above equation. + // Inverting the coefficient matrix efficiently benefits from additional + // observations. As a first step, the coefficient matrix is decomposed. + // Operator “∘” denotes the Hadamard, i.e., element-wise product. + // + // ⎛ ⋮ ⎞ ⎛ ⎛ ⋮ ⎞ ⎛ ⋮ ⎞ ⎞ + // = ⎜ … ξ^(k·l) … ⎟ · ⎜ ⎜ X^k ⎟ ∘ ⎜ f_k(X^K) ⎟ ⎟ + // ⎝ ⋮ ⎠ ⎝ ⎝ ⋮ ⎠ ⎝ ⋮ ⎠ ⎠ + // + // The coefficient matrix has dimensions K×K. Since ξ is a Kth root of unity, + // above matrix is an NTT matrix. That means its application can be efficiently + // reverted by performing iNTTs. + // The final step is elementwise multiplication with the vector (X^(-k)) to + // get the segment polynomials. + // + // For reasons of efficiency, this function does not operate on polynomials in + // monomial coefficient form, but on polynomial evaluations on some domain, + // i.e., codewords. + // Also for reasons of efficiency, the domain length N is a power of two, and + // the evaluation points are multiples of an Nth root of unity, ω. In order to + // avoid divisions by zero, the domain is offset by Ψ. Furthermore, the offset + // of a coset is some power of ι, which itself is a root of unity of order N·M, + // where M is the number of cosets. That is, ι^M = ω, and ω^N = 1. + // Summarizing, this function's input is a matrix of the following form: + // + // ⎛ ⋮ ⎞╷ ┬ ⎛ ⋮ ⎞ + // ⎜ … f(Ψ · ι^j · ω^i) … ⎟i N = ⎜ … f(Ψ · ι^(j + iM)) … ⎟ + // ⎝ ⋮ ⎠↓ ┴ ⎝ ⋮ ⎠ + // ╶─────────── j ─────────→ + // ├─────────── M ──────────┤ + // + // In order to kick off the series of steps derived & outlined above, this + // matrix needs to be rearranged. The desired shape can be derived by + // substituting the indeterminate X for the points at which f is evaluated, + // Ψ · ι^j · ω^i. Let L such that N·M = L·K. Observe that ξ being a Kth root of + // unity implies ξ = ω^(N/K) = ι^(N·M/K) = ι^L. + // + // ⎛ ⋮ ⎞ ⎛ ⋮ ⎞ ┬ + // ⎜ f(X·ξ^l) ⎟ ↦ ⎜ … f(ψ · ι^(j + i·M + l·L)) … ⎟ L + // ⎝ ⋮ ⎠ ⎝ ⋮ ⎠ ┴ + // + // ├────────────── K ──────────────┤ + // + // Helpful insights to understand the matrix re-arrangement are: + // - the various powers of ι, i.e., { ι^(j + i·M) | 0 ⩽ i < N, 0 ⩽ j < M }, + // sweep the entire input matrix (which has dimensions N×M) + // - ι is a (primitive) (N·M)th root of unity and thus, _all_ powers of ι are + // required to index the entirety of the input matrix + // - the re-arranged matrix (which has dimensions L×K) has all the same entries + // as the input matrix + // + // The map that satisfies all of these re-arrangement constraints is + // (i, j) ↦ ((j + i·M) % L, (j + i·M) // L) + // which has the inverse + // (a, b) ↦ ((a + b·L) // M, (a + b·L) % M). + // + // Even though this function conceptually belongs in + // `compute_quotient_segments_with_jit_lde`, it is factored out to simplify + // testing and to allow for abstraction in the conceptual parent function. + fn segmentify( quotient_multicoset_evaluations: Array2, psi: BFieldElement, iota: BFieldElement, - randomized_trace_domain: ArithmeticDomain, fri_domain: ArithmeticDomain, ) -> (Array2, Array1>) { - let num_rows = randomized_trace_domain.length; - let num_segments = quotient_multicoset_evaluations.ncols(); + let num_input_rows = quotient_multicoset_evaluations.nrows(); + let num_cosets = quotient_multicoset_evaluations.ncols(); + let num_output_rows = num_input_rows * num_cosets / NUM_SEGMENTS; + + assert!(num_input_rows.is_power_of_two()); + assert!(num_cosets.is_power_of_two()); + assert!(num_output_rows.is_power_of_two()); + assert!(NUM_SEGMENTS.is_power_of_two()); assert!( - num_rows > num_segments, - "trace domain length: {num_rows} versus num segments: {num_segments}", + num_input_rows >= num_cosets, + "working domain length: {num_input_rows} versus num cosets: {num_cosets}", + ); + assert!( + num_cosets >= NUM_SEGMENTS, + "num cosets: {num_cosets} versus num segments: {NUM_SEGMENTS}", ); - // Matrix `quotients` contains q(Ψ · ι^j · ω^i) in location (i,j) where ω is the - // trace domain generator, and where iota is an Fth root of ω such that ι^F = ω, - // where F is `num_quotient_segments`. So `quotients` contains q(Ψ · ι^(j+i·F)). - - // We need F-tuples from this matrix of elements separated by N/F rows. - let step_size = num_rows / num_segments; - let quotient_segments = (0..num_rows) + // Re-arrange data in preparation for iNTT: + // Move appropriate powers of ξ^(k·l) with the same k into the same row. Change + // the matrix's dimensions from N×M to L×K, with row majority (“`C`”) to get + // contiguous row slices for iNTT. + let mut quotient_segments = ndarray_helper::par_zeros((num_output_rows, NUM_SEGMENTS)); + quotient_segments + .axis_iter_mut(Axis(0)) .into_par_iter() - .flat_map(|jif| { - let col_idx = jif % num_segments; - let start_row = (jif - col_idx) / num_segments; - quotient_multicoset_evaluations - .slice(s![start_row..; step_size, col_idx]) - .to_vec() - }) - .collect(); - let mut quotient_segments = - Array2::from_shape_vec((num_rows, num_segments), quotient_segments).unwrap(); - - // Matrix `quotient_segments` now contains q(Ψ · ι^(j+i·F+l·N/F)) in cell - // (j+i·F, l). So *row* j+i·F contains {q(Ψ · ι^(j+i·F+l·N/F)) for l in [0..F-1]}. + .enumerate() + .for_each(|(output_row_idx, mut output_row)| { + for (output_col_idx, cell) in output_row.iter_mut().enumerate() { + let exponent_of_iota = output_row_idx + output_col_idx * num_output_rows; + let input_row_idx = exponent_of_iota / num_cosets; + let input_col_idx = exponent_of_iota % num_cosets; + *cell = quotient_multicoset_evaluations[[input_row_idx, input_col_idx]]; + } + }); - // apply inverse of Vandermonde matrix for ω^(N/F) matrix to every row - let n_over_f = (num_rows / num_segments).try_into().unwrap(); - let xi = randomized_trace_domain.generator.mod_pow_u32(n_over_f); - assert_eq!(bfe!(1), xi.mod_pow(num_segments.try_into().unwrap())); - let logn = num_segments.ilog2(); + // apply inverse of Vandermonde matrix for ξ = ι^L to every row + let xi = iota.mod_pow_u32(num_output_rows.try_into().unwrap()); + debug_assert_eq!(bfe!(1), xi.mod_pow(NUM_SEGMENTS.try_into().unwrap())); + let log2_num_segments = NUM_SEGMENTS.ilog2(); quotient_segments .axis_iter_mut(Axis(0)) .into_par_iter() - .for_each(|mut row| { - // `.unwrap()` is safe because `quotient_segments` is in row-major order - let row = row.as_slice_mut().unwrap(); - intt(row, xi, logn); - }); + .for_each(|mut row| intt(row.as_slice_mut().unwrap(), xi, log2_num_segments)); - // scale every row by Ψ^-k · ι^(-k(j+i·F)) + // scale every row by Ψ^-k · ι^(-k(j+i·M)) let num_threads = std::thread::available_parallelism() .map(|t| t.get()) .unwrap_or(1); - let chunk_size = (num_rows / num_threads).max(1); + let chunk_size = (num_output_rows / num_threads).max(1); let iota_inverse = iota.inverse(); let psi_inverse = psi.inverse(); quotient_segments .axis_chunks_iter_mut(Axis(0), chunk_size) .into_par_iter() .enumerate() - .for_each(|(thread, mut chunk)| { - let chunk_start = thread * chunk_size; - let mut psi_iotajif_inv = - psi_inverse * iota_inverse.mod_pow(chunk_start.try_into().unwrap()); - for mut row in chunk.rows_mut() { - let mut psi_iotajif_invk = xfe!(1); + .for_each(|(chunk_idx, mut chunk_of_rows)| { + let chunk_start = (chunk_idx * chunk_size).try_into().unwrap(); + let mut psi_iotajim_inv = psi_inverse * iota_inverse.mod_pow(chunk_start); + for mut row in chunk_of_rows.rows_mut() { + let mut psi_iotajim_invk = XFieldElement::ONE; for cell in &mut row { - *cell *= psi_iotajif_invk; - psi_iotajif_invk *= psi_iotajif_inv; + *cell *= psi_iotajim_invk; + psi_iotajim_invk *= psi_iotajim_inv; } - psi_iotajif_inv *= iota_inverse; + psi_iotajim_inv *= iota_inverse; } }); - // Matrix `quotients_codewords` contains q_k(Ψ^F · ω^(j+i·F)) in cell (j+i·F, k). - // To see this, observe that - // - // ⎛ … ⎞ ⎛ ⎛ … ⎞ ⎛ … ⎞ ⎞ - // ⎜ … ξ^(l·k) … ⎟ · ⎜ ⎜ ψ^k · ι^(j·k+i·k·F) ⎟ ∘ ⎜ q_k(ψ^F · ω^(j+i·F)) ⎟ ⎟ - // ⎝ … ⎠ ⎝ ⎝ … ⎠ ⎝ … ⎠ ⎠ - // = - // ⎛ … ⎞ ⎛ … ⎞ - // ⎜ … ψ^k · ι^(j·k+i·k·F+l·k·N/F) … ⎟ · ⎜ q_k(ψ^F · ω^(j+i·F)) ⎟ - // ⎝ … ⎠ ⎝ … ⎠ - // = - // ⎛ … ⎞ - // ⎜ q(ψ · ι^j · ω^(i + l · N/F)) ⎟ - // ⎝ … ⎠ - // low-degree extend columns from trace to FRI domain - let mut quotient_codewords = Array2::zeros([fri_domain.length, num_segments]); - let mut quotient_polynomials = Array1::zeros([num_segments]); + let segment_domain_offset = psi.mod_pow(NUM_SEGMENTS.try_into().unwrap()); + let segment_domain = ArithmeticDomain::of_length(num_output_rows) + .unwrap() + .with_offset(segment_domain_offset); + + let mut quotient_codewords = Array2::zeros([fri_domain.length, NUM_SEGMENTS]); + let mut quotient_polynomials = Array1::zeros([NUM_SEGMENTS]); Zip::from(quotient_segments.axis_iter(Axis(1))) .and(quotient_codewords.axis_iter_mut(Axis(1))) .and(quotient_polynomials.axis_iter_mut(Axis(0))) - .par_for_each(|segment, codeword, polynomial| { - let psi_exponent = num_segments.try_into().unwrap(); - let segment_domain_offset = psi.mod_pow(psi_exponent); - let segment_domain = randomized_trace_domain.with_offset(segment_domain_offset); - - // `.to_vec()` is necessary because `segment` is a column of `quotient_segments`, - // which is in row-major order + .par_for_each(|segment, target_codeword, target_polynomial| { + // `quotient_segments` is in row-major order, requiring `segment.to_vec()` let interpolant = segment_domain.interpolate(&segment.to_vec()); let lde_codeword = fri_domain.evaluate(&interpolant); - Array1::from(lde_codeword).move_into(codeword); - Array0::from_elem((), interpolant).move_into(polynomial); + Array1::from(lde_codeword).move_into(target_codeword); + Array0::from_elem((), interpolant).move_into(target_polynomial); }); (quotient_codewords, quotient_polynomials) @@ -1247,7 +1341,6 @@ pub(crate) mod tests { use isa::instruction::Instruction; use isa::op_stack::OpStackElement; use itertools::izip; - use num_traits::Zero; use proptest::collection::vec; use proptest::prelude::*; use proptest_arbitrary_interop::arb; @@ -1386,8 +1479,7 @@ pub(crate) mod tests { debug_assert!(main.fri_domain_table().is_none()); debug_assert!(aux.fri_domain_table().is_none()); - let jit_segments = - Stark::compute_quotient_segments(&main, &aux, fri_dom, quot_dom, &ch, &weights); + let jit_segments = Stark::compute_quotient_segments(&main, &aux, quot_dom, &ch, &weights); debug_assert!(main.fri_domain_table().is_none()); main.maybe_low_degree_extend_all_columns(); @@ -1397,8 +1489,7 @@ pub(crate) mod tests { aux.maybe_low_degree_extend_all_columns(); debug_assert!(aux.fri_domain_table().is_some()); - let cache_segments = - Stark::compute_quotient_segments(&main, &aux, fri_dom, quot_dom, &ch, &weights); + let cache_segments = Stark::compute_quotient_segments(&main, &aux, quot_dom, &ch, &weights); assert_eq!(jit_segments, cache_segments); } @@ -2486,157 +2577,49 @@ pub(crate) mod tests { assert_polynomial_equals_recomposed_segments(&f, &segments_7, x); } - #[proptest] - fn quotient_segments_of_old_and_new_methods_are_identical( - #[strategy(2_usize..8)] _log_trace_length: usize, - #[strategy(Just(1 << #_log_trace_length))] trace_length: usize, - #[strategy(Just(2 * #trace_length))] randomized_trace_length: usize, - #[strategy(arb())] - #[filter(!#offset.is_zero())] - offset: BFieldElement, - #[strategy(arb())] main_polynomials: [Polynomial; - MasterMainTable::NUM_COLUMNS], - #[strategy(arb())] aux_polynomials: [Polynomial; - MasterAuxTable::NUM_COLUMNS], - #[strategy(arb())] challenges: Challenges, - #[strategy(arb())] quotient_weights: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS], - ) { - // truncate polynomials to randomized trace domain length many coefficients - fn truncate_coefficients( - mut polynomials: Vec>, - num_coefficients: usize, - ) -> Array1> { - polynomials - .par_iter_mut() - .for_each(|p| p.coefficients.truncate(num_coefficients)); - Array1::from_vec(polynomials) - } - - let main_polynomials = - truncate_coefficients(main_polynomials.to_vec(), randomized_trace_length); - let aux_polynomials = - truncate_coefficients(aux_polynomials.to_vec(), randomized_trace_length); - - let trace_domain = ArithmeticDomain::of_length(trace_length)?; - let randomized_trace_domain = ArithmeticDomain::of_length(randomized_trace_length)?; - let fri_domain = - ArithmeticDomain::of_length(4 * randomized_trace_length)?.with_offset(offset); - let quotient_domain = - ArithmeticDomain::of_length(4 * randomized_trace_length)?.with_offset(offset); - - let (quotient_segment_codewords_old, quotient_segment_polynomials_old) = - compute_quotient_segments_old( - main_polynomials.view(), - aux_polynomials.view(), - trace_domain, - quotient_domain, - fri_domain, - &challenges, - "ient_weights, - ); - - let (quotient_segment_codewords_new, quotient_segment_polynomials_new) = - Stark::compute_quotient_segments_with_jit_lde( - main_polynomials, - aux_polynomials, - trace_domain, - randomized_trace_domain, - fri_domain, - &challenges, - "ient_weights, - ); - - prop_assert_eq!( - quotient_segment_codewords_old, - quotient_segment_codewords_new - ); - prop_assert_eq!( - quotient_segment_polynomials_old, - quotient_segment_polynomials_new - ); - } - - fn compute_quotient_segments_old( - main_polynomials: ArrayView1>, - aux_polynomials: ArrayView1>, - trace_domain: ArithmeticDomain, - quotient_domain: ArithmeticDomain, - fri_domain: ArithmeticDomain, - challenges: &Challenges, - quotient_weights: &[XFieldElement], - ) -> (Array2, Array1>) { - let mut main_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, MasterMainTable::NUM_COLUMNS]); - Zip::from(main_quotient_domain_codewords.axis_iter_mut(Axis(1))) - .and(main_polynomials.axis_iter(Axis(0))) - .for_each(|codeword, polynomial| { - Array1::from_vec(quotient_domain.evaluate(&polynomial[()])).move_into(codeword); - }); - let mut aux_quotient_domain_codewords = - Array2::::zeros([quotient_domain.length, MasterAuxTable::NUM_COLUMNS]); - Zip::from(aux_quotient_domain_codewords.axis_iter_mut(Axis(1))) - .and(aux_polynomials.axis_iter(Axis(0))) - .for_each(|codeword, polynomial| { - Array1::from_vec(quotient_domain.evaluate(&polynomial[()])).move_into(codeword); - }); - - let quotient_codeword = all_quotients_combined( - main_quotient_domain_codewords.view(), - aux_quotient_domain_codewords.view(), - trace_domain, - quotient_domain, - challenges, - quotient_weights, - ); - let quotient_codeword = Array1::from(quotient_codeword); - let quotient_segment_polynomials = - Stark::interpolate_quotient_segments(quotient_codeword, fri_domain); - let quotient_segment_codewords = - Stark::fri_domain_segment_polynomials(quotient_segment_polynomials.view(), fri_domain); - - (quotient_segment_codewords, quotient_segment_polynomials) - } - #[proptest] fn polynomial_segments_cohere_with_originating_polynomial( #[strategy(2_usize..8)] log_trace_length: usize, - #[strategy(1_usize..#log_trace_length.min(3))] log_num_segments: usize, + #[strategy(2_usize..=#log_trace_length.min(4))] log_num_cosets: usize, #[strategy(1_usize..6)] log_expansion_factor: usize, - #[strategy(vec(arb(), (1 << #log_num_segments) * (1 << #log_trace_length)))] + #[strategy(vec(arb(), (1 << #log_num_cosets) * (1 << #log_trace_length)))] coefficients: Vec, #[strategy(arb())] random_point: XFieldElement, ) { - let polynomial = Polynomial::new(coefficients); + const NUM_SEGMENTS: usize = 4; - let num_segments = 1 << log_num_segments; + let num_cosets = 1 << log_num_cosets; let trace_length = 1 << log_trace_length; let expansion_factor = 1 << log_expansion_factor; + let polynomial = Polynomial::new(coefficients); let iota = - BFieldElement::primitive_root_of_unity((trace_length * num_segments) as u64).unwrap(); + BFieldElement::primitive_root_of_unity((trace_length * num_cosets) as u64).unwrap(); let psi = bfe!(7); - let trace_domain = ArithmeticDomain::of_length(trace_length).unwrap(); - let fri_domain = ArithmeticDomain::of_length(trace_length * expansion_factor) - .unwrap() - .with_offset(psi); + let trace_domain = ArithmeticDomain::of_length(trace_length)?; + let fri_domain = + ArithmeticDomain::of_length(trace_length * expansion_factor)?.with_offset(psi); - let multi_coset_values = (0..u32::try_from(num_segments).unwrap()) + let coset_evaluations = (0..u32::try_from(num_cosets)?) .flat_map(|i| { let coset = trace_domain.with_offset(iota.mod_pow_u32(i) * psi); coset.evaluate(&polynomial) }) - .collect_vec(); - let multi_coset_values = - Array2::from_shape_vec((trace_length, num_segments).f(), multi_coset_values).unwrap(); + .collect(); + let coset_evaluations = + Array2::from_shape_vec((trace_length, num_cosets).f(), coset_evaluations)?; let (actual_segment_codewords, segment_polynomials) = - Stark::segmentify(multi_coset_values, psi, iota, trace_domain, fri_domain); + Stark::segmentify::(coset_evaluations, psi, iota, fri_domain); + + assert_eq!(NUM_SEGMENTS, actual_segment_codewords.ncols()); + assert_eq!(NUM_SEGMENTS, segment_polynomials.len()); let segments_evaluated = (0..) .zip(&segment_polynomials) .map(|(segment_index, segment_polynomial)| -> XFieldElement { let point_to_the_seg_idx = random_point.mod_pow_u32(segment_index); - let point_to_the_num_seg = random_point.mod_pow_u32(num_segments as u32); + let point_to_the_num_seg = random_point.mod_pow_u32(NUM_SEGMENTS as u32); let segment_in_point_to_the_num_seg = segment_polynomial.evaluate_in_same_field(point_to_the_num_seg); point_to_the_seg_idx * segment_in_point_to_the_num_seg @@ -2650,8 +2633,7 @@ pub(crate) mod tests { .flat_map(|polynomial| Array1::from(fri_domain.evaluate(polynomial))) .collect_vec(); let segments_codewords = - Array2::from_shape_vec((fri_domain.length, num_segments).f(), segments_codewords) - .unwrap(); + Array2::from_shape_vec((fri_domain.length, NUM_SEGMENTS).f(), segments_codewords)?; prop_assert_eq!(segments_codewords, actual_segment_codewords); }