Skip to content

Commit

Permalink
refactor: Use IsEqual instead of IsEqualVec in core air
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley committed Sep 24, 2024
1 parent 7c06737 commit 8367b40
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 43 deletions.
12 changes: 9 additions & 3 deletions primitives/src/is_equal/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ use afs_derive::AlignedBorrow;
pub const NUM_COLS: usize = 4;

#[repr(C)]
#[derive(AlignedBorrow)]
#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)]
pub struct IsEqualCols<T> {
pub io: IsEqualIoCols<T>,
pub aux: IsEqualAuxCols<T>,
}

#[derive(Clone, Copy)]
#[repr(C)]
#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)]
pub struct IsEqualIoCols<T> {
pub x: T,
pub y: T,
pub is_equal: T,
}

#[derive(Debug, Clone)]
#[repr(C)]
#[derive(AlignedBorrow, Clone, Debug, PartialEq, Eq)]
pub struct IsEqualAuxCols<T> {
pub inv: T,
}
Expand All @@ -31,6 +33,10 @@ impl<T: Clone> IsEqualAuxCols<T> {
pub fn flatten(&self) -> Vec<T> {
vec![self.inv.clone()]
}

pub fn width() -> usize {
1
}
}
impl<T: Clone> IsEqualCols<T> {
pub const fn new(x: T, y: T, is_equal: T, inv: T) -> IsEqualCols<T> {
Expand Down
19 changes: 7 additions & 12 deletions vm/src/core/air.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{borrow::Borrow, iter::zip};

use afs_primitives::{
is_equal_vec::{columns::IsEqualVecIoCols, IsEqualVecAir},
is_equal::{columns::IsEqualIoCols, IsEqualAir},
sub_chip::SubAir,
};
use afs_stark_backend::interaction::InteractionBuilder;
Expand All @@ -11,7 +11,7 @@ use p3_matrix::Matrix;

use super::{
columns::{CoreAuxCols, CoreCols, CoreIoCols},
CoreOptions, INST_WIDTH, WORD_SIZE,
CoreOptions, INST_WIDTH,
};
use crate::{
arch::{bridge::ExecutionBridge, instructions::Opcode::*},
Expand Down Expand Up @@ -67,7 +67,7 @@ impl<AB: AirBuilderWithPublicValues + InteractionBuilder> Air<AB> for CoreAir {
reads,
writes,
read0_equals_read1,
is_equal_vec_aux,
is_equal_aux,
reads_aux_cols,
writes_aux_cols,
next_pc,
Expand Down Expand Up @@ -333,17 +333,12 @@ impl<AB: AirBuilderWithPublicValues + InteractionBuilder> Air<AB> for CoreAir {

// evaluate equality between read1 and read2

let is_equal_vec_io_cols = IsEqualVecIoCols {
x: vec![read1.value],
y: vec![read2.value],
let is_equal_io_cols = IsEqualIoCols {
x: read1.value,
y: read2.value,
is_equal: read0_equals_read1,
};
SubAir::eval(
&IsEqualVecAir::new(WORD_SIZE),
builder,
is_equal_vec_io_cols,
is_equal_vec_aux,
);
SubAir::eval(&IsEqualAir, builder, is_equal_io_cols, is_equal_aux);

// make sure program terminates or shards with NOP
builder.when_last_row().assert_zero(
Expand Down
30 changes: 13 additions & 17 deletions vm/src/core/columns.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::{array, collections::BTreeMap};

use afs_primitives::{
is_equal_vec::{columns::IsEqualVecAuxCols, IsEqualVecAir},
is_equal::{columns::IsEqualAuxCols, IsEqualAir},
sub_chip::LocalTraceInstructions,
};
use itertools::Itertools;
use p3_field::{Field, PrimeField32};

use super::{
CoreAir, CoreChip, Opcode, CORE_MAX_READS_PER_CYCLE, CORE_MAX_WRITES_PER_CYCLE, WORD_SIZE,
};
use super::{CoreAir, CoreChip, Opcode, CORE_MAX_READS_PER_CYCLE, CORE_MAX_WRITES_PER_CYCLE};
use crate::{
arch::instructions::CORE_INSTRUCTIONS,
memory::{
Expand Down Expand Up @@ -162,7 +160,7 @@ pub struct CoreAuxCols<T> {
pub reads: [CoreMemoryAccessCols<T>; CORE_MAX_READS_PER_CYCLE],
pub writes: [CoreMemoryAccessCols<T>; CORE_MAX_WRITES_PER_CYCLE],
pub read0_equals_read1: T,
pub is_equal_vec_aux: IsEqualVecAuxCols<T>,
pub is_equal_aux: IsEqualAuxCols<T>,
pub reads_aux_cols: [MemoryReadOrImmediateAuxCols<T>; CORE_MAX_READS_PER_CYCLE],
pub writes_aux_cols: [MemoryWriteAuxCols<T, 1>; CORE_MAX_WRITES_PER_CYCLE],

Expand Down Expand Up @@ -199,8 +197,8 @@ impl<T: Clone> CoreAuxCols<T> {
let beq_check = slc[start].clone();

start = end;
end += IsEqualVecAuxCols::<T>::width(WORD_SIZE);
let is_equal_vec_aux = IsEqualVecAuxCols::from_slice(&slc[start..end], WORD_SIZE);
end += IsEqualAuxCols::<T>::width();
let is_equal_aux = IsEqualAuxCols::from_slice(&slc[start..end]);

let reads_aux_cols = array::from_fn(|_| {
start = end;
Expand All @@ -209,7 +207,7 @@ impl<T: Clone> CoreAuxCols<T> {
});
let writes_aux_cols = array::from_fn(|_| {
start = end;
end += MemoryWriteAuxCols::<T, WORD_SIZE>::width();
end += MemoryWriteAuxCols::<T, 1>::width();
MemoryWriteAuxCols::from_slice(&slc[start..end])
});
let next_pc = slc[end].clone();
Expand All @@ -220,7 +218,7 @@ impl<T: Clone> CoreAuxCols<T> {
reads,
writes,
read0_equals_read1: beq_check,
is_equal_vec_aux,
is_equal_aux,
reads_aux_cols,
writes_aux_cols,
next_pc,
Expand All @@ -246,7 +244,7 @@ impl<T: Clone> CoreAuxCols<T> {
.flat_map(CoreMemoryAccessCols::<T>::flatten),
);
flattened.push(self.read0_equals_read1.clone());
flattened.extend(self.is_equal_vec_aux.flatten());
flattened.extend(self.is_equal_aux.flatten());
flattened.extend(
self.reads_aux_cols
.iter()
Expand All @@ -269,9 +267,9 @@ impl<T: Clone> CoreAuxCols<T> {
+ CORE_MAX_READS_PER_CYCLE
* (CoreMemoryAccessCols::<T>::width() + MemoryReadOrImmediateAuxCols::<T>::width())
+ CORE_MAX_WRITES_PER_CYCLE
* (CoreMemoryAccessCols::<T>::width() + MemoryWriteAuxCols::<T, WORD_SIZE>::width())
* (CoreMemoryAccessCols::<T>::width() + MemoryWriteAuxCols::<T, 1>::width())
+ 1
+ IsEqualVecAuxCols::<T>::width(WORD_SIZE)
+ IsEqualAuxCols::<T>::width()
+ 1
}
}
Expand All @@ -283,17 +281,15 @@ impl<F: PrimeField32> CoreAuxCols<F> {
operation_flags.insert(opcode, F::from_bool(opcode == Opcode::NOP));
}

let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row(
&IsEqualVecAir::new(WORD_SIZE),
(vec![F::zero()], vec![F::zero()]),
);
let is_equal_cols =
LocalTraceInstructions::generate_trace_row(&IsEqualAir, (F::zero(), F::zero()));
Self {
operation_flags,
public_value_flags: vec![F::zero(); chip.air.options.num_public_values],
reads: array::from_fn(|_| CoreMemoryAccessCols::disabled()),
writes: array::from_fn(|_| CoreMemoryAccessCols::disabled()),
read0_equals_read1: F::one(),
is_equal_vec_aux: is_equal_vec_cols.aux,
is_equal_aux: is_equal_cols.aux,
reads_aux_cols: array::from_fn(|_| MemoryReadOrImmediateAuxCols::disabled()),
writes_aux_cols: array::from_fn(|_| MemoryWriteAuxCols::disabled()),
next_pc: F::from_canonical_usize(chip.state.pc),
Expand Down
16 changes: 8 additions & 8 deletions vm/src/core/execute.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{array, collections::BTreeMap};

use afs_primitives::{is_equal_vec::IsEqualVecAir, sub_chip::LocalTraceInstructions};
use afs_primitives::{is_equal::IsEqualAir, sub_chip::LocalTraceInstructions};
use p3_field::PrimeField32;

use super::{timestamp_delta, CoreChip, CoreState, WORD_SIZE};
use super::{timestamp_delta, CoreChip, CoreState};
use crate::{
arch::{
chips::InstructionExecutor,
Expand Down Expand Up @@ -260,21 +260,21 @@ impl<F: PrimeField32> InstructionExecutor<F> for CoreChip<F> {
operation_flags.insert(other_opcode, F::from_bool(other_opcode == opcode));
}

let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row(
&IsEqualVecAir::new(WORD_SIZE),
(vec![read_cols[0].value], vec![read_cols[1].value]),
let is_equal_cols = LocalTraceInstructions::generate_trace_row(
&IsEqualAir,
(read_cols[0].value, read_cols[1].value),
);

let read0_equals_read1 = is_equal_vec_cols.io.is_equal;
let is_equal_vec_aux = is_equal_vec_cols.aux;
let read0_equals_read1 = is_equal_cols.io.is_equal;
let is_equal_aux = is_equal_cols.aux;

let aux = CoreAuxCols {
operation_flags,
public_value_flags,
reads: read_cols,
writes: write_cols,
read0_equals_read1,
is_equal_vec_aux,
is_equal_aux,
reads_aux_cols,
writes_aux_cols,
next_pc,
Expand Down
3 changes: 0 additions & 3 deletions vm/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ pub const CORE_MAX_READS_PER_CYCLE: usize = 3;
pub const CORE_MAX_WRITES_PER_CYCLE: usize = 1;
pub const CORE_MAX_ACCESSES_PER_CYCLE: usize = CORE_MAX_READS_PER_CYCLE + CORE_MAX_WRITES_PER_CYCLE;

// [jpw] Temporary, we are going to remove cpu anyways
const WORD_SIZE: usize = 1;

fn timestamp_delta(opcode: Opcode) -> usize {
match opcode {
LOADW | STOREW => 3,
Expand Down

0 comments on commit 8367b40

Please sign in to comment.