-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ALU chip implementation (#459)
- Loading branch information
1 parent
5cf8f03
commit ef3c6fc
Showing
18 changed files
with
1,437 additions
and
1,554 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
use std::{array, borrow::Borrow}; | ||
|
||
use afs_primitives::{utils, xor::bus::XorBus}; | ||
use afs_stark_backend::interaction::InteractionBuilder; | ||
use p3_air::{Air, AirBuilder, BaseAir}; | ||
use p3_field::{AbstractField, Field}; | ||
use p3_matrix::Matrix; | ||
|
||
use super::columns::ArithmeticLogicCols; | ||
use crate::{ | ||
arch::{bridge::ExecutionBridge, instructions::ALU_256_INSTRUCTIONS}, | ||
memory::offline_checker::MemoryBridge, | ||
}; | ||
|
||
#[derive(Copy, Clone, Debug)] | ||
pub struct ArithmeticLogicAir<const ARG_SIZE: usize, const LIMB_SIZE: usize> { | ||
pub(super) execution_bridge: ExecutionBridge, | ||
pub(super) memory_bridge: MemoryBridge, | ||
pub bus: XorBus, | ||
} | ||
|
||
impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F> | ||
for ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS> | ||
{ | ||
fn width(&self) -> usize { | ||
ArithmeticLogicCols::<F, NUM_LIMBS, LIMB_BITS>::width() | ||
} | ||
} | ||
|
||
impl<AB: InteractionBuilder, const NUM_LIMBS: usize, const LIMB_BITS: usize> Air<AB> | ||
for ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS> | ||
{ | ||
fn eval(&self, builder: &mut AB) { | ||
let main = builder.main(); | ||
let local = main.row_slice(0); | ||
|
||
let ArithmeticLogicCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow(); | ||
builder.assert_bool(aux.is_valid); | ||
|
||
let flags = [ | ||
aux.opcode_add_flag, | ||
aux.opcode_sub_flag, | ||
aux.opcode_lt_flag, | ||
aux.opcode_eq_flag, | ||
aux.opcode_xor_flag, | ||
aux.opcode_and_flag, | ||
aux.opcode_or_flag, | ||
aux.opcode_slt_flag, | ||
]; | ||
for flag in flags { | ||
builder.assert_bool(flag); | ||
} | ||
|
||
builder.assert_eq( | ||
aux.is_valid, | ||
flags | ||
.iter() | ||
.fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()), | ||
); | ||
|
||
let x_limbs = &io.x.data; | ||
let y_limbs = &io.y.data; | ||
let z_limbs = &io.z.data; | ||
|
||
// For ADD, define carry[i] = (x[i] + y[i] + carry[i - 1] - z[i]) / 2^LIMB_BITS. If | ||
// each carry[i] is boolean and 0 <= z[i] < 2^NUM_LIMBS, it can be proven that | ||
// z[i] = (x[i] + y[i]) % 256 as necessary. The same holds for SUB when carry[i] is | ||
// (z[i] + y[i] - x[i] + carry[i - 1]) / 2^LIMB_BITS. | ||
let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); | ||
let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); | ||
let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse(); | ||
|
||
for i in 0..NUM_LIMBS { | ||
// We explicitly separate the constraints for ADD and SUB in order to keep degree | ||
// cubic. Because we constrain that the carry (which is arbitrary) is bool, if | ||
// carry has degree larger than 1 the max-degree constrain could be at least 4. | ||
carry_add[i] = AB::Expr::from(carry_divide) | ||
* (x_limbs[i] + y_limbs[i] - z_limbs[i] | ||
+ if i > 0 { | ||
carry_add[i - 1].clone() | ||
} else { | ||
AB::Expr::zero() | ||
}); | ||
builder | ||
.when(aux.opcode_add_flag) | ||
.assert_bool(carry_add[i].clone()); | ||
carry_sub[i] = AB::Expr::from(carry_divide) | ||
* (z_limbs[i] + y_limbs[i] - x_limbs[i] | ||
+ if i > 0 { | ||
carry_sub[i - 1].clone() | ||
} else { | ||
AB::Expr::zero() | ||
}); | ||
builder | ||
.when(aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag) | ||
.assert_bool(carry_sub[i].clone()); | ||
} | ||
|
||
// For LT, cmp_result must be equal to the last carry. For SLT, cmp_result ^ x_sign ^ y_sign must | ||
// be equal to the last carry. To ensure maximum cubic degree constraints, we set aux.x_sign and | ||
// aux.y_sign are 0 when not computing an SLT. | ||
builder.assert_bool(aux.x_sign); | ||
builder.assert_bool(aux.y_sign); | ||
builder | ||
.when(utils::not(aux.opcode_slt_flag)) | ||
.assert_zero(aux.x_sign); | ||
builder | ||
.when(utils::not(aux.opcode_slt_flag)) | ||
.assert_zero(aux.y_sign); | ||
|
||
let slt_xor = | ||
(aux.opcode_lt_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign | ||
- AB::Expr::from_canonical_u32(2) | ||
* (io.cmp_result * aux.x_sign | ||
+ io.cmp_result * aux.y_sign | ||
+ aux.x_sign * aux.y_sign) | ||
+ AB::Expr::from_canonical_u32(4) * (io.cmp_result * aux.x_sign * aux.y_sign); | ||
builder.assert_eq( | ||
slt_xor, | ||
(aux.opcode_lt_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(), | ||
); | ||
|
||
// For EQ, z is filled with 0 except at the lowest index i such that x[i] != y[i]. If | ||
// such an i exists z[i] is the inverse of x[i] - y[i], meaning sum_eq should be 1. | ||
let mut sum_eq: AB::Expr = io.cmp_result.into(); | ||
for i in 0..NUM_LIMBS { | ||
sum_eq += (x_limbs[i] - y_limbs[i]) * z_limbs[i]; | ||
builder | ||
.when(aux.opcode_eq_flag) | ||
.assert_zero(io.cmp_result * (x_limbs[i] - y_limbs[i])); | ||
} | ||
builder | ||
.when(aux.opcode_eq_flag) | ||
.assert_zero(sum_eq - AB::Expr::one()); | ||
|
||
let expected_opcode = flags | ||
.iter() | ||
.zip(ALU_256_INSTRUCTIONS) | ||
.fold(AB::Expr::zero(), |acc, (flag, opcode)| { | ||
acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8) | ||
}); | ||
|
||
self.eval_interactions(builder, io, aux, expected_opcode); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
use afs_stark_backend::interaction::InteractionBuilder; | ||
use itertools::izip; | ||
use p3_field::AbstractField; | ||
|
||
use super::{ | ||
air::ArithmeticLogicAir, | ||
columns::{ArithmeticLogicAuxCols, ArithmeticLogicIoCols}, | ||
}; | ||
use crate::memory::MemoryAddress; | ||
|
||
impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> ArithmeticLogicAir<NUM_LIMBS, LIMB_BITS> { | ||
pub fn eval_interactions<AB: InteractionBuilder>( | ||
&self, | ||
builder: &mut AB, | ||
io: &ArithmeticLogicIoCols<AB::Var, NUM_LIMBS, LIMB_BITS>, | ||
aux: &ArithmeticLogicAuxCols<AB::Var, NUM_LIMBS, LIMB_BITS>, | ||
expected_opcode: AB::Expr, | ||
) { | ||
let timestamp: AB::Var = io.from_state.timestamp; | ||
let mut timestamp_delta: usize = 0; | ||
let mut timestamp_pp = || { | ||
timestamp_delta += 1; | ||
timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) | ||
}; | ||
|
||
let range_check = | ||
aux.opcode_add_flag + aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag; | ||
let bitwise = aux.opcode_xor_flag + aux.opcode_and_flag + aux.opcode_or_flag; | ||
|
||
// Read the operand pointer's values, which are themselves pointers | ||
// for the actual IO data. | ||
for (ptr, value, mem_aux) in izip!( | ||
[ | ||
io.z.ptr_to_address, | ||
io.x.ptr_to_address, | ||
io.y.ptr_to_address | ||
], | ||
[io.z.address, io.x.address, io.y.address], | ||
&aux.read_ptr_aux_cols | ||
) { | ||
self.memory_bridge | ||
.read( | ||
MemoryAddress::new(io.ptr_as, ptr), | ||
[value], | ||
timestamp_pp(), | ||
mem_aux, | ||
) | ||
.eval(builder, aux.is_valid); | ||
} | ||
|
||
self.memory_bridge | ||
.read( | ||
MemoryAddress::new(io.address_as, io.x.address), | ||
io.x.data, | ||
timestamp_pp(), | ||
&aux.read_x_aux_cols, | ||
) | ||
.eval(builder, aux.is_valid); | ||
|
||
self.memory_bridge | ||
.read( | ||
MemoryAddress::new(io.address_as, io.y.address), | ||
io.y.data, | ||
timestamp_pp(), | ||
&aux.read_y_aux_cols, | ||
) | ||
.eval(builder, aux.is_valid); | ||
|
||
// Special handling for writing output z data: | ||
self.memory_bridge | ||
.write( | ||
MemoryAddress::new(io.address_as, io.z.address), | ||
io.z.data, | ||
timestamp + AB::F::from_canonical_usize(timestamp_delta), | ||
&aux.write_z_aux_cols, | ||
) | ||
.eval( | ||
builder, | ||
aux.opcode_add_flag + aux.opcode_sub_flag + bitwise.clone(), | ||
); | ||
|
||
// Special handling for writing output cmp data: | ||
self.memory_bridge | ||
.write( | ||
MemoryAddress::new(io.address_as, io.z.address), | ||
[io.cmp_result], | ||
timestamp + AB::F::from_canonical_usize(timestamp_delta), | ||
&aux.write_cmp_aux_cols, | ||
) | ||
.eval( | ||
builder, | ||
aux.opcode_lt_flag + aux.opcode_eq_flag + aux.opcode_slt_flag, | ||
); | ||
timestamp_delta += 1; | ||
|
||
self.execution_bridge | ||
.execute_and_increment_pc( | ||
expected_opcode, | ||
[ | ||
io.z.ptr_to_address, | ||
io.x.ptr_to_address, | ||
io.y.ptr_to_address, | ||
io.ptr_as, | ||
io.address_as, | ||
], | ||
io.from_state, | ||
AB::F::from_canonical_usize(timestamp_delta), | ||
) | ||
.eval(builder, aux.is_valid); | ||
|
||
// Check x_sign & x[NUM_LIMBS - 1] == x_sign using XOR | ||
let x_sign_shifted = aux.x_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); | ||
let y_sign_shifted = aux.y_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); | ||
self.bus | ||
.send( | ||
x_sign_shifted.clone(), | ||
io.x.data[NUM_LIMBS - 1], | ||
io.x.data[NUM_LIMBS - 1] - x_sign_shifted, | ||
) | ||
.eval(builder, aux.opcode_slt_flag); | ||
self.bus | ||
.send( | ||
y_sign_shifted.clone(), | ||
io.y.data[NUM_LIMBS - 1], | ||
io.y.data[NUM_LIMBS - 1] - y_sign_shifted, | ||
) | ||
.eval(builder, aux.opcode_slt_flag); | ||
|
||
// Chip-specific interactions | ||
for i in 0..NUM_LIMBS { | ||
let x = range_check.clone() * io.z.data[i] + bitwise.clone() * io.x.data[i]; | ||
let y = range_check.clone() * io.z.data[i] + bitwise.clone() * io.y.data[i]; | ||
let xor_res = aux.opcode_xor_flag * io.z.data[i] | ||
+ aux.opcode_and_flag | ||
* (io.x.data[i] + io.y.data[i] | ||
- (AB::Expr::from_canonical_u32(2) * io.z.data[i])) | ||
+ aux.opcode_or_flag | ||
* ((AB::Expr::from_canonical_u32(2) * io.z.data[i]) | ||
- io.x.data[i] | ||
- io.y.data[i]); | ||
self.bus | ||
.send(x, y, xor_res) | ||
.eval(builder, range_check.clone() + bitwise.clone()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
use std::mem::size_of; | ||
|
||
use afs_derive::AlignedBorrow; | ||
|
||
use crate::{ | ||
arch::columns::ExecutionState, | ||
memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, | ||
uint_multiplication::MemoryData, | ||
}; | ||
|
||
#[repr(C)] | ||
#[derive(AlignedBorrow)] | ||
pub struct ArithmeticLogicCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> { | ||
pub io: ArithmeticLogicIoCols<T, NUM_LIMBS, LIMB_BITS>, | ||
pub aux: ArithmeticLogicAuxCols<T, NUM_LIMBS, LIMB_BITS>, | ||
} | ||
|
||
impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> | ||
ArithmeticLogicCols<T, NUM_LIMBS, LIMB_BITS> | ||
{ | ||
pub fn width() -> usize { | ||
ArithmeticLogicAuxCols::<T, NUM_LIMBS, LIMB_BITS>::width() | ||
+ ArithmeticLogicIoCols::<T, NUM_LIMBS, LIMB_BITS>::width() | ||
} | ||
} | ||
|
||
#[repr(C)] | ||
#[derive(AlignedBorrow)] | ||
pub struct ArithmeticLogicIoCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> { | ||
pub from_state: ExecutionState<T>, | ||
pub x: MemoryData<T, NUM_LIMBS, LIMB_BITS>, | ||
pub y: MemoryData<T, NUM_LIMBS, LIMB_BITS>, | ||
pub z: MemoryData<T, NUM_LIMBS, LIMB_BITS>, | ||
pub cmp_result: T, | ||
pub ptr_as: T, | ||
pub address_as: T, | ||
} | ||
|
||
impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> | ||
ArithmeticLogicIoCols<T, NUM_LIMBS, LIMB_BITS> | ||
{ | ||
pub fn width() -> usize { | ||
size_of::<ArithmeticLogicIoCols<u8, NUM_LIMBS, LIMB_BITS>>() | ||
} | ||
} | ||
|
||
#[repr(C)] | ||
#[derive(AlignedBorrow)] | ||
pub struct ArithmeticLogicAuxCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> { | ||
pub is_valid: T, | ||
pub x_sign: T, | ||
pub y_sign: T, | ||
|
||
// Opcode flags for different operations | ||
pub opcode_add_flag: T, | ||
pub opcode_sub_flag: T, | ||
pub opcode_lt_flag: T, | ||
pub opcode_eq_flag: T, | ||
pub opcode_xor_flag: T, | ||
pub opcode_and_flag: T, | ||
pub opcode_or_flag: T, | ||
pub opcode_slt_flag: T, | ||
|
||
/// Pointer read auxiliary columns for [z, x, y]. | ||
/// **Note** the ordering, which is designed to match the instruction order. | ||
pub read_ptr_aux_cols: [MemoryReadAuxCols<T, 1>; 3], | ||
pub read_x_aux_cols: MemoryReadAuxCols<T, NUM_LIMBS>, | ||
pub read_y_aux_cols: MemoryReadAuxCols<T, NUM_LIMBS>, | ||
pub write_z_aux_cols: MemoryWriteAuxCols<T, NUM_LIMBS>, | ||
pub write_cmp_aux_cols: MemoryWriteAuxCols<T, 1>, | ||
} | ||
|
||
impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> | ||
ArithmeticLogicAuxCols<T, NUM_LIMBS, LIMB_BITS> | ||
{ | ||
pub fn width() -> usize { | ||
size_of::<ArithmeticLogicAuxCols<u8, NUM_LIMBS, LIMB_BITS>>() | ||
} | ||
} |
Oops, something went wrong.