Skip to content

Commit

Permalink
feat: ALU chip implementation (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenh-axiom-xyz authored Sep 24, 2024
1 parent 5cf8f03 commit ef3c6fc
Show file tree
Hide file tree
Showing 18 changed files with 1,437 additions and 1,554 deletions.
8 changes: 8 additions & 0 deletions primitives/src/xor/lookup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,12 @@ impl<const M: usize> XorLookupChip<M> {

self.calc_xor(x, y)
}

pub fn clear(&self) {
for i in 0..(1 << M) {
for j in 0..(1 << M) {
self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed);
}
}
}
}
145 changes: 145 additions & 0 deletions vm/src/alu/air.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::{array, borrow::Borrow};

use afs_primitives::{utils, xor::bus::XorBus};
use afs_stark_backend::interaction::InteractionBuilder;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, Field};
use p3_matrix::Matrix;

use super::columns::ArithmeticLogicCols;
use crate::{
arch::{bridge::ExecutionBridge, instructions::ALU_256_INSTRUCTIONS},
memory::offline_checker::MemoryBridge,
};

#[derive(Copy, Clone, Debug)]
pub struct ArithmeticLogicAir<const ARG_SIZE: usize, const LIMB_SIZE: usize> {
pub(super) execution_bridge: ExecutionBridge,
pub(super) memory_bridge: MemoryBridge,
pub bus: XorBus,
}

impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
for ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS>
{
fn width(&self) -> usize {
ArithmeticLogicCols::<F, NUM_LIMBS, LIMB_BITS>::width()
}
}

impl<AB: InteractionBuilder, const NUM_LIMBS: usize, const LIMB_BITS: usize> Air<AB>
for ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS>
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);

let ArithmeticLogicCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow();
builder.assert_bool(aux.is_valid);

let flags = [
aux.opcode_add_flag,
aux.opcode_sub_flag,
aux.opcode_lt_flag,
aux.opcode_eq_flag,
aux.opcode_xor_flag,
aux.opcode_and_flag,
aux.opcode_or_flag,
aux.opcode_slt_flag,
];
for flag in flags {
builder.assert_bool(flag);
}

builder.assert_eq(
aux.is_valid,
flags
.iter()
.fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()),
);

let x_limbs = &io.x.data;
let y_limbs = &io.y.data;
let z_limbs = &io.z.data;

// For ADD, define carry[i] = (x[i] + y[i] + carry[i - 1] - z[i]) / 2^LIMB_BITS. If
// each carry[i] is boolean and 0 <= z[i] < 2^NUM_LIMBS, it can be proven that
// z[i] = (x[i] + y[i]) % 256 as necessary. The same holds for SUB when carry[i] is
// (z[i] + y[i] - x[i] + carry[i - 1]) / 2^LIMB_BITS.
let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero());
let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero());
let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();

for i in 0..NUM_LIMBS {
// We explicitly separate the constraints for ADD and SUB in order to keep degree
// cubic. Because we constrain that the carry (which is arbitrary) is bool, if
// carry has degree larger than 1 the max-degree constrain could be at least 4.
carry_add[i] = AB::Expr::from(carry_divide)
* (x_limbs[i] + y_limbs[i] - z_limbs[i]
+ if i > 0 {
carry_add[i - 1].clone()
} else {
AB::Expr::zero()
});
builder
.when(aux.opcode_add_flag)
.assert_bool(carry_add[i].clone());
carry_sub[i] = AB::Expr::from(carry_divide)
* (z_limbs[i] + y_limbs[i] - x_limbs[i]
+ if i > 0 {
carry_sub[i - 1].clone()
} else {
AB::Expr::zero()
});
builder
.when(aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag)
.assert_bool(carry_sub[i].clone());
}

// For LT, cmp_result must be equal to the last carry. For SLT, cmp_result ^ x_sign ^ y_sign must
// be equal to the last carry. To ensure maximum cubic degree constraints, we set aux.x_sign and
// aux.y_sign are 0 when not computing an SLT.
builder.assert_bool(aux.x_sign);
builder.assert_bool(aux.y_sign);
builder
.when(utils::not(aux.opcode_slt_flag))
.assert_zero(aux.x_sign);
builder
.when(utils::not(aux.opcode_slt_flag))
.assert_zero(aux.y_sign);

let slt_xor =
(aux.opcode_lt_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign
- AB::Expr::from_canonical_u32(2)
* (io.cmp_result * aux.x_sign
+ io.cmp_result * aux.y_sign
+ aux.x_sign * aux.y_sign)
+ AB::Expr::from_canonical_u32(4) * (io.cmp_result * aux.x_sign * aux.y_sign);
builder.assert_eq(
slt_xor,
(aux.opcode_lt_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(),
);

// For EQ, z is filled with 0 except at the lowest index i such that x[i] != y[i]. If
// such an i exists z[i] is the inverse of x[i] - y[i], meaning sum_eq should be 1.
let mut sum_eq: AB::Expr = io.cmp_result.into();
for i in 0..NUM_LIMBS {
sum_eq += (x_limbs[i] - y_limbs[i]) * z_limbs[i];
builder
.when(aux.opcode_eq_flag)
.assert_zero(io.cmp_result * (x_limbs[i] - y_limbs[i]));
}
builder
.when(aux.opcode_eq_flag)
.assert_zero(sum_eq - AB::Expr::one());

let expected_opcode = flags
.iter()
.zip(ALU_256_INSTRUCTIONS)
.fold(AB::Expr::zero(), |acc, (flag, opcode)| {
acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8)
});

self.eval_interactions(builder, io, aux, expected_opcode);
}
}
146 changes: 146 additions & 0 deletions vm/src/alu/bridge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use afs_stark_backend::interaction::InteractionBuilder;
use itertools::izip;
use p3_field::AbstractField;

use super::{
air::ArithmeticLogicAir,
columns::{ArithmeticLogicAuxCols, ArithmeticLogicIoCols},
};
use crate::memory::MemoryAddress;

impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS> {
pub fn eval_interactions<AB: InteractionBuilder>(
&self,
builder: &mut AB,
io: &ArithmeticLogicIoCols<AB::Var, NUM_LIMBS, LIMB_BITS>,
aux: &ArithmeticLogicAuxCols<AB::Var, NUM_LIMBS, LIMB_BITS>,
expected_opcode: AB::Expr,
) {
let timestamp: AB::Var = io.from_state.timestamp;
let mut timestamp_delta: usize = 0;
let mut timestamp_pp = || {
timestamp_delta += 1;
timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
};

let range_check =
aux.opcode_add_flag + aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag;
let bitwise = aux.opcode_xor_flag + aux.opcode_and_flag + aux.opcode_or_flag;

// Read the operand pointer's values, which are themselves pointers
// for the actual IO data.
for (ptr, value, mem_aux) in izip!(
[
io.z.ptr_to_address,
io.x.ptr_to_address,
io.y.ptr_to_address
],
[io.z.address, io.x.address, io.y.address],
&aux.read_ptr_aux_cols
) {
self.memory_bridge
.read(
MemoryAddress::new(io.ptr_as, ptr),
[value],
timestamp_pp(),
mem_aux,
)
.eval(builder, aux.is_valid);
}

self.memory_bridge
.read(
MemoryAddress::new(io.address_as, io.x.address),
io.x.data,
timestamp_pp(),
&aux.read_x_aux_cols,
)
.eval(builder, aux.is_valid);

self.memory_bridge
.read(
MemoryAddress::new(io.address_as, io.y.address),
io.y.data,
timestamp_pp(),
&aux.read_y_aux_cols,
)
.eval(builder, aux.is_valid);

// Special handling for writing output z data:
self.memory_bridge
.write(
MemoryAddress::new(io.address_as, io.z.address),
io.z.data,
timestamp + AB::F::from_canonical_usize(timestamp_delta),
&aux.write_z_aux_cols,
)
.eval(
builder,
aux.opcode_add_flag + aux.opcode_sub_flag + bitwise.clone(),
);

// Special handling for writing output cmp data:
self.memory_bridge
.write(
MemoryAddress::new(io.address_as, io.z.address),
[io.cmp_result],
timestamp + AB::F::from_canonical_usize(timestamp_delta),
&aux.write_cmp_aux_cols,
)
.eval(
builder,
aux.opcode_lt_flag + aux.opcode_eq_flag + aux.opcode_slt_flag,
);
timestamp_delta += 1;

self.execution_bridge
.execute_and_increment_pc(
expected_opcode,
[
io.z.ptr_to_address,
io.x.ptr_to_address,
io.y.ptr_to_address,
io.ptr_as,
io.address_as,
],
io.from_state,
AB::F::from_canonical_usize(timestamp_delta),
)
.eval(builder, aux.is_valid);

// Check x_sign & x[NUM_LIMBS - 1] == x_sign using XOR
let x_sign_shifted = aux.x_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
let y_sign_shifted = aux.y_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1));
self.bus
.send(
x_sign_shifted.clone(),
io.x.data[NUM_LIMBS - 1],
io.x.data[NUM_LIMBS - 1] - x_sign_shifted,
)
.eval(builder, aux.opcode_slt_flag);
self.bus
.send(
y_sign_shifted.clone(),
io.y.data[NUM_LIMBS - 1],
io.y.data[NUM_LIMBS - 1] - y_sign_shifted,
)
.eval(builder, aux.opcode_slt_flag);

// Chip-specific interactions
for i in 0..NUM_LIMBS {
let x = range_check.clone() * io.z.data[i] + bitwise.clone() * io.x.data[i];
let y = range_check.clone() * io.z.data[i] + bitwise.clone() * io.y.data[i];
let xor_res = aux.opcode_xor_flag * io.z.data[i]
+ aux.opcode_and_flag
* (io.x.data[i] + io.y.data[i]
- (AB::Expr::from_canonical_u32(2) * io.z.data[i]))
+ aux.opcode_or_flag
* ((AB::Expr::from_canonical_u32(2) * io.z.data[i])
- io.x.data[i]
- io.y.data[i]);
self.bus
.send(x, y, xor_res)
.eval(builder, range_check.clone() + bitwise.clone());
}
}
}
79 changes: 79 additions & 0 deletions vm/src/alu/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::mem::size_of;

use afs_derive::AlignedBorrow;

use crate::{
arch::columns::ExecutionState,
memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols},
uint_multiplication::MemoryData,
};

#[repr(C)]
#[derive(AlignedBorrow)]
pub struct ArithmeticLogicCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub io: ArithmeticLogicIoCols<T, NUM_LIMBS, LIMB_BITS>,
pub aux: ArithmeticLogicAuxCols<T, NUM_LIMBS, LIMB_BITS>,
}

impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize>
ArithmeticLogicCols<T, NUM_LIMBS, LIMB_BITS>
{
pub fn width() -> usize {
ArithmeticLogicAuxCols::<T, NUM_LIMBS, LIMB_BITS>::width()
+ ArithmeticLogicIoCols::<T, NUM_LIMBS, LIMB_BITS>::width()
}
}

#[repr(C)]
#[derive(AlignedBorrow)]
pub struct ArithmeticLogicIoCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub from_state: ExecutionState<T>,
pub x: MemoryData<T, NUM_LIMBS, LIMB_BITS>,
pub y: MemoryData<T, NUM_LIMBS, LIMB_BITS>,
pub z: MemoryData<T, NUM_LIMBS, LIMB_BITS>,
pub cmp_result: T,
pub ptr_as: T,
pub address_as: T,
}

impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize>
ArithmeticLogicIoCols<T, NUM_LIMBS, LIMB_BITS>
{
pub fn width() -> usize {
size_of::<ArithmeticLogicIoCols<u8, NUM_LIMBS, LIMB_BITS>>()
}
}

#[repr(C)]
#[derive(AlignedBorrow)]
pub struct ArithmeticLogicAuxCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub is_valid: T,
pub x_sign: T,
pub y_sign: T,

// Opcode flags for different operations
pub opcode_add_flag: T,
pub opcode_sub_flag: T,
pub opcode_lt_flag: T,
pub opcode_eq_flag: T,
pub opcode_xor_flag: T,
pub opcode_and_flag: T,
pub opcode_or_flag: T,
pub opcode_slt_flag: T,

/// Pointer read auxiliary columns for [z, x, y].
/// **Note** the ordering, which is designed to match the instruction order.
pub read_ptr_aux_cols: [MemoryReadAuxCols<T, 1>; 3],
pub read_x_aux_cols: MemoryReadAuxCols<T, NUM_LIMBS>,
pub read_y_aux_cols: MemoryReadAuxCols<T, NUM_LIMBS>,
pub write_z_aux_cols: MemoryWriteAuxCols<T, NUM_LIMBS>,
pub write_cmp_aux_cols: MemoryWriteAuxCols<T, 1>,
}

impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize>
ArithmeticLogicAuxCols<T, NUM_LIMBS, LIMB_BITS>
{
pub fn width() -> usize {
size_of::<ArithmeticLogicAuxCols<u8, NUM_LIMBS, LIMB_BITS>>()
}
}
Loading

0 comments on commit ef3c6fc

Please sign in to comment.