From 28a4a608495f90a7320f6135956765782f6aed3d Mon Sep 17 00:00:00 2001 From: Zach Langley Date: Tue, 24 Sep 2024 14:05:18 -0400 Subject: [PATCH] refactor: Use IsEqual instead of IsEqualVec in core air --- primitives/src/is_equal/columns.rs | 12 +++++++++--- vm/src/core/air.rs | 19 +++++++------------ vm/src/core/columns.rs | 30 +++++++++++++----------------- vm/src/core/execute.rs | 16 ++++++++-------- vm/src/core/mod.rs | 3 --- 5 files changed, 37 insertions(+), 43 deletions(-) diff --git a/primitives/src/is_equal/columns.rs b/primitives/src/is_equal/columns.rs index bf77a7ae99..ca7143c0ad 100644 --- a/primitives/src/is_equal/columns.rs +++ b/primitives/src/is_equal/columns.rs @@ -3,20 +3,22 @@ use afs_derive::AlignedBorrow; pub const NUM_COLS: usize = 4; #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)] pub struct IsEqualCols { pub io: IsEqualIoCols, pub aux: IsEqualAuxCols, } -#[derive(Clone, Copy)] +#[repr(C)] +#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)] pub struct IsEqualIoCols { pub x: T, pub y: T, pub is_equal: T, } -#[derive(Debug, Clone)] +#[repr(C)] +#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)] pub struct IsEqualAuxCols { pub inv: T, } @@ -31,6 +33,10 @@ impl IsEqualAuxCols { pub fn flatten(&self) -> Vec { vec![self.inv.clone()] } + + pub fn width() -> usize { + 1 + } } impl IsEqualCols { pub const fn new(x: T, y: T, is_equal: T, inv: T) -> IsEqualCols { diff --git a/vm/src/core/air.rs b/vm/src/core/air.rs index 566e01fb17..7b25b0ea22 100644 --- a/vm/src/core/air.rs +++ b/vm/src/core/air.rs @@ -1,7 +1,7 @@ use std::borrow::Borrow; use afs_primitives::{ - is_equal_vec::{columns::IsEqualVecIoCols, IsEqualVecAir}, + is_equal::{columns::IsEqualIoCols, IsEqualAir}, sub_chip::SubAir, }; use afs_stark_backend::interaction::InteractionBuilder; @@ -12,7 +12,7 @@ use p3_matrix::Matrix; use super::{ columns::{CoreAuxCols, CoreCols, CoreIoCols}, - CoreOptions, INST_WIDTH, WORD_SIZE, + CoreOptions, INST_WIDTH, }; use crate::{ arch::{bridge::ExecutionBridge, instructions::Opcode::*}, @@ -68,7 +68,7 @@ impl Air for CoreAir { reads, writes, read0_equals_read1, - is_equal_vec_aux, + is_equal_aux, reads_aux_cols, writes_aux_cols, next_pc, @@ -335,17 +335,12 @@ impl Air for CoreAir { // evaluate equality between read1 and read2 - let is_equal_vec_io_cols = IsEqualVecIoCols { - x: vec![read1.value], - y: vec![read2.value], + let is_equal_io_cols = IsEqualIoCols { + x: read1.value, + y: read2.value, is_equal: read0_equals_read1, }; - SubAir::eval( - &IsEqualVecAir::new(WORD_SIZE), - builder, - is_equal_vec_io_cols, - is_equal_vec_aux, - ); + SubAir::eval(&IsEqualAir, builder, is_equal_io_cols, is_equal_aux); // make sure program terminates or shards with NOP builder.when_last_row().assert_zero( diff --git a/vm/src/core/columns.rs b/vm/src/core/columns.rs index f9618bd84b..0193809e54 100644 --- a/vm/src/core/columns.rs +++ b/vm/src/core/columns.rs @@ -1,15 +1,13 @@ use std::{array, collections::BTreeMap}; use afs_primitives::{ - is_equal_vec::{columns::IsEqualVecAuxCols, IsEqualVecAir}, + is_equal::{columns::IsEqualAuxCols, IsEqualAir}, sub_chip::LocalTraceInstructions, }; use itertools::Itertools; use p3_field::{Field, PrimeField32}; -use super::{ - CoreAir, CoreChip, Opcode, CORE_MAX_READS_PER_CYCLE, CORE_MAX_WRITES_PER_CYCLE, WORD_SIZE, -}; +use super::{CoreAir, CoreChip, Opcode, CORE_MAX_READS_PER_CYCLE, CORE_MAX_WRITES_PER_CYCLE}; use crate::{ arch::instructions::CORE_INSTRUCTIONS, memory::{ @@ -146,7 +144,7 @@ pub struct CoreAuxCols { pub reads: [CoreMemoryAccessCols; CORE_MAX_READS_PER_CYCLE], pub writes: [CoreMemoryAccessCols; CORE_MAX_WRITES_PER_CYCLE], pub read0_equals_read1: T, - pub is_equal_vec_aux: IsEqualVecAuxCols, + pub is_equal_aux: IsEqualAuxCols, pub reads_aux_cols: [MemoryReadOrImmediateAuxCols; CORE_MAX_READS_PER_CYCLE], pub writes_aux_cols: [MemoryWriteAuxCols; CORE_MAX_WRITES_PER_CYCLE], @@ -183,8 +181,8 @@ impl CoreAuxCols { let beq_check = slc[start].clone(); start = end; - end += IsEqualVecAuxCols::::width(WORD_SIZE); - let is_equal_vec_aux = IsEqualVecAuxCols::from_slice(&slc[start..end], WORD_SIZE); + end += IsEqualAuxCols::::width(); + let is_equal_aux = IsEqualAuxCols::from_slice(&slc[start..end]); let reads_aux_cols = array::from_fn(|_| { start = end; @@ -193,7 +191,7 @@ impl CoreAuxCols { }); let writes_aux_cols = array::from_fn(|_| { start = end; - end += MemoryWriteAuxCols::::width(); + end += MemoryWriteAuxCols::::width(); MemoryWriteAuxCols::from_slice(&slc[start..end]) }); let next_pc = slc[end].clone(); @@ -204,7 +202,7 @@ impl CoreAuxCols { reads, writes, read0_equals_read1: beq_check, - is_equal_vec_aux, + is_equal_aux, reads_aux_cols, writes_aux_cols, next_pc, @@ -230,7 +228,7 @@ impl CoreAuxCols { .flat_map(CoreMemoryAccessCols::::flatten), ); flattened.push(self.read0_equals_read1.clone()); - flattened.extend(self.is_equal_vec_aux.flatten()); + flattened.extend(self.is_equal_aux.flatten()); flattened.extend( self.reads_aux_cols .iter() @@ -253,9 +251,9 @@ impl CoreAuxCols { + CORE_MAX_READS_PER_CYCLE * (CoreMemoryAccessCols::::width() + MemoryReadOrImmediateAuxCols::::width()) + CORE_MAX_WRITES_PER_CYCLE - * (CoreMemoryAccessCols::::width() + MemoryWriteAuxCols::::width()) + * (CoreMemoryAccessCols::::width() + MemoryWriteAuxCols::::width()) + 1 - + IsEqualVecAuxCols::::width(WORD_SIZE) + + IsEqualAuxCols::::width() + 1 } } @@ -267,17 +265,15 @@ impl CoreAuxCols { operation_flags.insert(opcode, F::from_bool(opcode == Opcode::NOP)); } - let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row( - &IsEqualVecAir::new(WORD_SIZE), - (vec![F::zero()], vec![F::zero()]), - ); + let is_equal_cols = + LocalTraceInstructions::generate_trace_row(&IsEqualAir, (F::zero(), F::zero())); Self { operation_flags, public_value_flags: vec![F::zero(); chip.air.options.num_public_values], reads: array::from_fn(|_| CoreMemoryAccessCols::disabled()), writes: array::from_fn(|_| CoreMemoryAccessCols::disabled()), read0_equals_read1: F::one(), - is_equal_vec_aux: is_equal_vec_cols.aux, + is_equal_aux: is_equal_cols.aux, reads_aux_cols: array::from_fn(|_| MemoryReadOrImmediateAuxCols::disabled()), writes_aux_cols: array::from_fn(|_| MemoryWriteAuxCols::disabled()), next_pc: F::from_canonical_usize(chip.state.pc), diff --git a/vm/src/core/execute.rs b/vm/src/core/execute.rs index bb7fda81a0..b45485ec52 100644 --- a/vm/src/core/execute.rs +++ b/vm/src/core/execute.rs @@ -1,9 +1,9 @@ use std::{array, collections::BTreeMap}; -use afs_primitives::{is_equal_vec::IsEqualVecAir, sub_chip::LocalTraceInstructions}; +use afs_primitives::{is_equal::IsEqualAir, sub_chip::LocalTraceInstructions}; use p3_field::PrimeField32; -use super::{timestamp_delta, CoreChip, CoreState, WORD_SIZE}; +use super::{timestamp_delta, CoreChip, CoreState}; use crate::{ arch::{ chips::InstructionExecutor, @@ -260,13 +260,13 @@ impl InstructionExecutor for CoreChip { operation_flags.insert(other_opcode, F::from_bool(other_opcode == opcode)); } - let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row( - &IsEqualVecAir::new(WORD_SIZE), - (vec![read_cols[0].value], vec![read_cols[1].value]), + let is_equal_cols = LocalTraceInstructions::generate_trace_row( + &IsEqualAir, + (read_cols[0].value, read_cols[1].value), ); - let read0_equals_read1 = is_equal_vec_cols.io.is_equal; - let is_equal_vec_aux = is_equal_vec_cols.aux; + let read0_equals_read1 = is_equal_cols.io.is_equal; + let is_equal_aux = is_equal_cols.aux; let aux = CoreAuxCols { operation_flags, @@ -274,7 +274,7 @@ impl InstructionExecutor for CoreChip { reads: read_cols, writes: write_cols, read0_equals_read1, - is_equal_vec_aux, + is_equal_aux, reads_aux_cols, writes_aux_cols, next_pc, diff --git a/vm/src/core/mod.rs b/vm/src/core/mod.rs index 787a1a58ec..e0dc83078b 100644 --- a/vm/src/core/mod.rs +++ b/vm/src/core/mod.rs @@ -35,9 +35,6 @@ pub const CORE_MAX_READS_PER_CYCLE: usize = 3; pub const CORE_MAX_WRITES_PER_CYCLE: usize = 1; pub const CORE_MAX_ACCESSES_PER_CYCLE: usize = CORE_MAX_READS_PER_CYCLE + CORE_MAX_WRITES_PER_CYCLE; -// [jpw] Temporary, we are going to remove cpu anyways -const WORD_SIZE: usize = 1; - fn timestamp_delta(opcode: Opcode) -> usize { match opcode { LOADW | STOREW => 3,