From 8d41b00cbc744233d0e7bea2d0e1e217bbbcd5b7 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Wed, 8 Jan 2025 11:39:44 +0200 Subject: [PATCH] Use PreProcessedColumn Trait Instead of Enum --- .../src/constraint_framework/component.rs | 24 ++++-- .../constraint_framework/expr/evaluator.rs | 4 +- .../prover/src/constraint_framework/info.rs | 6 +- .../prover/src/constraint_framework/logup.rs | 2 +- crates/prover/src/constraint_framework/mod.rs | 9 ++- .../preprocessed_columns.rs | 75 +++++++++---------- .../constraint_framework/relation_tracker.rs | 3 +- crates/prover/src/examples/blake/air.rs | 73 +++++++++--------- crates/prover/src/examples/blake/round/mod.rs | 8 +- .../src/examples/blake/scheduler/mod.rs | 8 +- .../examples/blake/xor_table/constraints.rs | 12 +-- .../src/examples/blake/xor_table/gen.rs | 3 +- .../src/examples/blake/xor_table/mod.rs | 2 +- crates/prover/src/examples/plonk/mod.rs | 14 ++-- crates/prover/src/examples/poseidon/mod.rs | 12 ++- .../prover/src/examples/state_machine/mod.rs | 14 ++-- .../src/examples/xor/gkr_lookups/mle_eval.rs | 17 +++-- 17 files changed, 157 insertions(+), 129 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 86f00609c..dee71d761 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::fmt::{self, Display, Formatter}; use std::iter::zip; use std::ops::Deref; +use std::rc::Rc; use itertools::Itertools; #[cfg(feature = "parallel")] @@ -49,7 +50,7 @@ pub struct TraceLocationAllocator { /// Mapping of tree index to next available column offset. next_tree_offsets: TreeVec, /// Mapping of preprocessed columns to their index. - preprocessed_columns: HashMap, + preprocessed_columns: HashMap, usize>, /// Controls whether the preprocessed columns are dynamic or static (default=Dynamic). preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode, } @@ -81,30 +82,37 @@ impl TraceLocationAllocator { } /// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup. - pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self { + pub fn new_with_preproccessed_columns( + preprocessed_columns: &[Rc], + ) -> Self { Self { next_tree_offsets: Default::default(), preprocessed_columns: preprocessed_columns .iter() .enumerate() - .map(|(i, &col)| (col, i)) + .map(|(i, col)| (col.clone(), i)) .collect(), preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static, } } - pub const fn preprocessed_columns(&self) -> &HashMap { + pub const fn preprocessed_columns(&self) -> &HashMap, usize> { &self.preprocessed_columns } // validates that `self.preprocessed_columns` is consistent with // `preprocessed_columns`. // I.e. preprocessed_columns[i] == self.preprocessed_columns[i]. - pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) { + pub fn validate_preprocessed_columns( + &self, + preprocessed_columns: &[Rc], + ) { assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len()); - for (column, idx) in self.preprocessed_columns.iter() { - assert_eq!(Some(column), preprocessed_columns.get(*idx)); + let preprocessed_column = preprocessed_columns + .get(*idx) + .expect("Preprocessed column is missing from preprocessed_columns"); + assert_eq!(column.id(), preprocessed_column.id()); } } } @@ -146,7 +154,7 @@ impl FrameworkComponent { let next_column = location_allocator.preprocessed_columns.len(); *location_allocator .preprocessed_columns - .entry(*col) + .entry(col.clone()) .or_insert_with(|| { if matches!( location_allocator.preprocessed_columns_allocation_mode, diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs index 6b9238d23..33605b7c7 100644 --- a/crates/prover/src/constraint_framework/expr/evaluator.rs +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use num_traits::Zero; use super::{BaseExpr, ExtExpr}; @@ -174,7 +176,7 @@ impl EvalAtRow for ExprEvaluator { intermediate } - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, column: Rc) -> Self::F { BaseExpr::Param(column.name().to_string()) } diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index f8a6257e3..6c960ed46 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -22,14 +22,14 @@ use crate::core::pcs::TreeVec; pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, pub n_constraints: usize, - pub preprocessed_columns: Vec, + pub preprocessed_columns: Vec>, pub logup: LogupAtRow, pub arithmetic_counts: ArithmeticCounts, } impl InfoEvaluator { pub fn new( log_size: u32, - preprocessed_columns: Vec, + preprocessed_columns: Vec>, logup_sums: LogupSums, ) -> Self { Self { @@ -70,7 +70,7 @@ impl EvalAtRow for InfoEvaluator { array::from_fn(|_| FieldCounter::one()) } - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, column: Rc) -> Self::F { self.preprocessed_columns.push(column); FieldCounter::one() } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 370987e4c..014312995 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -52,7 +52,7 @@ pub struct LogupAtRow { pub fracs: Vec>, pub is_finalized: bool, /// The value of the `is_first` constant column at current row. - /// See [`super::preprocessed_columns::gen_is_first()`]. + /// See [`super::preprocessed_columns::IsFirst`]. pub is_first: E::F, pub log_size: u32, } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 22809152d..900375c6c 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -13,6 +13,7 @@ mod simd_domain; use std::array; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; +use std::rc::Rc; pub use assert::{assert_constraints, AssertEvaluator}; pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; @@ -87,7 +88,7 @@ pub trait EvalAtRow { mask_item } - fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, _column: Rc) -> Self::F { let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]); mask_item } @@ -172,11 +173,11 @@ macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { if self.logup.fracs.is_empty() { - self.logup.is_first = self.get_preprocessed_column( - crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst( + self.logup.is_first = self.get_preprocessed_column(std::rc::Rc::new( + crate::constraint_framework::preprocessed_columns::IsFirst::new( self.logup.log_size, ), - ); + )); self.logup.is_finalized = false; } self.logup.fracs.push(fraction.clone()); diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index 853b124b7..a6df63bf6 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; use std::hash::Hash; +use std::rc::Rc; use std::simd::Simd; use num_traits::{One, Zero}; @@ -18,8 +19,7 @@ const SIMD_ENUMERATION_0: PackedM31 = unsafe { ])) }; -// TODO(Gali): Rename to PrerocessedColumn. -pub trait PreprocessedColumnTrait: Debug { +pub trait PreprocessedColumn: Debug { fn name(&self) -> &'static str; fn id(&self) -> String; fn log_size(&self) -> u32; @@ -29,13 +29,13 @@ pub trait PreprocessedColumnTrait: Debug { /// Generates a column according to the preprocessed column chosen. fn gen_column_simd(&self) -> CircleEvaluation; } -impl PartialEq for dyn PreprocessedColumnTrait { +impl PartialEq for dyn PreprocessedColumn { fn eq(&self, other: &Self) -> bool { self.id() == other.id() } } -impl Eq for dyn PreprocessedColumnTrait {} -impl Hash for dyn PreprocessedColumnTrait { +impl Eq for dyn PreprocessedColumn {} +impl Hash for dyn PreprocessedColumn { fn hash(&self, state: &mut H) { self.id().hash(state); } @@ -51,7 +51,7 @@ impl IsFirst { Self { log_size } } } -impl PreprocessedColumnTrait for IsFirst { +impl PreprocessedColumn for IsFirst { fn name(&self) -> &'static str { "preprocessed_is_first" } @@ -97,7 +97,7 @@ impl Seq { Self { log_size } } } -impl PreprocessedColumnTrait for Seq { +impl PreprocessedColumn for Seq { fn name(&self) -> &'static str { "preprocessed_seq" } @@ -139,7 +139,7 @@ impl XorTable { } } } -impl PreprocessedColumnTrait for XorTable { +impl PreprocessedColumn for XorTable { fn name(&self) -> &'static str { "preprocessed_xor_table" } @@ -174,7 +174,7 @@ impl Plonk { Self { wire } } } -impl PreprocessedColumnTrait for Plonk { +impl PreprocessedColumn for Plonk { fn name(&self) -> &'static str { "preprocessed_plonk" } @@ -212,7 +212,7 @@ impl IsStepWithOffset { } } } -impl PreprocessedColumnTrait for IsStepWithOffset { +impl PreprocessedColumn for IsStepWithOffset { fn name(&self) -> &'static str { "preprocessed_is_step_with_offset" } @@ -251,10 +251,9 @@ impl PreprocessedColumnTrait for IsStepWithOffset { } } -// TODO(ilya): Where should this enum be placed? -// TODO(Gali): Add documentation for the rest of the variants. +// TODO(Gali): Remove Enum. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum PreprocessedColumn { +pub enum PreprocessedColumnEnum { /// A column with `1` at the first position, and `0` elsewhere. IsFirst(u32), Plonk(usize), @@ -263,29 +262,29 @@ pub enum PreprocessedColumn { XorTable(u32, u32, usize), } -impl PreprocessedColumn { +impl PreprocessedColumnEnum { pub const fn name(&self) -> &'static str { match self { - PreprocessedColumn::IsFirst(_) => "preprocessed_is_first", - PreprocessedColumn::Plonk(_) => "preprocessed_plonk", - PreprocessedColumn::Seq(_) => "preprocessed_seq", - PreprocessedColumn::XorTable(..) => "preprocessed_xor_table", + PreprocessedColumnEnum::IsFirst(_) => "preprocessed_is_first", + PreprocessedColumnEnum::Plonk(_) => "preprocessed_plonk", + PreprocessedColumnEnum::Seq(_) => "preprocessed_seq", + PreprocessedColumnEnum::XorTable(..) => "preprocessed_xor_table", } } pub fn log_size(&self) -> u32 { match self { - PreprocessedColumn::IsFirst(log_size) => *log_size, - PreprocessedColumn::Seq(log_size) => *log_size, - PreprocessedColumn::XorTable(log_size, ..) => *log_size, - PreprocessedColumn::Plonk(_) => unimplemented!(), + PreprocessedColumnEnum::IsFirst(log_size) => *log_size, + PreprocessedColumnEnum::Seq(log_size) => *log_size, + PreprocessedColumnEnum::XorTable(log_size, ..) => *log_size, + PreprocessedColumnEnum::Plonk(_) => unimplemented!(), } } /// Returns the values of the column at the given row. pub fn packed_at(&self, vec_row: usize) -> PackedM31 { match self { - PreprocessedColumn::IsFirst(log_size) => { + PreprocessedColumnEnum::IsFirst(log_size) => { assert!(vec_row < (1 << log_size) / N_LANES); if vec_row == 0 { unsafe { @@ -301,7 +300,7 @@ impl PreprocessedColumn { PackedM31::zero() } } - PreprocessedColumn::Seq(log_size) => { + PreprocessedColumnEnum::Seq(log_size) => { assert!(vec_row < (1 << log_size) / N_LANES); PackedM31::broadcast(M31::from(vec_row * N_LANES)) + SIMD_ENUMERATION_0 } @@ -312,14 +311,14 @@ impl PreprocessedColumn { /// Generates a column according to the preprocessed column chosen. pub fn gen_preprocessed_column( - preprocessed_column: &PreprocessedColumn, + preprocessed_column: &PreprocessedColumnEnum, ) -> CircleEvaluation { match preprocessed_column { - PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), - PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { + PreprocessedColumnEnum::IsFirst(log_size) => gen_is_first(*log_size), + PreprocessedColumnEnum::Plonk(_) | PreprocessedColumnEnum::XorTable(..) => { unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") } - PreprocessedColumn::Seq(log_size) => gen_seq(*log_size), + PreprocessedColumnEnum::Seq(log_size) => gen_seq(*log_size), } } } @@ -360,25 +359,23 @@ pub fn gen_seq(log_size: u32) -> CircleEvaluation( - columns: impl Iterator, -) -> Vec> { - columns - .map(PreprocessedColumn::gen_preprocessed_column) - .collect() +pub fn gen_preprocessed_columns<'a>( + columns: impl Iterator>, +) -> Vec> { + columns.map(|col| col.gen_column_simd()).collect() } #[cfg(test)] mod tests { + use super::{IsFirst, PreprocessedColumn, Seq}; use crate::core::backend::simd::m31::N_LANES; - use crate::core::backend::simd::SimdBackend; use crate::core::backend::Column; use crate::core::fields::m31::{BaseField, M31}; const LOG_SIZE: u32 = 8; #[test] fn test_gen_seq() { - let seq = super::gen_seq::(LOG_SIZE); + let seq = Seq::new(LOG_SIZE).gen_column_simd(); for i in 0..(1 << LOG_SIZE) { assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); @@ -388,8 +385,8 @@ mod tests { // TODO(Gali): Add packed_at tests for xor_table and plonk. #[test] fn test_packed_at_is_first() { - let is_first = super::PreprocessedColumn::IsFirst(LOG_SIZE); - let expected_is_first = super::gen_is_first::(LOG_SIZE).to_cpu(); + let is_first = IsFirst::new(LOG_SIZE); + let expected_is_first = is_first.gen_column_simd().to_cpu(); for i in 0..(1 << LOG_SIZE) / N_LANES { assert_eq!( @@ -401,7 +398,7 @@ mod tests { #[test] fn test_packed_at_seq() { - let seq = super::PreprocessedColumn::Seq(LOG_SIZE); + let seq = Seq::new(LOG_SIZE); let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index e4606e7bf..d89cb7420 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::fmt::Debug; +use std::rc::Rc; use itertools::Itertools; use num_traits::Zero; @@ -146,7 +147,7 @@ impl EvalAtRow for RelationTrackerEvaluator<'_> { }) } - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, column: Rc) -> Self::F { column.packed_at(self.vec_row) } diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 424e34f13..985813b71 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -1,3 +1,4 @@ +use std::rc::Rc; use std::simd::u32x16; use itertools::{chain, multiunzip, Itertools}; @@ -8,7 +9,7 @@ use tracing::{span, Level}; use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; use super::xor_table::{xor12, xor4, xor7, xor8, xor9}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn, XorTable}; use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX}; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -26,28 +27,30 @@ use crate::examples::blake::{ round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, }; -const PREPROCESSED_XOR_COLUMNS: [PreprocessedColumn; 20] = [ - PreprocessedColumn::XorTable(12, 4, 0), - PreprocessedColumn::XorTable(12, 4, 1), - PreprocessedColumn::XorTable(12, 4, 2), - PreprocessedColumn::IsFirst(xor12::column_bits::<12, 4>()), - PreprocessedColumn::XorTable(9, 2, 0), - PreprocessedColumn::XorTable(9, 2, 1), - PreprocessedColumn::XorTable(9, 2, 2), - PreprocessedColumn::IsFirst(xor9::column_bits::<9, 2>()), - PreprocessedColumn::XorTable(8, 2, 0), - PreprocessedColumn::XorTable(8, 2, 1), - PreprocessedColumn::XorTable(8, 2, 2), - PreprocessedColumn::IsFirst(xor8::column_bits::<8, 2>()), - PreprocessedColumn::XorTable(7, 2, 0), - PreprocessedColumn::XorTable(7, 2, 1), - PreprocessedColumn::XorTable(7, 2, 2), - PreprocessedColumn::IsFirst(xor7::column_bits::<7, 2>()), - PreprocessedColumn::XorTable(4, 0, 0), - PreprocessedColumn::XorTable(4, 0, 1), - PreprocessedColumn::XorTable(4, 0, 2), - PreprocessedColumn::IsFirst(xor4::column_bits::<4, 0>()), -]; +fn preprocessed_xor_columns() -> [Rc; 20] { + [ + Rc::new(XorTable::new(12, 4, 0)), + Rc::new(XorTable::new(12, 4, 1)), + Rc::new(XorTable::new(12, 4, 2)), + Rc::new(IsFirst::new(xor12::column_bits::<12, 4>())), + Rc::new(XorTable::new(9, 2, 0)), + Rc::new(XorTable::new(9, 2, 1)), + Rc::new(XorTable::new(9, 2, 2)), + Rc::new(IsFirst::new(xor9::column_bits::<9, 2>())), + Rc::new(XorTable::new(8, 2, 0)), + Rc::new(XorTable::new(8, 2, 1)), + Rc::new(XorTable::new(8, 2, 2)), + Rc::new(IsFirst::new(xor8::column_bits::<8, 2>())), + Rc::new(XorTable::new(7, 2, 0)), + Rc::new(XorTable::new(7, 2, 1)), + Rc::new(XorTable::new(7, 2, 2)), + Rc::new(IsFirst::new(xor7::column_bits::<7, 2>())), + Rc::new(XorTable::new(4, 0, 0)), + Rc::new(XorTable::new(4, 0, 1)), + Rc::new(XorTable::new(4, 0, 2)), + Rc::new(IsFirst::new(xor4::column_bits::<4, 0>())), + ] +} #[derive(Serialize)] pub struct BlakeStatement0 { @@ -86,12 +89,7 @@ impl BlakeStatement0 { log_sizes[PREPROCESSED_TRACE_IDX] = chain!( [scheduler_is_first_column_log_size], blake_round_is_first_column_log_sizes, - PREPROCESSED_XOR_COLUMNS.map(|column| match column { - PreprocessedColumn::XorTable(elem_bits, expand_bits, _) => - 2 * (elem_bits - expand_bits), - PreprocessedColumn::IsFirst(log_size) => log_size, - _ => panic!("Unexpected column"), - }), + preprocessed_xor_columns().map(|column| column.log_size()), ) .collect_vec(); @@ -164,16 +162,17 @@ impl BlakeComponents { fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { let log_size = stmt0.log_size; - let scheduler_is_first_column = PreprocessedColumn::IsFirst(log_size); - let blake_round_is_first_columns_iter = ROUND_LOG_SPLIT - .iter() - .map(|l| PreprocessedColumn::IsFirst(log_size + l)); + let scheduler_is_first_column: Rc = Rc::new(IsFirst::new(log_size)); + let blake_round_is_first_columns_iter = ROUND_LOG_SPLIT.iter().map(|l| { + let column: Rc = Rc::new(IsFirst::new(log_size + l)); + column + }); let tree_span_provider = &mut TraceLocationAllocator::new_with_preproccessed_columns( &chain!( [scheduler_is_first_column], blake_round_is_first_columns_iter, - PREPROCESSED_XOR_COLUMNS, + preprocessed_xor_columns(), ) .collect_vec()[..], ); @@ -322,8 +321,10 @@ where let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals( chain![ - vec![gen_is_first(log_size)], - ROUND_LOG_SPLIT.iter().map(|l| gen_is_first(log_size + l)), + vec![IsFirst::new(log_size).gen_column_simd()], + ROUND_LOG_SPLIT + .iter() + .map(|l| IsFirst::new(log_size + l).gen_column_simd()), xor_table::xor12::generate_constant_trace::<12, 4>(), xor_table::xor9::generate_constant_trace::<9, 2>(), xor_table::xor8::generate_constant_trace::<8, 2>(), diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index 4926fe218..20caad03b 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -56,7 +56,7 @@ mod tests { use itertools::Itertools; - use crate::constraint_framework::preprocessed_columns::gen_is_first; + use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::r#gen::{ @@ -92,7 +92,11 @@ mod tests { &round_lookup_elements, ); - let trace = TreeVec::new(vec![vec![gen_is_first(LOG_SIZE)], trace, interaction_trace]); + let trace = TreeVec::new(vec![ + vec![IsFirst::new(LOG_SIZE).gen_column_simd()], + trace, + interaction_trace, + ]); let trace_polys = trace.map_cols(|c| c.interpolate()); let component = BlakeRoundEval { diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index c998ed61b..5150b418f 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -57,7 +57,7 @@ mod tests { use itertools::Itertools; - use crate::constraint_framework::preprocessed_columns::gen_is_first; + use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::RoundElements; @@ -89,7 +89,11 @@ mod tests { &blake_lookup_elements, ); - let trace = TreeVec::new(vec![vec![gen_is_first(LOG_SIZE)], trace, interaction_trace]); + let trace = TreeVec::new(vec![ + vec![IsFirst::new(LOG_SIZE).gen_column_simd()], + trace, + interaction_trace, + ]); let trace_polys = trace.map_cols(|c| c.interpolate()); let component = BlakeSchedulerEval { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 60fef8bfe..0bbe097e7 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -18,27 +18,27 @@ macro_rules! xor_table_eval { // cl is the constant column for the xor: al ^ bl. let al = self .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( + .get_preprocessed_column(std::rc::Rc::new(XorTable::new( ELEM_BITS, EXPAND_BITS, 0, - )); + ))); let bl = self .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( + .get_preprocessed_column(std::rc::Rc::new(XorTable::new( ELEM_BITS, EXPAND_BITS, 1, - )); + ))); let cl = self .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( + .get_preprocessed_column(std::rc::Rc::new(XorTable::new( ELEM_BITS, EXPAND_BITS, 2, - )); + ))); for i in (0..(1 << (2 * EXPAND_BITS))) { let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 55877d02e..519c80434 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -162,7 +162,8 @@ macro_rules! xor_table_gen { ) }) .to_vec(); - constant_trace.push(gen_is_first(column_bits::())); + constant_trace + .push(IsFirst::new(column_bits::()).gen_column_simd()); constant_trace } }; diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index e653e0bb9..66f11182e 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -20,7 +20,7 @@ use num_traits::Zero; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn, XorTable}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, Relation, RelationEntry, INTERACTION_TRACE_IDX, PREPROCESSED_TRACE_IDX, diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index a1e0362c9..2f0af8693 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -1,9 +1,11 @@ +use std::rc::Rc; + use itertools::Itertools; use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::{IsFirst, Plonk, PreprocessedColumn}; use crate::constraint_framework::{ assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, TraceLocationAllocator, @@ -49,12 +51,12 @@ impl FrameworkEval for PlonkEval { } fn evaluate(&self, mut eval: E) -> E { - let a_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(0)); - let b_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(1)); + let a_wire = eval.get_preprocessed_column(Rc::new(Plonk::new(0))); + let b_wire = eval.get_preprocessed_column(Rc::new(Plonk::new(1))); // Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x. // A constant column is easier though. - let c_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(2)); - let op = eval.get_preprocessed_column(PreprocessedColumn::Plonk(3)); + let c_wire = eval.get_preprocessed_column(Rc::new(Plonk::new(2))); + let op = eval.get_preprocessed_column(Rc::new(Plonk::new(3))); let mult = eval.next_trace_mask(); let a_val = eval.next_trace_mask(); @@ -195,7 +197,7 @@ pub fn prove_fibonacci_plonk( // Preprocessed trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let is_first = gen_is_first(log_n_rows); + let is_first = IsFirst::new(log_n_rows).gen_column_simd(); let mut constant_trace = [ circuit.a_wire.clone(), circuit.b_wire.clone(), diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 481d30fd6..4ea60e714 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -7,7 +7,7 @@ use num_traits::One; use tracing::{info, span, Level}; use crate::constraint_framework::logup::LogupTraceGenerator; -use crate::constraint_framework::preprocessed_columns::gen_is_first; +use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, Relation, RelationEntry, TraceLocationAllocator, @@ -349,7 +349,7 @@ pub fn prove_poseidon( // Preprocessed trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let constant_trace = vec![gen_is_first(log_n_rows)]; + let constant_trace = vec![IsFirst::new(log_n_rows).gen_column_simd()]; tree_builder.extend_evals(constant_trace); tree_builder.commit(channel); span.exit(); @@ -397,7 +397,7 @@ mod tests { use num_traits::One; use crate::constraint_framework::assert_constraints; - use crate::constraint_framework::preprocessed_columns::gen_is_first; + use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::core::air::Component; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; @@ -472,7 +472,11 @@ mod tests { let (trace1, total_sum) = gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); - let traces = TreeVec::new(vec![vec![gen_is_first(LOG_N_ROWS)], trace0, trace1]); + let traces = TreeVec::new(vec![ + vec![IsFirst::new(LOG_N_ROWS).gen_column_simd()], + trace0, + trace1, + ]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); assert_constraints( diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 5de104188..6b4f91614 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use crate::constraint_framework::relation_tracker::RelationSummary; use crate::constraint_framework::Relation; pub mod components; @@ -12,7 +14,7 @@ use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; use crate::constraint_framework::preprocessed_columns::{ - gen_preprocessed_columns, PreprocessedColumn, + gen_preprocessed_columns, IsFirst, PreprocessedColumn, }; use crate::constraint_framework::TraceLocationAllocator; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -59,9 +61,9 @@ pub fn prove_state_machine( let mut commitment_scheme = CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); - let preprocessed_columns = [ - PreprocessedColumn::IsFirst(x_axis_log_rows), - PreprocessedColumn::IsFirst(y_axis_log_rows), + let preprocessed_columns: [Rc; 2] = [ + Rc::new(IsFirst::new(x_axis_log_rows)), + Rc::new(IsFirst::new(y_axis_log_rows)), ]; // Preprocessed trace. @@ -209,7 +211,7 @@ mod tests { use super::gen::{gen_interaction_trace, gen_trace}; use super::{prove_state_machine, verify_state_machine}; use crate::constraint_framework::expr::ExprEvaluator; - use crate::constraint_framework::preprocessed_columns::gen_is_first; + use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::constraint_framework::{ assert_constraints, FrameworkEval, Relation, TraceLocationAllocator, }; @@ -245,7 +247,7 @@ mod tests { ); let trace = TreeVec::new(vec![ - vec![gen_is_first(log_n_rows)], + vec![IsFirst::new(log_n_rows).gen_column_simd()], trace, interaction_trace, ]); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 3c0c17d0c..e1ec5dd1b 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -8,7 +8,7 @@ use itertools::{chain, zip_eq, Itertools}; use num_traits::{One, Zero}; use tracing::{span, Level}; -use crate::constraint_framework::preprocessed_columns::gen_is_first; +use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn}; use crate::constraint_framework::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, TraceLocationAllocator, }; @@ -210,7 +210,8 @@ impl ComponentProver for MleEvalProverComp .interpolate_with_twiddles(self.twiddles) .evaluate_with_twiddles(eval_domain, self.twiddles) .into_coordinate_evals(); - let is_first_lde = gen_is_first::(self.log_size()) + let is_first_lde = IsFirst::new(self.log_size()) + .gen_column_simd() .interpolate_with_twiddles(self.twiddles) .evaluate_with_twiddles(eval_domain, self.twiddles); let aux_interaction = component_trace.len(); @@ -745,7 +746,7 @@ mod tests { MleEvalVerifierComponent, }; use crate::constraint_framework::preprocessed_columns::{ - gen_is_first, gen_is_step_with_offset, + IsFirst, IsStepWithOffset, PreprocessedColumn, }; use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator}; use crate::core::air::{Component, ComponentProver, Components}; @@ -950,7 +951,7 @@ mod tests { let mle_coeffs_col_trace = mle_coeff_column::build_trace(&mle); let claim_shift = claim / BaseField::from(size); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); - let is_first_col = [gen_is_first(log_size)]; + let is_first_col = [IsFirst::new(log_size).gen_column_simd()]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![mle_coeffs_col_trace, mle_eval_trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); @@ -992,7 +993,7 @@ mod tests { let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); - let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let is_first_col = [IsFirst::new(N_VARIABLES as u32).gen_column_simd()]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); @@ -1029,7 +1030,7 @@ mod tests { let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); - let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let is_first_col = [IsFirst::new(N_VARIABLES as u32).gen_column_simd()]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); @@ -1066,7 +1067,7 @@ mod tests { let mle_eval_point = MleEvalPoint::new(&eval_point); let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); - let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let is_first_col = [IsFirst::new(N_VARIABLES as u32).gen_column_simd()]; let aux_trace = chain![carry_quotients_col, is_first_col].collect(); let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); @@ -1120,7 +1121,7 @@ mod tests { const OFFSET: usize = 1; const LOG_STEP: u32 = 2; let coset = CanonicCoset::new(LOG_SIZE).coset(); - let col_eval = gen_is_step_with_offset::(LOG_SIZE, LOG_STEP, OFFSET); + let col_eval = IsStepWithOffset::new(LOG_SIZE, LOG_STEP, OFFSET).gen_column_simd(); let col_poly = col_eval.interpolate(); let p = SECURE_FIELD_CIRCLE_GEN;