Skip to content

Commit

Permalink
Use PreProcessedColumn Trait Instead of Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 9, 2025
1 parent 6ba7258 commit 05e2eb6
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 124 deletions.
24 changes: 16 additions & 8 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
use std::rc::Rc;

use itertools::Itertools;
#[cfg(feature = "parallel")]
Expand Down Expand Up @@ -49,7 +50,7 @@ pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
preprocessed_columns: HashMap<Rc<dyn PreprocessedColumn>, usize>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}
Expand Down Expand Up @@ -81,30 +82,37 @@ impl TraceLocationAllocator {
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self {
pub fn new_with_preproccessed_columns(
preprocessed_columns: &[Rc<dyn PreprocessedColumn>],
) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.map(|(i, col)| (col.clone(), i))
.collect(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub const fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
pub const fn preprocessed_columns(&self) -> &HashMap<Rc<dyn PreprocessedColumn>, usize> {
&self.preprocessed_columns
}

// validates that `self.preprocessed_columns` is consistent with
// `preprocessed_columns`.
// I.e. preprocessed_columns[i] == self.preprocessed_columns[i].
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) {
pub fn validate_preprocessed_columns(
&self,
preprocessed_columns: &[Rc<dyn PreprocessedColumn>],
) {
assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len());

for (column, idx) in self.preprocessed_columns.iter() {
assert_eq!(Some(column), preprocessed_columns.get(*idx));
let preprocessed_column = preprocessed_columns
.get(*idx)
.expect("Preprocessed column is missing from preprocessed_columns");
assert_eq!(column.id(), preprocessed_column.id());
}
}
}
Expand Down Expand Up @@ -146,7 +154,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
.preprocessed_columns
.entry(*col)
.entry(col.clone())
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::rc::Rc;

use num_traits::Zero;

use super::{BaseExpr, ExtExpr};
Expand Down Expand Up @@ -174,7 +176,7 @@ impl EvalAtRow for ExprEvaluator {
intermediate
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Rc<dyn PreprocessedColumn>) -> Self::F {
BaseExpr::Param(column.name().to_string())
}

Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
pub preprocessed_columns: Vec<Rc<dyn PreprocessedColumn>>,
pub logup: LogupAtRow<Self>,
pub arithmetic_counts: ArithmeticCounts,
}
impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
preprocessed_columns: Vec<Rc<dyn PreprocessedColumn>>,
logup_sums: LogupSums,
) -> Self {
Self {
Expand Down Expand Up @@ -70,7 +70,7 @@ impl EvalAtRow for InfoEvaluator {
array::from_fn(|_| FieldCounter::one())
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Rc<dyn PreprocessedColumn>) -> Self::F {
self.preprocessed_columns.push(column);
FieldCounter::one()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
/// See [`super::preprocessed_columns::IsFirst`].
pub is_first: E::F,
pub log_size: u32,
}
Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod simd_domain;
use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
use std::rc::Rc;

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
Expand Down Expand Up @@ -87,7 +88,7 @@ pub trait EvalAtRow {
mask_item
}

fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, _column: Rc<dyn PreprocessedColumn>) -> Self::F {
let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
mask_item
}
Expand Down Expand Up @@ -172,11 +173,11 @@ macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
self.logup.is_first = self.get_preprocessed_column(std::rc::Rc::new(
crate::constraint_framework::preprocessed_columns::IsFirst::new(
self.logup.log_size,
),
);
));
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand Down
65 changes: 31 additions & 34 deletions crates/prover/src/constraint_framework/preprocessed_columns.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;
use std::hash::Hash;
use std::rc::Rc;
use std::simd::Simd;

use num_traits::{One, Zero};
Expand All @@ -17,31 +18,29 @@ const SIMD_ENUMERATION_0: PackedM31 = unsafe {
]))
};

// TODO(Gali): Rename to PrerocessedColumn.
pub trait PreprocessedColumnTrait: Debug {
pub trait PreprocessedColumn: Debug {
fn name(&self) -> &'static str;
/// Used for hashing and comparison of preprocessed columns.
/// The id should be unique for each preprocessed column - one naive
/// implementation is: "PreProcessedColumnName(PreProcessedColumnsVariables)".
fn id(&self) -> String;
fn log_size(&self) -> u32;
}
impl PartialEq for dyn PreprocessedColumnTrait {
impl PartialEq for dyn PreprocessedColumn {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
}
}
impl Eq for dyn PreprocessedColumnTrait {}
impl Hash for dyn PreprocessedColumnTrait {
impl Eq for dyn PreprocessedColumn {}
impl Hash for dyn PreprocessedColumn {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id().hash(state);
}
}

// TODO(ilya): Where should this enum be placed?
// TODO(Gali): Add documentation for the rest of the variants.
// TODO(Gali): Remove Enum.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PreprocessedColumn {
pub enum PreprocessedColumnEnum {
/// A column with `1` at the first position, and `0` elsewhere.
IsFirst(u32),
Plonk(usize),
Expand All @@ -50,29 +49,29 @@ pub enum PreprocessedColumn {
XorTable(u32, u32, usize),
}

impl PreprocessedColumn {
impl PreprocessedColumnEnum {
pub const fn name(&self) -> &'static str {
match self {
PreprocessedColumn::IsFirst(_) => "preprocessed_is_first",
PreprocessedColumn::Plonk(_) => "preprocessed_plonk",
PreprocessedColumn::Seq(_) => "preprocessed_seq",
PreprocessedColumn::XorTable(..) => "preprocessed_xor_table",
PreprocessedColumnEnum::IsFirst(_) => "preprocessed_is_first",
PreprocessedColumnEnum::Plonk(_) => "preprocessed_plonk",
PreprocessedColumnEnum::Seq(_) => "preprocessed_seq",
PreprocessedColumnEnum::XorTable(..) => "preprocessed_xor_table",
}
}

pub fn log_size(&self) -> u32 {
match self {
PreprocessedColumn::IsFirst(log_size) => *log_size,
PreprocessedColumn::Seq(log_size) => *log_size,
PreprocessedColumn::XorTable(log_size, ..) => *log_size,
PreprocessedColumn::Plonk(_) => unimplemented!(),
PreprocessedColumnEnum::IsFirst(log_size) => *log_size,
PreprocessedColumnEnum::Seq(log_size) => *log_size,
PreprocessedColumnEnum::XorTable(log_size, ..) => *log_size,
PreprocessedColumnEnum::Plonk(_) => unimplemented!(),
}
}

/// Returns the values of the column at the given row.
pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
match self {
PreprocessedColumn::IsFirst(log_size) => {
PreprocessedColumnEnum::IsFirst(log_size) => {
assert!(vec_row < (1 << log_size) / N_LANES);
if vec_row == 0 {
unsafe {
Expand All @@ -88,7 +87,7 @@ impl PreprocessedColumn {
PackedM31::zero()
}
}
PreprocessedColumn::Seq(log_size) => {
PreprocessedColumnEnum::Seq(log_size) => {
assert!(vec_row < (1 << log_size) / N_LANES);
PackedM31::broadcast(M31::from(vec_row * N_LANES)) + SIMD_ENUMERATION_0
}
Expand All @@ -99,14 +98,14 @@ impl PreprocessedColumn {

/// Generates a column according to the preprocessed column chosen.
pub fn gen_preprocessed_column<B: Backend>(
preprocessed_column: &PreprocessedColumn,
preprocessed_column: &PreprocessedColumnEnum,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
match preprocessed_column {
PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size),
PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => {
PreprocessedColumnEnum::IsFirst(log_size) => gen_is_first(*log_size),
PreprocessedColumnEnum::Plonk(_) | PreprocessedColumnEnum::XorTable(..) => {
unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.")
}
PreprocessedColumn::Seq(log_size) => gen_seq(*log_size),
PreprocessedColumnEnum::Seq(log_size) => gen_seq(*log_size),
}
}
}
Expand Down Expand Up @@ -147,25 +146,23 @@ pub fn gen_seq<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitR
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

pub fn gen_preprocessed_columns<'a, B: Backend>(
columns: impl Iterator<Item = &'a PreprocessedColumn>,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
columns
.map(PreprocessedColumn::gen_preprocessed_column)
.collect()
pub fn gen_preprocessed_columns<'a>(
columns: impl Iterator<Item = &'a Rc<dyn PreprocessedColumn>>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
columns.map(|col| col.gen_column_simd()).collect()
}

#[cfg(test)]
mod tests {
use super::{IsFirst, PreprocessedColumn, Seq};
use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::{BaseField, M31};
const LOG_SIZE: u32 = 8;

#[test]
fn test_gen_seq() {
let seq = super::gen_seq::<SimdBackend>(LOG_SIZE);
let seq = Seq::new(LOG_SIZE).gen_column_simd();

for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
Expand All @@ -175,8 +172,8 @@ mod tests {
// TODO(Gali): Add packed_at tests for xor_table and plonk.
#[test]
fn test_packed_at_is_first() {
let is_first = super::PreprocessedColumn::IsFirst(LOG_SIZE);
let expected_is_first = super::gen_is_first::<SimdBackend>(LOG_SIZE).to_cpu();
let is_first = IsFirst::new(LOG_SIZE);
let expected_is_first = is_first.gen_column_simd().to_cpu();

for i in 0..(1 << LOG_SIZE) / N_LANES {
assert_eq!(
Expand All @@ -188,7 +185,7 @@ mod tests {

#[test]
fn test_packed_at_seq() {
let seq = super::PreprocessedColumn::Seq(LOG_SIZE);
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));

let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::rc::Rc;

use itertools::Itertools;
use num_traits::Zero;
Expand Down Expand Up @@ -146,7 +147,7 @@ impl EvalAtRow for RelationTrackerEvaluator<'_> {
})
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Rc<dyn PreprocessedColumn>) -> Self::F {
column.packed_at(self.vec_row)
}

Expand Down
Loading

0 comments on commit 05e2eb6

Please sign in to comment.