diff --git a/primitives/src/xor/lookup/mod.rs b/primitives/src/xor/lookup/mod.rs index 939f1aebca..464f51bcca 100644 --- a/primitives/src/xor/lookup/mod.rs +++ b/primitives/src/xor/lookup/mod.rs @@ -48,4 +48,12 @@ impl XorLookupChip { self.calc_xor(x, y) } + + pub fn clear(&self) { + for i in 0..(1 << M) { + for j in 0..(1 << M) { + self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed); + } + } + } } diff --git a/vm/src/alu/air.rs b/vm/src/alu/air.rs new file mode 100644 index 0000000000..01351445ec --- /dev/null +++ b/vm/src/alu/air.rs @@ -0,0 +1,145 @@ +use std::{array, borrow::Borrow}; + +use afs_primitives::{utils, xor::bus::XorBus}; +use afs_stark_backend::interaction::InteractionBuilder; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{AbstractField, Field}; +use p3_matrix::Matrix; + +use super::columns::ArithmeticLogicCols; +use crate::{ + arch::{bridge::ExecutionBridge, instructions::ALU_256_INSTRUCTIONS}, + memory::offline_checker::MemoryBridge, +}; + +#[derive(Copy, Clone, Debug)] +pub struct ArithmeticLogicAir { + pub(super) execution_bridge: ExecutionBridge, + pub(super) memory_bridge: MemoryBridge, + pub bus: XorBus, +} + +impl BaseAir + for ArithmeticLogicAir +{ + fn width(&self) -> usize { + ArithmeticLogicCols::::width() + } +} + +impl Air + for ArithmeticLogicAir +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + + let ArithmeticLogicCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow(); + builder.assert_bool(aux.is_valid); + + let flags = [ + aux.opcode_add_flag, + aux.opcode_sub_flag, + aux.opcode_lt_flag, + aux.opcode_eq_flag, + aux.opcode_xor_flag, + aux.opcode_and_flag, + aux.opcode_or_flag, + aux.opcode_slt_flag, + ]; + for flag in flags { + builder.assert_bool(flag); + } + + builder.assert_eq( + aux.is_valid, + flags + .iter() + .fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()), + ); + + let x_limbs = &io.x.data; + let y_limbs = &io.y.data; + let z_limbs = &io.z.data; + + // For ADD, define carry[i] = (x[i] + y[i] + carry[i - 1] - z[i]) / 2^LIMB_BITS. If + // each carry[i] is boolean and 0 <= z[i] < 2^NUM_LIMBS, it can be proven that + // z[i] = (x[i] + y[i]) % 256 as necessary. The same holds for SUB when carry[i] is + // (z[i] + y[i] - x[i] + carry[i - 1]) / 2^LIMB_BITS. + let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); + let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); + let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse(); + + for i in 0..NUM_LIMBS { + // We explicitly separate the constraints for ADD and SUB in order to keep degree + // cubic. Because we constrain that the carry (which is arbitrary) is bool, if + // carry has degree larger than 1 the max-degree constrain could be at least 4. + carry_add[i] = AB::Expr::from(carry_divide) + * (x_limbs[i] + y_limbs[i] - z_limbs[i] + + if i > 0 { + carry_add[i - 1].clone() + } else { + AB::Expr::zero() + }); + builder + .when(aux.opcode_add_flag) + .assert_bool(carry_add[i].clone()); + carry_sub[i] = AB::Expr::from(carry_divide) + * (z_limbs[i] + y_limbs[i] - x_limbs[i] + + if i > 0 { + carry_sub[i - 1].clone() + } else { + AB::Expr::zero() + }); + builder + .when(aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag) + .assert_bool(carry_sub[i].clone()); + } + + // For LT, cmp_result must be equal to the last carry. For SLT, cmp_result ^ x_sign ^ y_sign must + // be equal to the last carry. To ensure maximum cubic degree constraints, we set aux.x_sign and + // aux.y_sign are 0 when not computing an SLT. + builder.assert_bool(aux.x_sign); + builder.assert_bool(aux.y_sign); + builder + .when(utils::not(aux.opcode_slt_flag)) + .assert_zero(aux.x_sign); + builder + .when(utils::not(aux.opcode_slt_flag)) + .assert_zero(aux.y_sign); + + let slt_xor = + (aux.opcode_lt_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign + - AB::Expr::from_canonical_u32(2) + * (io.cmp_result * aux.x_sign + + io.cmp_result * aux.y_sign + + aux.x_sign * aux.y_sign) + + AB::Expr::from_canonical_u32(4) * (io.cmp_result * aux.x_sign * aux.y_sign); + builder.assert_eq( + slt_xor, + (aux.opcode_lt_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(), + ); + + // For EQ, z is filled with 0 except at the lowest index i such that x[i] != y[i]. If + // such an i exists z[i] is the inverse of x[i] - y[i], meaning sum_eq should be 1. + let mut sum_eq: AB::Expr = io.cmp_result.into(); + for i in 0..NUM_LIMBS { + sum_eq += (x_limbs[i] - y_limbs[i]) * z_limbs[i]; + builder + .when(aux.opcode_eq_flag) + .assert_zero(io.cmp_result * (x_limbs[i] - y_limbs[i])); + } + builder + .when(aux.opcode_eq_flag) + .assert_zero(sum_eq - AB::Expr::one()); + + let expected_opcode = flags + .iter() + .zip(ALU_256_INSTRUCTIONS) + .fold(AB::Expr::zero(), |acc, (flag, opcode)| { + acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8) + }); + + self.eval_interactions(builder, io, aux, expected_opcode); + } +} diff --git a/vm/src/alu/bridge.rs b/vm/src/alu/bridge.rs new file mode 100644 index 0000000000..b41e90f6c4 --- /dev/null +++ b/vm/src/alu/bridge.rs @@ -0,0 +1,146 @@ +use afs_stark_backend::interaction::InteractionBuilder; +use itertools::izip; +use p3_field::AbstractField; + +use super::{ + air::ArithmeticLogicAir, + columns::{ArithmeticLogicAuxCols, ArithmeticLogicIoCols}, +}; +use crate::memory::MemoryAddress; + +impl ArithmeticLogicAir { + pub fn eval_interactions( + &self, + builder: &mut AB, + io: &ArithmeticLogicIoCols, + aux: &ArithmeticLogicAuxCols, + expected_opcode: AB::Expr, + ) { + let timestamp: AB::Var = io.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) + }; + + let range_check = + aux.opcode_add_flag + aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag; + let bitwise = aux.opcode_xor_flag + aux.opcode_and_flag + aux.opcode_or_flag; + + // Read the operand pointer's values, which are themselves pointers + // for the actual IO data. + for (ptr, value, mem_aux) in izip!( + [ + io.z.ptr_to_address, + io.x.ptr_to_address, + io.y.ptr_to_address + ], + [io.z.address, io.x.address, io.y.address], + &aux.read_ptr_aux_cols + ) { + self.memory_bridge + .read( + MemoryAddress::new(io.ptr_as, ptr), + [value], + timestamp_pp(), + mem_aux, + ) + .eval(builder, aux.is_valid); + } + + self.memory_bridge + .read( + MemoryAddress::new(io.address_as, io.x.address), + io.x.data, + timestamp_pp(), + &aux.read_x_aux_cols, + ) + .eval(builder, aux.is_valid); + + self.memory_bridge + .read( + MemoryAddress::new(io.address_as, io.y.address), + io.y.data, + timestamp_pp(), + &aux.read_y_aux_cols, + ) + .eval(builder, aux.is_valid); + + // Special handling for writing output z data: + self.memory_bridge + .write( + MemoryAddress::new(io.address_as, io.z.address), + io.z.data, + timestamp + AB::F::from_canonical_usize(timestamp_delta), + &aux.write_z_aux_cols, + ) + .eval( + builder, + aux.opcode_add_flag + aux.opcode_sub_flag + bitwise.clone(), + ); + + // Special handling for writing output cmp data: + self.memory_bridge + .write( + MemoryAddress::new(io.address_as, io.z.address), + [io.cmp_result], + timestamp + AB::F::from_canonical_usize(timestamp_delta), + &aux.write_cmp_aux_cols, + ) + .eval( + builder, + aux.opcode_lt_flag + aux.opcode_eq_flag + aux.opcode_slt_flag, + ); + timestamp_delta += 1; + + self.execution_bridge + .execute_and_increment_pc( + expected_opcode, + [ + io.z.ptr_to_address, + io.x.ptr_to_address, + io.y.ptr_to_address, + io.ptr_as, + io.address_as, + ], + io.from_state, + AB::F::from_canonical_usize(timestamp_delta), + ) + .eval(builder, aux.is_valid); + + // Check x_sign & x[NUM_LIMBS - 1] == x_sign using XOR + let x_sign_shifted = aux.x_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); + let y_sign_shifted = aux.y_sign * AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); + self.bus + .send( + x_sign_shifted.clone(), + io.x.data[NUM_LIMBS - 1], + io.x.data[NUM_LIMBS - 1] - x_sign_shifted, + ) + .eval(builder, aux.opcode_slt_flag); + self.bus + .send( + y_sign_shifted.clone(), + io.y.data[NUM_LIMBS - 1], + io.y.data[NUM_LIMBS - 1] - y_sign_shifted, + ) + .eval(builder, aux.opcode_slt_flag); + + // Chip-specific interactions + for i in 0..NUM_LIMBS { + let x = range_check.clone() * io.z.data[i] + bitwise.clone() * io.x.data[i]; + let y = range_check.clone() * io.z.data[i] + bitwise.clone() * io.y.data[i]; + let xor_res = aux.opcode_xor_flag * io.z.data[i] + + aux.opcode_and_flag + * (io.x.data[i] + io.y.data[i] + - (AB::Expr::from_canonical_u32(2) * io.z.data[i])) + + aux.opcode_or_flag + * ((AB::Expr::from_canonical_u32(2) * io.z.data[i]) + - io.x.data[i] + - io.y.data[i]); + self.bus + .send(x, y, xor_res) + .eval(builder, range_check.clone() + bitwise.clone()); + } + } +} diff --git a/vm/src/alu/columns.rs b/vm/src/alu/columns.rs new file mode 100644 index 0000000000..1bb964c7fd --- /dev/null +++ b/vm/src/alu/columns.rs @@ -0,0 +1,79 @@ +use std::mem::size_of; + +use afs_derive::AlignedBorrow; + +use crate::{ + arch::columns::ExecutionState, + memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, + uint_multiplication::MemoryData, +}; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ArithmeticLogicCols { + pub io: ArithmeticLogicIoCols, + pub aux: ArithmeticLogicAuxCols, +} + +impl + ArithmeticLogicCols +{ + pub fn width() -> usize { + ArithmeticLogicAuxCols::::width() + + ArithmeticLogicIoCols::::width() + } +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ArithmeticLogicIoCols { + pub from_state: ExecutionState, + pub x: MemoryData, + pub y: MemoryData, + pub z: MemoryData, + pub cmp_result: T, + pub ptr_as: T, + pub address_as: T, +} + +impl + ArithmeticLogicIoCols +{ + pub fn width() -> usize { + size_of::>() + } +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ArithmeticLogicAuxCols { + pub is_valid: T, + pub x_sign: T, + pub y_sign: T, + + // Opcode flags for different operations + pub opcode_add_flag: T, + pub opcode_sub_flag: T, + pub opcode_lt_flag: T, + pub opcode_eq_flag: T, + pub opcode_xor_flag: T, + pub opcode_and_flag: T, + pub opcode_or_flag: T, + pub opcode_slt_flag: T, + + /// Pointer read auxiliary columns for [z, x, y]. + /// **Note** the ordering, which is designed to match the instruction order. + pub read_ptr_aux_cols: [MemoryReadAuxCols; 3], + pub read_x_aux_cols: MemoryReadAuxCols, + pub read_y_aux_cols: MemoryReadAuxCols, + pub write_z_aux_cols: MemoryWriteAuxCols, + pub write_cmp_aux_cols: MemoryWriteAuxCols, +} + +impl + ArithmeticLogicAuxCols +{ + pub fn width() -> usize { + size_of::>() + } +} diff --git a/vm/src/alu/mod.rs b/vm/src/alu/mod.rs new file mode 100644 index 0000000000..7903914ec0 --- /dev/null +++ b/vm/src/alu/mod.rs @@ -0,0 +1,286 @@ +use std::sync::Arc; + +use afs_primitives::xor::lookup::XorLookupChip; +use air::ArithmeticLogicAir; +use p3_field::PrimeField32; + +use crate::{ + arch::{ + bridge::ExecutionBridge, + bus::ExecutionBus, + chips::InstructionExecutor, + columns::ExecutionState, + instructions::{Opcode, ALU_256_INSTRUCTIONS}, + }, + memory::{MemoryChipRef, MemoryReadRecord, MemoryWriteRecord}, + program::{bridge::ProgramBus, ExecutionError, Instruction}, +}; + +mod air; +mod bridge; +mod columns; +mod trace; + +// pub use air::*; +pub use columns::*; + +#[cfg(test)] +mod tests; + +pub const ALU_CMP_INSTRUCTIONS: [Opcode; 3] = [Opcode::LT256, Opcode::EQ256, Opcode::SLT256]; +pub const ALU_ARITHMETIC_INSTRUCTIONS: [Opcode; 2] = [Opcode::ADD256, Opcode::SUB256]; +pub const ALU_BITWISE_INSTRUCTIONS: [Opcode; 3] = [Opcode::XOR256, Opcode::AND256, Opcode::OR256]; + +#[derive(Debug)] +pub enum WriteRecord { + Long(MemoryWriteRecord), + Bool(MemoryWriteRecord), +} + +#[derive(Debug)] +pub struct ArithmeticLogicRecord { + pub from_state: ExecutionState, + pub instruction: Instruction, + + pub x_ptr_read: MemoryReadRecord, + pub y_ptr_read: MemoryReadRecord, + pub z_ptr_read: MemoryReadRecord, + + pub x_read: MemoryReadRecord, + pub y_read: MemoryReadRecord, + pub z_write: WriteRecord, + + // sign of x and y if SLT, else should be 0 + pub x_sign: T, + pub y_sign: T, + + // empty if not bool instruction, else contents of this vector will be stored in z + pub cmp_buffer: Vec, +} + +#[derive(Debug)] +pub struct ArithmeticLogicChip { + pub air: ArithmeticLogicAir, + data: Vec>, + memory_chip: MemoryChipRef, + pub xor_lookup_chip: Arc>, +} + +impl + ArithmeticLogicChip +{ + pub fn new( + execution_bus: ExecutionBus, + program_bus: ProgramBus, + memory_chip: MemoryChipRef, + xor_lookup_chip: Arc>, + ) -> Self { + let memory_bridge = memory_chip.borrow().memory_bridge(); + Self { + air: ArithmeticLogicAir { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bus: xor_lookup_chip.bus(), + }, + data: vec![], + memory_chip, + xor_lookup_chip, + } + } +} + +impl InstructionExecutor + for ArithmeticLogicChip +{ + fn execute( + &mut self, + instruction: Instruction, + from_state: ExecutionState, + ) -> Result, ExecutionError> { + let Instruction { + opcode, + op_a: a, + op_b: b, + op_c: c, + d, + e, + .. + } = instruction.clone(); + assert!(ALU_256_INSTRUCTIONS.contains(&opcode)); + + let mut memory_chip = self.memory_chip.borrow_mut(); + debug_assert_eq!( + from_state.timestamp, + memory_chip.timestamp().as_canonical_u32() as usize + ); + + let [z_ptr_read, x_ptr_read, y_ptr_read] = + [a, b, c].map(|ptr_of_ptr| memory_chip.read_cell(d, ptr_of_ptr)); + let x_read = memory_chip.read::(e, x_ptr_read.value()); + let y_read = memory_chip.read::(e, y_ptr_read.value()); + + let x = x_read.data.map(|x| x.as_canonical_u32()); + let y = y_read.data.map(|x| x.as_canonical_u32()); + let (z, cmp) = solve_alu::(opcode, &x, &y); + + let z_write = if ALU_CMP_INSTRUCTIONS.contains(&opcode) { + WriteRecord::Bool(memory_chip.write_cell(e, z_ptr_read.value(), T::from_bool(cmp))) + } else { + WriteRecord::Long( + memory_chip.write::( + e, + z_ptr_read.value(), + z.clone() + .into_iter() + .map(T::from_canonical_u32) + .collect::>() + .try_into() + .unwrap(), + ), + ) + }; + + let mut x_sign = 0; + let mut y_sign = 0; + + if opcode == Opcode::SLT256 { + x_sign = x[NUM_LIMBS - 1] >> (LIMB_BITS - 1); + y_sign = y[NUM_LIMBS - 1] >> (LIMB_BITS - 1); + self.xor_lookup_chip + .request(x_sign * (1 << (LIMB_BITS - 1)), x[NUM_LIMBS - 1]); + self.xor_lookup_chip + .request(y_sign * (1 << (LIMB_BITS - 1)), y[NUM_LIMBS - 1]); + } + + if ALU_BITWISE_INSTRUCTIONS.contains(&opcode) { + for i in 0..NUM_LIMBS { + self.xor_lookup_chip.request(x[i], y[i]); + } + } else if opcode != Opcode::EQ256 { + for z_val in &z { + self.xor_lookup_chip.request(*z_val, *z_val); + } + } + + self.data + .push(ArithmeticLogicRecord:: { + from_state, + instruction: instruction.clone(), + x_ptr_read, + y_ptr_read, + z_ptr_read, + x_read, + y_read, + z_write, + x_sign: T::from_canonical_u32(x_sign), + y_sign: T::from_canonical_u32(y_sign), + cmp_buffer: if ALU_CMP_INSTRUCTIONS.contains(&opcode) { + z.into_iter().map(T::from_canonical_u32).collect() + } else { + vec![] + }, + }); + + Ok(ExecutionState { + pc: from_state.pc + 1, + timestamp: memory_chip.timestamp().as_canonical_u32() as usize, + }) + } +} + +fn solve_alu( + opcode: Opcode, + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + match opcode { + Opcode::ADD256 => solve_add::(x, y), + Opcode::SUB256 | Opcode::LT256 => solve_subtract::(x, y), + Opcode::EQ256 => solve_eq::(x, y), + Opcode::XOR256 => solve_xor::(x, y), + Opcode::AND256 => solve_and::(x, y), + Opcode::OR256 => solve_or::(x, y), + Opcode::SLT256 => { + let (z, cmp) = solve_subtract::(x, y); + ( + z, + cmp ^ (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) != 0) + ^ (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) != 0), + ) + } + _ => unreachable!(), + } +} + +fn solve_add( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let mut z = vec![0u32; NUM_LIMBS]; + let mut carry = vec![0u32; NUM_LIMBS]; + for i in 0..NUM_LIMBS { + z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; + carry[i] = z[i] >> LIMB_BITS; + z[i] &= (1 << LIMB_BITS) - 1; + } + (z, false) +} + +fn solve_subtract( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let mut z = vec![0u32; NUM_LIMBS]; + let mut carry = vec![0u32; NUM_LIMBS]; + for i in 0..NUM_LIMBS { + let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; + if x[i] >= rhs { + z[i] = x[i] - rhs; + carry[i] = 0; + } else { + z[i] = x[i] + (1 << LIMB_BITS) - rhs; + carry[i] = 1; + } + } + (z, carry[NUM_LIMBS - 1] != 0) +} + +fn solve_eq( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let mut z = vec![0u32; NUM_LIMBS]; + for i in 0..NUM_LIMBS { + if x[i] != y[i] { + z[i] = (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])) + .inverse() + .as_canonical_u32(); + return (z, false); + } + } + (z, true) +} + +fn solve_xor( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let z = (0..NUM_LIMBS).map(|i| x[i] ^ y[i]).collect(); + (z, false) +} + +fn solve_and( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let z = (0..NUM_LIMBS).map(|i| x[i] & y[i]).collect(); + (z, false) +} + +fn solve_or( + x: &[u32], + y: &[u32], +) -> (Vec, bool) { + let z = (0..NUM_LIMBS).map(|i| x[i] | y[i]).collect(); + (z, false) +} diff --git a/vm/src/alu/tests.rs b/vm/src/alu/tests.rs new file mode 100644 index 0000000000..2875520cee --- /dev/null +++ b/vm/src/alu/tests.rs @@ -0,0 +1,628 @@ +use std::{array, borrow::BorrowMut, iter, sync::Arc}; + +use afs_primitives::xor::lookup::XorLookupChip; +use afs_stark_backend::{utils::disable_debug_builder, verifier::VerificationError}; +use ax_sdk::{ + any_rap_vec, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, + utils::create_seeded_rng, +}; +use p3_baby_bear::BabyBear; +use p3_field::{AbstractField, PrimeField32}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use rand::{rngs::StdRng, Rng}; + +use super::{ + columns::ArithmeticLogicCols, solve_subtract, ArithmeticLogicChip, ALU_CMP_INSTRUCTIONS, +}; +use crate::{ + alu::solve_alu, + arch::{chips::MachineChip, instructions::Opcode, testing::MachineChipTestBuilder}, + core::BYTE_XOR_BUS, + program::Instruction, +}; + +type F = BabyBear; + +const NUM_LIMBS: usize = 32; +const LIMB_BITS: usize = 8; + +fn generate_long_number( + rng: &mut StdRng, +) -> Vec { + (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..1 << LIMB_BITS)) + .collect() +} + +#[allow(clippy::too_many_arguments)] +fn run_alu_rand_write_execute( + tester: &mut MachineChipTestBuilder, + chip: &mut ArithmeticLogicChip, + opcode: Opcode, + x: Vec, + y: Vec, + rng: &mut StdRng, +) { + let address_space_range = || 1usize..=2; + let address_range = || 0usize..1 << 29; + + let d = rng.gen_range(address_space_range()); + let e = rng.gen_range(address_space_range()); + + let x_address = rng.gen_range(address_range()); + let y_address = rng.gen_range(address_range()); + let res_address = rng.gen_range(address_range()); + let x_ptr_to_address = rng.gen_range(address_range()); + let y_ptr_to_address = rng.gen_range(address_range()); + let res_ptr_to_address = rng.gen_range(address_range()); + + let x_f = x + .clone() + .into_iter() + .map(F::from_canonical_u32) + .collect::>(); + let y_f = y + .clone() + .into_iter() + .map(F::from_canonical_u32) + .collect::>(); + + tester.write_cell(d, x_ptr_to_address, F::from_canonical_usize(x_address)); + tester.write_cell(d, y_ptr_to_address, F::from_canonical_usize(y_address)); + tester.write_cell(d, res_ptr_to_address, F::from_canonical_usize(res_address)); + tester.write::(e, x_address, x_f.as_slice().try_into().unwrap()); + tester.write::(e, y_address, y_f.as_slice().try_into().unwrap()); + + let (z, cmp) = solve_alu::(opcode, &x, &y); + tester.execute( + chip, + Instruction::from_usize( + opcode, + [res_ptr_to_address, x_ptr_to_address, y_ptr_to_address, d, e], + ), + ); + + if ALU_CMP_INSTRUCTIONS.contains(&opcode) { + assert_eq!([F::from_bool(cmp)], tester.read::<1>(e, res_address)) + } else { + assert_eq!( + z.into_iter().map(F::from_canonical_u32).collect::>(), + tester.read::(e, res_address) + ) + } +} + +/// Given a fake trace of a single operation, setup a chip and run the test. +/// We replace the "output" part of the trace, and we _may_ replace the interactions +/// based on the desired output. We check that it produces the error we expect. +#[allow(clippy::too_many_arguments)] +fn run_alu_negative_test( + opcode: Opcode, + x: Vec, + y: Vec, + z: Vec, + cmp_result: bool, + x_sign: u32, + y_sign: u32, + expected_error: VerificationError, +) { + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester: MachineChipTestBuilder = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + let mut rng = create_seeded_rng(); + run_alu_rand_write_execute( + &mut tester, + &mut chip, + opcode, + x.clone(), + y.clone(), + &mut rng, + ); + + let alu_air = chip.air; + let alu_trace = chip.generate_trace(); + + let mut alu_trace_row = alu_trace.row_slice(0).to_vec(); + let alu_trace_cols: &mut ArithmeticLogicCols = (*alu_trace_row).borrow_mut(); + alu_trace_cols.io.z.data = array::from_fn(|i| F::from_canonical_u32(z[i])); + alu_trace_cols.io.cmp_result = F::from_bool(cmp_result); + alu_trace_cols.aux.x_sign = F::from_canonical_u32(x_sign); + alu_trace_cols.aux.y_sign = F::from_canonical_u32(y_sign); + let alu_trace: p3_matrix::dense::DenseMatrix<_> = RowMajorMatrix::new( + alu_trace_row, + ArithmeticLogicCols::::width(), + ); + + let xor_lookup_air = xor_lookup_chip.air; + let xor_lookup_trace = xor_lookup_chip.generate_trace(); + + disable_debug_builder(); + let msg = format!( + "Expected verification to fail with {:?}, but it didn't", + &expected_error + ); + let result = BabyBearPoseidon2Engine::run_simple_test_no_pis( + &any_rap_vec![&alu_air, &xor_lookup_air], + vec![alu_trace, xor_lookup_trace], + ); + assert_eq!(result.err(), Some(expected_error), "{}", msg); +} + +#[test] +fn alu_add_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::ADD256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_add_out_of_range_negative_test() { + run_alu_negative_test( + Opcode::ADD256, + iter::once(250) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(250) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(500) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + false, + 0, + 0, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_add_wrong_negative_test() { + run_alu_negative_test( + Opcode::ADD256, + iter::once(250) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(250) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(500 - (1 << 8)) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + false, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_sub_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::SUB256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_sub_out_of_range_negative_test() { + run_alu_negative_test( + Opcode::SUB256, + iter::once(1) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(2) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(F::neg_one().as_canonical_u32()) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + false, + 0, + 0, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_sub_wrong_negative_test() { + run_alu_negative_test( + Opcode::SUB256, + iter::once(1) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once(2) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + iter::once((1 << 8) - 1) + .chain(iter::repeat(0).take(NUM_LIMBS - 1)) + .collect(), + false, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_lt_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::LT256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_lt_wrong_subtraction_test() { + run_alu_negative_test( + Opcode::LT256, + iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), + iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), + std::iter::once(1).chain(iter::repeat(0).take(31)).collect(), + false, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_lt_wrong_negative_test() { + run_alu_negative_test( + Opcode::LT256, + iter::once(1).chain(iter::repeat(0).take(31)).collect(), + iter::once(1).chain(iter::repeat(0).take(31)).collect(), + iter::once(0).chain(iter::repeat(0).take(31)).collect(), + true, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_lt_non_zero_sign_negative_test() { + run_alu_negative_test( + Opcode::LT256, + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + vec![0; NUM_LIMBS], + false, + 1, + 1, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_eq_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::EQ256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_eq_wrong_negative_test() { + run_alu_negative_test( + Opcode::EQ256, + vec![0; 31].into_iter().chain(iter::once(123)).collect(), + vec![0; 31].into_iter().chain(iter::once(456)).collect(), + vec![0; 31].into_iter().chain(iter::once(0)).collect(), + true, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_xor_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::XOR256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_xor_wrong_negative_test() { + run_alu_negative_test( + Opcode::XOR256, + vec![0; 31].into_iter().chain(iter::once(1)).collect(), + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + true, + 0, + 0, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_and_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::AND256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_and_wrong_negative_test() { + run_alu_negative_test( + Opcode::AND256, + vec![0; 31].into_iter().chain(iter::once(1)).collect(), + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + vec![0; NUM_LIMBS], + true, + 0, + 0, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_or_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::OR256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_or_wrong_negative_test() { + run_alu_negative_test( + Opcode::OR256, + vec![0; NUM_LIMBS], + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS - 1] + .into_iter() + .chain(iter::once((1 << LIMB_BITS) - 2)) + .collect(), + vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], + true, + 0, + 0, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_slt_rand_test() { + let num_ops: usize = 10; + let mut rng = create_seeded_rng(); + + let xor_lookup_chip = Arc::new(XorLookupChip::::new(BYTE_XOR_BUS)); + let mut tester = MachineChipTestBuilder::default(); + let mut chip = ArithmeticLogicChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_chip(), + xor_lookup_chip.clone(), + ); + + for _ in 0..num_ops { + let x = generate_long_number::(&mut rng); + let y = generate_long_number::(&mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::SLT256, x, y, &mut rng); + } + + let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_slt_pos_neg_sign_negative_test() { + let x = [0; NUM_LIMBS]; + let y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + true, + 0, + 1, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_slt_neg_pos_sign_negative_test() { + let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + let y = [0; NUM_LIMBS]; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + false, + 1, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_slt_both_pos_sign_negative_test() { + let x = [0; NUM_LIMBS]; + let mut y = [0; NUM_LIMBS]; + y[0] = 1; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + false, + 0, + 0, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_slt_both_neg_sign_negative_test() { + let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + y[0] = 1; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + true, + 1, + 1, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn alu_slt_wrong_sign_negative_test() { + let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + y[0] = 1; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + true, + 0, + 1, + VerificationError::NonZeroCumulativeSum, + ); +} + +#[test] +fn alu_slt_non_boolean_sign_negative_test() { + let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; + y[0] = 1; + run_alu_negative_test( + Opcode::SLT256, + x.to_vec(), + y.to_vec(), + solve_subtract::(&x, &y).0, + false, + 2, + 1, + VerificationError::OodEvaluationMismatch, + ); +} diff --git a/vm/src/alu/trace.rs b/vm/src/alu/trace.rs new file mode 100644 index 0000000000..e86720c3d5 --- /dev/null +++ b/vm/src/alu/trace.rs @@ -0,0 +1,122 @@ +use std::{array, borrow::BorrowMut}; + +use afs_stark_backend::{config::StarkGenericConfig, rap::AnyRap}; +use p3_commit::PolynomialSpace; +use p3_field::PrimeField32; +use p3_matrix::dense::RowMajorMatrix; +use p3_uni_stark::Domain; + +use super::{ + columns::{ArithmeticLogicAuxCols, ArithmeticLogicCols, ArithmeticLogicIoCols}, + ArithmeticLogicChip, ArithmeticLogicRecord, WriteRecord, +}; +use crate::{ + arch::{chips::MachineChip, instructions::Opcode}, + memory::offline_checker::MemoryWriteAuxCols, + uint_multiplication::MemoryData, +}; + +impl MachineChip + for ArithmeticLogicChip +{ + fn generate_trace(self) -> RowMajorMatrix { + let aux_cols_factory = self.memory_chip.borrow().aux_cols_factory(); + + let width = self.trace_width(); + let height = self.data.len(); + let padded_height = height.next_power_of_two(); + let mut rows = vec![F::zero(); width * padded_height]; + + for (row, operation) in rows.chunks_mut(width).zip(self.data) { + let ArithmeticLogicRecord:: { + from_state, + instruction, + x_ptr_read, + y_ptr_read, + z_ptr_read, + x_read, + y_read, + z_write, + x_sign, + y_sign, + cmp_buffer, + } = operation; + + let row: &mut ArithmeticLogicCols = row.borrow_mut(); + + row.io = ArithmeticLogicIoCols { + from_state: from_state.map(F::from_canonical_usize), + x: MemoryData:: { + data: x_read.data, + address: x_read.pointer, + ptr_to_address: x_ptr_read.pointer, + }, + y: MemoryData:: { + data: y_read.data, + address: y_read.pointer, + ptr_to_address: y_ptr_read.pointer, + }, + z: match &z_write { + WriteRecord::Long(z) => MemoryData { + data: z.data, + address: z.pointer, + ptr_to_address: z_ptr_read.pointer, + }, + WriteRecord::Bool(z) => MemoryData { + data: array::from_fn(|i| cmp_buffer[i]), + address: z.pointer, + ptr_to_address: z_ptr_read.pointer, + }, + }, + cmp_result: match &z_write { + WriteRecord::Long(_) => F::zero(), + WriteRecord::Bool(z) => z.data[0], + }, + ptr_as: instruction.d, + address_as: instruction.e, + }; + + row.aux = ArithmeticLogicAuxCols { + is_valid: F::one(), + x_sign, + y_sign, + opcode_add_flag: F::from_bool(instruction.opcode == Opcode::ADD256), + opcode_sub_flag: F::from_bool(instruction.opcode == Opcode::SUB256), + opcode_lt_flag: F::from_bool(instruction.opcode == Opcode::LT256), + opcode_eq_flag: F::from_bool(instruction.opcode == Opcode::EQ256), + opcode_xor_flag: F::from_bool(instruction.opcode == Opcode::XOR256), + opcode_and_flag: F::from_bool(instruction.opcode == Opcode::AND256), + opcode_or_flag: F::from_bool(instruction.opcode == Opcode::OR256), + opcode_slt_flag: F::from_bool(instruction.opcode == Opcode::SLT256), + read_ptr_aux_cols: [z_ptr_read, x_ptr_read, y_ptr_read] + .map(|read| aux_cols_factory.make_read_aux_cols(read.clone())), + read_x_aux_cols: aux_cols_factory.make_read_aux_cols(x_read.clone()), + read_y_aux_cols: aux_cols_factory.make_read_aux_cols(y_read.clone()), + write_z_aux_cols: match &z_write { + WriteRecord::Long(z) => aux_cols_factory.make_write_aux_cols(z.clone()), + WriteRecord::Bool(_) => MemoryWriteAuxCols::disabled(), + }, + write_cmp_aux_cols: match &z_write { + WriteRecord::Long(_) => MemoryWriteAuxCols::disabled(), + WriteRecord::Bool(z) => aux_cols_factory.make_write_aux_cols(z.clone()), + }, + }; + } + RowMajorMatrix::new(rows, width) + } + + fn air(&self) -> Box> + where + Domain: PolynomialSpace, + { + Box::new(self.air) + } + + fn current_trace_height(&self) -> usize { + self.data.len() + } + + fn trace_width(&self) -> usize { + ArithmeticLogicCols::::width() + } +} diff --git a/vm/src/arch/chips.rs b/vm/src/arch/chips.rs index 84b11bd473..e7f93584fa 100644 --- a/vm/src/arch/chips.rs +++ b/vm/src/arch/chips.rs @@ -15,6 +15,7 @@ use p3_uni_stark::{Domain, StarkGenericConfig}; use strum_macros::IntoStaticStr; use crate::{ + alu::ArithmeticLogicChip, arch::columns::ExecutionState, castf::CastFChip, core::CoreChip, @@ -28,7 +29,6 @@ use crate::{ program::{ExecutionError, Instruction, ProgramChip}, shift::ShiftChip, ui::UiChip, - uint_arithmetic::UintArithmeticChip, uint_multiplication::UintMultiplicationChip, }; @@ -129,7 +129,7 @@ pub enum InstructionExecutorVariant { Keccak256(Rc>>), ModularAddSub(Rc>>), ModularMultDiv(Rc>>), - U256Arithmetic(Rc>>), + ArithmeticLogicUnit256(Rc>>), U256Multiplication(Rc>>), Shift256(Rc>>), Ui(Rc>>), @@ -151,7 +151,7 @@ pub enum MachineChipVariant { RangeTupleChecker(Arc), Keccak256(Rc>>), ByteXor(Arc>), - U256Arithmetic(Rc>>), + ArithmeticLogicUnit256(Rc>>), U256Multiplication(Rc>>), Shift256(Rc>>), Ui(Rc>>), diff --git a/vm/src/arch/instructions.rs b/vm/src/arch/instructions.rs index 01759784f1..b2a580e1e4 100644 --- a/vm/src/arch/instructions.rs +++ b/vm/src/arch/instructions.rs @@ -66,10 +66,10 @@ pub enum Opcode { MUL256 = 82, LT256 = 83, EQ256 = 84, - // XOR256 = 85, - // AND256 = 86, - // OR256 = 87, - // SLT256 = 88, + XOR256 = 85, + AND256 = 86, + OR256 = 87, + SLT256 = 88, SLL256 = 89, SRL256 = 90, SRA256 = 91, @@ -95,7 +95,8 @@ pub const CORE_INSTRUCTIONS: [Opcode; 17] = [ ]; pub const FIELD_ARITHMETIC_INSTRUCTIONS: [Opcode; 4] = [FADD, FSUB, FMUL, FDIV]; pub const FIELD_EXTENSION_INSTRUCTIONS: [Opcode; 4] = [FE4ADD, FE4SUB, BBE4MUL, BBE4DIV]; -pub const UINT256_ARITHMETIC_INSTRUCTIONS: [Opcode; 4] = [ADD256, SUB256, LT256, EQ256]; +pub const ALU_256_INSTRUCTIONS: [Opcode; 8] = + [ADD256, SUB256, LT256, EQ256, XOR256, AND256, OR256, SLT256]; pub const SHIFT_256_INSTRUCTIONS: [Opcode; 3] = [SLL256, SRL256, SRA256]; pub const UI_32_INSTRUCTIONS: [Opcode; 2] = [LUI, AUIPC]; diff --git a/vm/src/lib.rs b/vm/src/lib.rs index e7a30b5b04..cd230d9ec2 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -1,3 +1,4 @@ +pub mod alu; pub mod arch; pub mod castf; pub mod core; @@ -14,6 +15,5 @@ pub mod program; pub mod sdk; pub mod shift; pub mod ui; -pub mod uint_arithmetic; pub mod uint_multiplication; pub mod vm; diff --git a/vm/src/shift/mod.rs b/vm/src/shift/mod.rs index 2252adbeee..57ff33b282 100644 --- a/vm/src/shift/mod.rs +++ b/vm/src/shift/mod.rs @@ -6,7 +6,7 @@ use crate::{ bus::ExecutionBus, chips::InstructionExecutor, columns::ExecutionState, - instructions::{Opcode, UINT256_ARITHMETIC_INSTRUCTIONS}, + instructions::{Opcode, ALU_256_INSTRUCTIONS}, }, memory::MemoryChipRef, program::{ExecutionError, Instruction}, @@ -65,7 +65,7 @@ impl Instructio e, .. } = instruction.clone(); - assert!(UINT256_ARITHMETIC_INSTRUCTIONS.contains(&opcode)); + assert!(ALU_256_INSTRUCTIONS.contains(&opcode)); let mut memory_chip = self.memory_chip.borrow_mut(); debug_assert_eq!( diff --git a/vm/src/uint_arithmetic/air.rs b/vm/src/uint_arithmetic/air.rs deleted file mode 100644 index bca9a45d84..0000000000 --- a/vm/src/uint_arithmetic/air.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::borrow::Borrow; - -use afs_primitives::{utils, var_range::bus::VariableRangeCheckerBus}; -use afs_stark_backend::interaction::InteractionBuilder; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; - -use super::{columns::UintArithmeticCols, num_limbs}; -use crate::{ - arch::{ - bridge::ExecutionBridge, - instructions::{Opcode, UINT256_ARITHMETIC_INSTRUCTIONS}, - }, - memory::offline_checker::MemoryBridge, -}; - -/// AIR for the uint addition circuit. ARG_SIZE is the size of the arguments in bits, and LIMB_SIZE is the size of the limbs in bits. -#[derive(Copy, Clone, Debug)] -pub struct UintArithmeticAir { - pub(super) execution_bridge: ExecutionBridge, - pub(super) memory_bridge: MemoryBridge, - - pub bus: VariableRangeCheckerBus, // to communicate with the range checker that checks that all limbs are < 2^LIMB_SIZE - pub base_op: Opcode, -} - -impl BaseAir - for UintArithmeticAir -{ - fn width(&self) -> usize { - UintArithmeticCols::::width() - } -} - -impl Air - for UintArithmeticAir -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - - let local = main.row_slice(0); - let local: &[AB::Var] = (*local).borrow(); - - let UintArithmeticCols { io, aux } = - UintArithmeticCols::::from_iterator( - local.iter().copied(), - ); - - let num_limbs = num_limbs::(); - - let flags = [ - aux.opcode_add_flag, - aux.opcode_sub_flag, - aux.opcode_lt_flag, - aux.opcode_eq_flag, - ]; - for flag in flags { - builder.assert_bool(flag); - } - - builder.assert_bool(aux.is_valid); - builder.assert_eq( - aux.is_valid, - flags - .iter() - .fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()), - ); - - let x_limbs = &io.x.data; - let y_limbs = &io.y.data; - let z_limbs = &io.z.data; - - for i in 0..num_limbs { - // If we need to perform an arithmetic operation, we will use "buffer" - // as a "carry/borrow" vector. We refer to it as "carry" in this section. - - // For addition, we have the following: - // z[i] + carry[i] * 2^LIMB_SIZE = x[i] + y[i] + carry[i - 1] - // For subtraction, we have the following: - // z[i] = x[i] - y[i] - carry[i - 1] + carry[i] * 2^LIMB_SIZE - // Separating the summands with the same sign from the others, we get: - // z[i] - x[i] = \pm (y[i] + carry[i - 1] - carry[i] * 2^LIMB_SIZE) - - // Or another way to think about it: we essentially either check that - // z = x + y, or that x = z + y; and "carry" is always the carry of - // the addition. So it is natural that x and z are separated from - // everything else. - - // lhs = +rhs if opcode_add_flag = 1, - // lhs = -rhs if opcode_sub_flag = 1 or opcode_lt_flag = 1. - let lhs = y_limbs[i] - + if i > 0 { - aux.buffer[i - 1].into() - } else { - AB::Expr::zero() - } - - aux.buffer[i] * AB::Expr::from_canonical_u32(1 << LIMB_SIZE); - let rhs = z_limbs[i] - x_limbs[i]; - builder - .when(aux.opcode_add_flag) - .assert_eq(lhs.clone(), rhs.clone()); - builder - .when(aux.opcode_sub_flag + aux.opcode_lt_flag) - .assert_eq(lhs.clone(), -rhs.clone()); - - builder - .when(utils::not(aux.opcode_eq_flag)) - .assert_bool(aux.buffer[i]); - } - - // If we wanted LT, then cmp_result must equal the last carry. - builder - .when(aux.opcode_lt_flag) - .assert_zero(io.cmp_result - aux.buffer[num_limbs - 1]); - // If we wanted EQ, we will do as we would do for checking a single number, - // but we will use "buffer" vector for inverses. - // Namely, we check that: - // - cmp_result * (x[i] - y[i]) = 0, - // - cmp_result + sum_{i < num_limbs} (x[i] - y[i]) * buffer[i] = 1. - let mut sum_eq: AB::Expr = io.cmp_result.into(); - for i in 0..num_limbs { - sum_eq += (x_limbs[i] - y_limbs[i]) * aux.buffer[i]; - - builder - .when(aux.opcode_eq_flag) - .assert_zero(io.cmp_result * (x_limbs[i] - y_limbs[i])); - } - builder - .when(aux.opcode_eq_flag) - .assert_zero(sum_eq - AB::Expr::one()); - - let expected_opcode = flags - .iter() - .zip(UINT256_ARITHMETIC_INSTRUCTIONS) - .fold(AB::Expr::zero(), |acc, (flag, opcode)| { - acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8) - }); - - self.eval_interactions(builder, io, aux, expected_opcode); - } -} diff --git a/vm/src/uint_arithmetic/bridge.rs b/vm/src/uint_arithmetic/bridge.rs deleted file mode 100644 index 47e0e6b3f4..0000000000 --- a/vm/src/uint_arithmetic/bridge.rs +++ /dev/null @@ -1,115 +0,0 @@ -use afs_stark_backend::interaction::InteractionBuilder; -use itertools::izip; -use p3_field::AbstractField; - -use super::{ - air::UintArithmeticAir, - columns::{UintArithmeticAuxCols, UintArithmeticIoCols}, -}; -use crate::memory::MemoryAddress; - -impl UintArithmeticAir { - pub fn eval_interactions( - &self, - builder: &mut AB, - io: UintArithmeticIoCols, - aux: UintArithmeticAuxCols, - expected_opcode: AB::Expr, - ) { - let mut timestamp_delta = AB::Expr::zero(); - - let timestamp: AB::Var = io.from_state.timestamp; - - // Read the operand pointer's values, which are themselves pointers - // for the actual IO data. - for (ptr, value, mem_aux) in izip!( - [io.z.ptr, io.x.ptr, io.y.ptr], - [io.z.address, io.x.address, io.y.address], - &aux.read_ptr_aux_cols - ) { - self.memory_bridge - .read( - MemoryAddress::new(io.d, ptr), // all use addr space d - [value], - timestamp + timestamp_delta.clone(), - mem_aux, - ) - .eval(builder, aux.is_valid); - - timestamp_delta += AB::Expr::one(); - } - - // Memory read for x data - self.memory_bridge - .read( - MemoryAddress::new(io.x.address_space, io.x.address), - io.x.data.try_into().unwrap_or_else(|_| unreachable!()), - timestamp + timestamp_delta.clone(), - &aux.read_x_aux_cols, - ) - .eval(builder, aux.is_valid); - timestamp_delta += AB::Expr::one(); - - // Memory read for y data - self.memory_bridge - .read( - MemoryAddress::new(io.y.address_space, io.y.address), - io.y.data.try_into().unwrap_or_else(|_| unreachable!()), - timestamp + timestamp_delta.clone(), - &aux.read_y_aux_cols, - ) - .eval(builder, aux.is_valid); - timestamp_delta += AB::Expr::one(); - - // Special handling for writing output z data: - let enabled = aux.opcode_add_flag + aux.opcode_sub_flag; - self.memory_bridge - .write( - MemoryAddress::new(io.z.address_space, io.z.address), - io.z.data - .clone() - .try_into() - .unwrap_or_else(|_| unreachable!()), - timestamp + timestamp_delta.clone(), - &aux.write_z_aux_cols, - ) - .eval(builder, enabled.clone()); - timestamp_delta += enabled; - - let enabled = aux.opcode_lt_flag + aux.opcode_eq_flag; - self.memory_bridge - .write( - MemoryAddress::new(io.z.address_space, io.z.address), - [io.cmp_result], - timestamp + timestamp_delta.clone(), - &aux.write_cmp_aux_cols, - ) - .eval(builder, enabled.clone()); - timestamp_delta += enabled; - - self.execution_bridge - .execute_and_increment_pc( - expected_opcode, - [ - io.z.ptr, - io.x.ptr, - io.y.ptr, - io.d, - io.z.address_space, - io.x.address_space, - io.y.address_space, - ], - io.from_state, - timestamp_delta, - ) - .eval(builder, aux.is_valid); - - // Chip-specific interactions - for z in io.z.data.iter() { - self.bus.range_check(*z, LIMB_SIZE).eval( - builder, - aux.opcode_add_flag + aux.opcode_sub_flag + aux.opcode_lt_flag, - ); - } - } -} diff --git a/vm/src/uint_arithmetic/columns.rs b/vm/src/uint_arithmetic/columns.rs deleted file mode 100644 index 54df20172d..0000000000 --- a/vm/src/uint_arithmetic/columns.rs +++ /dev/null @@ -1,229 +0,0 @@ -use std::iter; - -use super::{num_limbs, NUM_LIMBS}; -use crate::{ - arch::columns::ExecutionState, - memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; - -pub struct UintArithmeticCols { - pub io: UintArithmeticIoCols, - pub aux: UintArithmeticAuxCols, -} - -#[derive(Default)] -pub struct UintArithmeticIoCols { - pub from_state: ExecutionState, - pub x: MemoryData, - pub y: MemoryData, - pub z: MemoryData, - /// The pointer address space - pub d: T, - pub cmp_result: T, -} - -pub struct MemoryData { - pub data: Vec, - pub address_space: T, - pub address: T, - /// Pointer whose value is `address`. All pointers use same address space `d`. - pub ptr: T, -} - -impl Default - for MemoryData -{ - fn default() -> Self { - let num_limbs = num_limbs::(); - Self { - data: vec![Default::default(); num_limbs], - address_space: Default::default(), - address: Default::default(), - ptr: Default::default(), - } - } -} - -impl MemoryData { - pub fn from_iterator(mut iter: impl Iterator) -> Self { - let num_limbs = num_limbs::(); - Self { - data: iter.by_ref().take(num_limbs).collect(), - address_space: iter.next().unwrap(), - address: iter.next().unwrap(), - ptr: iter.next().unwrap(), - } - } - - pub fn flatten(&self) -> impl Iterator { - self.data - .iter() - .chain(iter::once(&self.address_space)) - .chain(iter::once(&self.address)) - .chain(iter::once(&self.ptr)) - } -} - -pub struct UintArithmeticAuxCols { - pub is_valid: T, - pub opcode_add_flag: T, // 1 if z_limbs should contain the result of addition - pub opcode_sub_flag: T, // 1 if z_limbs should contain the result of subtraction (means that opcode is SUB or LT) - pub opcode_lt_flag: T, // 1 if opcode is LT - pub opcode_eq_flag: T, // 1 if opcode is EQ - // buffer is the carry of the addition/subtraction, - // or may serve as a single-nonzero-inverse helper vector for EQ256. - // Refer to air.rs for more details. - pub buffer: Vec, - - /// Pointer read auxiliary columns for [z, x, y]. - /// **Note** the ordering, which is designed to match the instruction order. - pub read_ptr_aux_cols: [MemoryReadAuxCols; 3], - pub read_x_aux_cols: MemoryReadAuxCols, - pub read_y_aux_cols: MemoryReadAuxCols, - pub write_z_aux_cols: MemoryWriteAuxCols, - pub write_cmp_aux_cols: MemoryWriteAuxCols, -} - -impl - UintArithmeticCols -{ - pub fn from_iterator(mut iter: impl Iterator) -> Self { - let io = UintArithmeticIoCols::::from_iterator(iter.by_ref()); - let aux = UintArithmeticAuxCols::::from_iterator(iter.by_ref()); - - Self { io, aux } - } - - pub fn flatten(&self) -> Vec { - [self.io.flatten(), self.aux.flatten()].concat() - } - - // TODO get rid of width somehow? - pub const fn width() -> usize { - UintArithmeticIoCols::::width() - + UintArithmeticAuxCols::::width() - } -} - -impl - UintArithmeticIoCols -{ - pub const fn width() -> usize { - 3 * num_limbs::() + 9 + 3 + 1 - } - - pub fn from_iterator(mut iter: impl Iterator) -> Self { - let from_state = ExecutionState::from_iter(iter.by_ref()); - let x = MemoryData::from_iterator(iter.by_ref()); - let y = MemoryData::from_iterator(iter.by_ref()); - let z = MemoryData::from_iterator(iter.by_ref()); - let d = iter.next().unwrap(); - let cmp_result = iter.next().unwrap(); - - Self { - from_state, - x, - y, - z, - d, - cmp_result, - } - } - - pub fn flatten(&self) -> Vec { - iter::once(&self.from_state.pc) - .chain(iter::once(&self.from_state.timestamp)) - .chain(self.x.flatten()) - .chain(self.y.flatten()) - .chain(self.z.flatten()) - .chain(iter::once(&self.d)) - .chain(iter::once(&self.cmp_result)) - .cloned() - .collect() - } -} - -impl - UintArithmeticAuxCols -{ - pub const fn width() -> usize { - let num_limbs = num_limbs::(); - 3 * MemoryReadAuxCols::::width() - + MemoryReadAuxCols::::width() - + MemoryReadAuxCols::::width() - + MemoryWriteAuxCols::::width() - + MemoryWriteAuxCols::::width() - + (5 + num_limbs) - } - - pub fn from_iterator(mut iter: impl Iterator) -> Self { - let num_limbs = num_limbs::(); - - let is_valid = iter.next().unwrap(); - let opcode_add_flag = iter.next().unwrap(); - let opcode_sub_flag = iter.next().unwrap(); - let opcode_lt_flag = iter.next().unwrap(); - let opcode_eq_flag = iter.next().unwrap(); - let buffer = iter.by_ref().take(num_limbs).collect(); - - let width_for_cell = MemoryReadAuxCols::::width(); - let read_ptr_aux_cols = [(); 3].map(|_| { - MemoryReadAuxCols::::from_slice( - &iter.by_ref().take(width_for_cell).collect::>(), - ) - }); - let width = MemoryReadAuxCols::::width(); - let read_x_slice = iter.by_ref().take(width).collect::>(); - let read_y_slice = iter.by_ref().take(width).collect::>(); - let write_z_slice = { - let width = MemoryWriteAuxCols::::width(); - iter.by_ref().take(width).collect::>() - }; - let write_cmp_slice = { - let width = MemoryWriteAuxCols::::width(); - iter.by_ref().take(width).collect::>() - }; - - let read_x_aux_cols = MemoryReadAuxCols::::from_slice(&read_x_slice); - let read_y_aux_cols = MemoryReadAuxCols::::from_slice(&read_y_slice); - let write_z_aux_cols = MemoryWriteAuxCols::::from_slice(&write_z_slice); - let write_cmp_aux_cols = MemoryWriteAuxCols::::from_slice(&write_cmp_slice); - - Self { - is_valid, - opcode_add_flag, - opcode_sub_flag, - opcode_lt_flag, - opcode_eq_flag, - buffer, - read_ptr_aux_cols, - read_x_aux_cols, - read_y_aux_cols, - write_z_aux_cols, - write_cmp_aux_cols, - } - } - - pub fn flatten(&self) -> Vec { - let our_cols = iter::once(&self.is_valid) - .chain(iter::once(&self.opcode_add_flag)) - .chain(iter::once(&self.opcode_sub_flag)) - .chain(iter::once(&self.opcode_lt_flag)) - .chain(iter::once(&self.opcode_eq_flag)) - .chain(self.buffer.iter()) - .cloned() - .collect::>(); - let memory_aux_cols = [ - self.read_ptr_aux_cols - .iter() - .flat_map(|c| c.clone().flatten()) - .collect::>(), - self.read_x_aux_cols.clone().flatten(), - self.read_y_aux_cols.clone().flatten(), - self.write_z_aux_cols.clone().flatten(), - self.write_cmp_aux_cols.clone().flatten(), - ] - .concat(); - [our_cols, memory_aux_cols].concat() - } -} diff --git a/vm/src/uint_arithmetic/mod.rs b/vm/src/uint_arithmetic/mod.rs deleted file mode 100644 index 3e0140a53e..0000000000 --- a/vm/src/uint_arithmetic/mod.rs +++ /dev/null @@ -1,284 +0,0 @@ -use std::{marker::PhantomData, sync::Arc}; - -use afs_primitives::var_range::VariableRangeCheckerChip; -use air::UintArithmeticAir; -use itertools::Itertools; -use p3_field::PrimeField32; - -use crate::{ - arch::{ - bridge::ExecutionBridge, - bus::ExecutionBus, - chips::InstructionExecutor, - columns::ExecutionState, - instructions::{Opcode, UINT256_ARITHMETIC_INSTRUCTIONS}, - }, - memory::{MemoryChipRef, MemoryReadRecord, MemoryWriteRecord}, - program::{bridge::ProgramBus, ExecutionError, Instruction}, -}; - -#[cfg(test)] -pub mod tests; - -pub mod air; -pub mod bridge; -pub mod columns; -pub mod trace; - -pub const NUM_LIMBS: usize = 32; // This is used in some places where const generics are hard to use. - // Of course, TODO make it something normal - -pub const fn num_limbs() -> usize { - (ARG_SIZE + LIMB_SIZE - 1) / LIMB_SIZE -} - -#[derive(Debug)] -pub enum WriteRecord { - Uint(MemoryWriteRecord), - Short(MemoryWriteRecord), -} - -#[derive(Debug)] -pub struct UintArithmeticRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub x_ptr_read: MemoryReadRecord, - pub y_ptr_read: MemoryReadRecord, - pub z_ptr_read: MemoryReadRecord, - - pub x_read: MemoryReadRecord, - pub y_read: MemoryReadRecord, - pub z_write: WriteRecord, - - // this may be redundant because we can extract it from z_write, - // but it's not always the case - pub result: Vec, - - pub buffer: Vec, -} - -#[derive(Debug)] -pub struct UintArithmeticChip { - pub air: UintArithmeticAir, - data: Vec>, - memory_chip: MemoryChipRef, - pub range_checker_chip: Arc, -} - -impl - UintArithmeticChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_chip: MemoryChipRef, - ) -> Self { - let range_checker_chip = memory_chip.borrow().range_checker.clone(); - let memory_bridge = memory_chip.borrow().memory_bridge(); - let bus = range_checker_chip.bus(); - assert!( - bus.range_max_bits >= LIMB_SIZE, - "range_max_bits {} < LIMB_SIZE {}", - bus.range_max_bits, - LIMB_SIZE - ); - Self { - air: UintArithmeticAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus, - base_op: Opcode::ADD256, - }, - data: vec![], - memory_chip, - range_checker_chip, - } - } -} - -impl InstructionExecutor - for UintArithmeticChip -{ - fn execute( - &mut self, - instruction: Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let Instruction { - opcode, - op_a: a, - op_b: b, - op_c: c, - d, - e, - op_f: f, - op_g: g, - .. - } = instruction.clone(); - assert!(UINT256_ARITHMETIC_INSTRUCTIONS.contains(&opcode)); - - let mut memory_chip = self.memory_chip.borrow_mut(); - - debug_assert_eq!( - from_state.timestamp, - memory_chip.timestamp().as_canonical_u32() as usize - ); - - let [z_ptr_read, x_ptr_read, y_ptr_read] = - [a, b, c].map(|ptr_of_ptr| memory_chip.read_cell(d, ptr_of_ptr)); - - let x_read = memory_chip.read::(f, x_ptr_read.value()); - let y_read = memory_chip.read::(g, y_ptr_read.value()); - - let x = x_read.data.map(|x| x.as_canonical_u32()); - let y = y_read.data.map(|x| x.as_canonical_u32()); - let (z, residue) = UintArithmetic::::solve(opcode, (&x, &y)); - let CalculationResidue { result, buffer } = residue; - - let z_address_space = e; - let z_write: WriteRecord = match z { - CalculationResult::Uint(limbs) => { - let to_write = limbs - .iter() - .map(|x| T::from_canonical_u32(*x)) - .collect::>(); - WriteRecord::Uint(memory_chip.write::( - z_address_space, - z_ptr_read.value(), - to_write.try_into().unwrap(), - )) - } - CalculationResult::Short(res) => { - WriteRecord::Short(memory_chip.write_cell(e, z_ptr_read.value(), T::from_bool(res))) - } - }; - - for elem in result.iter() { - self.range_checker_chip.add_count(*elem, LIMB_SIZE); - } - - self.data.push(UintArithmeticRecord { - from_state, - instruction: instruction.clone(), - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - result: result.into_iter().map(T::from_canonical_u32).collect_vec(), - buffer: buffer.into_iter().map(T::from_canonical_u32).collect_vec(), - }); - - Ok(ExecutionState { - pc: from_state.pc + 1, - timestamp: memory_chip.timestamp().as_canonical_u32() as usize, - }) - } -} - -pub enum CalculationResult { - Uint(Vec), - Short(bool), -} - -pub struct CalculationResidue { - pub result: Vec, - pub buffer: Vec, -} - -pub struct UintArithmetic { - _marker: PhantomData, -} -impl - UintArithmetic -{ - pub fn solve( - opcode: Opcode, - (x, y): (&[u32], &[u32]), - ) -> (CalculationResult, CalculationResidue) { - match opcode { - Opcode::ADD256 => { - let (result, carry) = Self::add(x, y); - ( - CalculationResult::Uint(result.clone()), - CalculationResidue { - result, - buffer: carry, - }, - ) - } - Opcode::SUB256 => { - let (result, carry) = Self::subtract(x, y); - ( - CalculationResult::Uint(result.clone()), - CalculationResidue { - result, - buffer: carry, - }, - ) - } - Opcode::LT256 => { - let (diff, carry) = Self::subtract(x, y); - let cmp_result = *carry.last().unwrap() == 1; - ( - CalculationResult::Short(cmp_result), - CalculationResidue { - result: diff, - buffer: carry, - }, - ) - } - Opcode::EQ256 => { - let num_limbs = num_limbs::(); - let mut inverse = vec![0u32; num_limbs]; - for i in 0..num_limbs { - if x[i] != y[i] { - inverse[i] = (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])) - .inverse() - .as_canonical_u32(); - break; - } - } - ( - CalculationResult::Short(x == y), - CalculationResidue { - result: Default::default(), - buffer: inverse, - }, - ) - } - _ => unreachable!(), - } - } - - fn add(x: &[u32], y: &[u32]) -> (Vec, Vec) { - let num_limbs = num_limbs::(); - let mut result = vec![0u32; num_limbs]; - let mut carry = vec![0u32; num_limbs]; - for i in 0..num_limbs { - result[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; - carry[i] = result[i] >> LIMB_SIZE; - result[i] &= (1 << LIMB_SIZE) - 1; - } - (result, carry) - } - - fn subtract(x: &[u32], y: &[u32]) -> (Vec, Vec) { - let num_limbs = num_limbs::(); - let mut result = vec![0u32; num_limbs]; - let mut carry = vec![0u32; num_limbs]; - for i in 0..num_limbs { - let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; - if x[i] >= rhs { - result[i] = x[i] - rhs; - carry[i] = 0; - } else { - result[i] = x[i] + (1 << LIMB_SIZE) - rhs; - carry[i] = 1; - } - } - (result, carry) - } -} diff --git a/vm/src/uint_arithmetic/tests.rs b/vm/src/uint_arithmetic/tests.rs deleted file mode 100644 index 051c1a89a5..0000000000 --- a/vm/src/uint_arithmetic/tests.rs +++ /dev/null @@ -1,610 +0,0 @@ -use afs_stark_backend::{utils::disable_debug_builder, verifier::VerificationError}; -use ax_sdk::{ - any_rap_vec, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - utils::create_seeded_rng, -}; -use p3_baby_bear::BabyBear; -use p3_field::{AbstractField, Field, PrimeField32}; -use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use rand::{rngs::StdRng, Rng}; - -use super::{ - columns::UintArithmeticCols, CalculationResult, UintArithmetic, UintArithmeticChip, NUM_LIMBS, -}; -use crate::{ - arch::{chips::MachineChip, instructions::Opcode, testing::MachineChipTestBuilder}, - program::Instruction, -}; - -type F = BabyBear; - -const OPCODES_ARITH: [Opcode; 2] = [Opcode::ADD256, Opcode::SUB256]; - -fn generate_uint_number( - rng: &mut StdRng, -) -> Vec { - assert_eq!(ARG_SIZE % LIMB_SIZE, 0); - - (0..ARG_SIZE / LIMB_SIZE) - .map(|_| rng.gen_range(0..1 << LIMB_SIZE)) - .collect() -} - -#[test] -fn uint_arithmetic_rand_air_test() { - const ARG_SIZE: usize = 256; - const LIMB_SIZE: usize = 8; - let num_ops: usize = 15; - let address_space_range = || 1usize..=2; - let address_range = || 0usize..1 << 29; - - let mut tester = MachineChipTestBuilder::default(); - let mut chip = UintArithmeticChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_chip(), - ); - - let mut rng = create_seeded_rng(); - - for _ in 0..num_ops { - let opcode = OPCODES_ARITH[rng.gen_range(0..OPCODES_ARITH.len())]; - let operand1 = generate_uint_number::(&mut rng); - let operand2 = generate_uint_number::(&mut rng); - - let ptr_as = rng.gen_range(address_space_range()); // d - let result_as = rng.gen_range(address_space_range()); // e - let as1 = rng.gen_range(address_space_range()); // f - let as2 = rng.gen_range(address_space_range()); // g - let address1 = rng.gen_range(address_range()); - let address2 = rng.gen_range(address_range()); - let address1_ptr = rng.gen_range(address_range()); - let address2_ptr = rng.gen_range(address_range()); - let result_ptr = rng.gen_range(address_range()); - let result_address = rng.gen_range(address_range()); - - let operand1_f = operand1 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let operand2_f = operand2 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write::(as1, address1, operand1_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address1_ptr, F::from_canonical_usize(address1)); - tester.write::(as2, address2, operand2_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address2_ptr, F::from_canonical_usize(address2)); - tester.write_cell(ptr_as, result_ptr, F::from_canonical_usize(result_address)); - - let result = - UintArithmetic::::solve(opcode, (&operand1, &operand2)); - - tester.execute( - &mut chip, - Instruction::from_usize( - opcode, - [ - result_ptr, - address1_ptr, - address2_ptr, - ptr_as, - result_as, - as1, - as2, - ], - ), - ); - match result.0 { - CalculationResult::Uint(result) => { - assert_eq!( - result - .into_iter() - .map(F::from_canonical_u32) - .collect::>(), - tester.read::(result_as, result_address) - ) - } - CalculationResult::Short(_) => unreachable!(), - } - } - - let tester = tester.build().load(chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} - -/// Given a fake trace of a single operation, setup a chip and run the test. -/// We replace the "output" part of the trace, and we _may_ replace the interactions -/// based on the desired output. We check that it produces the error we expect. -#[allow(clippy::too_many_arguments)] -fn run_bad_uint_arithmetic_test( - op: Opcode, - x: Vec, - y: Vec, - z: Vec, - buffer: Vec, - cmp_result: bool, - replace_interactions: bool, - expected_error: VerificationError, -) { - let mut tester = MachineChipTestBuilder::default(); - let mut chip = UintArithmeticChip::<256, 8, F>::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_chip(), - ); - - let x_f = x - .iter() - .map(|v| F::from_canonical_u32(*v)) - .collect::>(); - let y_f = y - .iter() - .map(|v| F::from_canonical_u32(*v)) - .collect::>(); - let ptr_as = 1; - let mem_as = 2; - tester.write::(mem_as, 0, x_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, 2 * NUM_LIMBS, F::from_canonical_usize(0)); - tester.write::(mem_as, NUM_LIMBS, y_f.as_slice().try_into().unwrap()); - tester.write_cell( - ptr_as, - 2 * NUM_LIMBS + 1, - F::from_canonical_usize(NUM_LIMBS), - ); - tester.write_cell(ptr_as, 0, F::from_canonical_usize(0)); - - tester.execute( - &mut chip, - Instruction::from_usize( - op, - [ - 0, // result address ptr - 2 * NUM_LIMBS, // x address ptr - 2 * NUM_LIMBS + 1, // y address ptr - ptr_as, - 3, // result as - mem_as, // x as - mem_as, // y as - ], - ), - ); - - if let CalculationResult::Uint(_) = UintArithmetic::<256, 8, F>::solve(op, (&x, &y)).0 { - if replace_interactions { - chip.range_checker_chip.clear(); - for limb in z.iter() { - chip.range_checker_chip.add_count(*limb, 8); - } - } - } - - let air = chip.air; - let range_checker = chip.range_checker_chip.clone(); - let range_air = range_checker.air; - let trace = chip.generate_trace(); - let row = trace.row_slice(0).to_vec(); - let mut cols = UintArithmeticCols::<256, 8, F>::from_iterator(&mut row.into_iter()); - cols.io.z.data = z.into_iter().map(F::from_canonical_u32).collect(); - cols.aux.buffer = buffer.into_iter().map(F::from_canonical_u32).collect(); - cols.io.cmp_result = F::from_bool(cmp_result); - let trace = RowMajorMatrix::new(cols.flatten(), UintArithmeticCols::<256, 8, F>::width()); - - let range_trace = range_checker.generate_trace(); - - disable_debug_builder(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - &expected_error - ); - let result = BabyBearPoseidon2Engine::run_simple_test_no_pis( - &any_rap_vec![&air, &range_air], - vec![trace, range_trace], - ); - assert_eq!(result.err(), Some(expected_error), "{}", msg); -} - -#[test] -fn uint_add_wrong_carry_air_test() { - run_bad_uint_arithmetic_test( - Opcode::ADD256, - std::iter::once(1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(3) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - false, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_add_out_of_range_air_test() { - run_bad_uint_arithmetic_test( - Opcode::ADD256, - std::iter::once(250) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(250) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(500) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(0) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - false, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn uint_add_wrong_addition_air_test() { - run_bad_uint_arithmetic_test( - Opcode::ADD256, - std::iter::once(250) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(250) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(500 - (1 << 8)) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(0) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - false, - VerificationError::OodEvaluationMismatch, - ); -} - -// We NEED to check that the carry is 0 or 1 -#[test] -fn uint_add_invalid_carry_air_test() { - let bad_carry = F::from_canonical_u32(1 << 8).inverse().as_canonical_u32(); - - run_bad_uint_arithmetic_test( - Opcode::ADD256, - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(bad_carry)) - .collect(), - false, - true, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_sub_out_of_range_air_test() { - run_bad_uint_arithmetic_test( - Opcode::SUB256, - std::iter::once(1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(2) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(F::neg_one().as_canonical_u32()) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(0) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - false, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn uint_sub_wrong_subtraction_air_test() { - run_bad_uint_arithmetic_test( - Opcode::SUB256, - std::iter::once(1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(2) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once((1 << 8) - 1) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - std::iter::once(0) - .chain(std::iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - false, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_sub_invalid_carry_air_test() { - let bad_carry = F::from_canonical_u32(1 << 8).inverse().as_canonical_u32(); - - run_bad_uint_arithmetic_test( - Opcode::SUB256, - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(1)) - .collect(), - vec![0; NUM_LIMBS - 1] - .into_iter() - .chain(std::iter::once(bad_carry)) - .collect(), - false, - true, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_lt_rand_air_test() { - const ARG_SIZE: usize = 256; - const LIMB_SIZE: usize = 8; - let num_ops: usize = 15; - let address_space_range = || 1usize..=2; - let address_range = || 0usize..1 << 29; - - let mut tester = MachineChipTestBuilder::default(); - let mut chip = UintArithmeticChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_chip(), - ); - - let mut rng = create_seeded_rng(); - - for _ in 0..num_ops { - let opcode = Opcode::LT256; - let operand1 = generate_uint_number::(&mut rng); - let operand2 = generate_uint_number::(&mut rng); - - let ptr_as = rng.gen_range(address_space_range()); // d - let result_as = rng.gen_range(address_space_range()); // e - let as1 = rng.gen_range(address_space_range()); // f - let as2 = rng.gen_range(address_space_range()); // g - let address1 = rng.gen_range(address_range()); - let address2 = rng.gen_range(address_range()); - let address1_ptr = rng.gen_range(address_range()); - let address2_ptr = rng.gen_range(address_range()); - let result_ptr = rng.gen_range(address_range()); - let result_address = rng.gen_range(address_range()); - - let operand1_f = operand1 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let operand2_f = operand2 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write::(as1, address1, operand1_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address1_ptr, F::from_canonical_usize(address1)); - tester.write::(as2, address2, operand2_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address2_ptr, F::from_canonical_usize(address2)); - tester.write_cell(ptr_as, result_ptr, F::from_canonical_usize(result_address)); - - let result = - UintArithmetic::::solve(opcode, (&operand1, &operand2)); - - tester.execute( - &mut chip, - Instruction::from_usize( - opcode, - [ - result_ptr, - address1_ptr, - address2_ptr, - ptr_as, - result_as, - as1, - as2, - ], - ), - ); - match result.0 { - CalculationResult::Uint(_) => unreachable!(), - CalculationResult::Short(result) => { - assert_eq!( - [F::from_bool(result)], - tester.read::<1>(result_as, result_address) - ) - } - } - } - - let tester = tester.build().load(chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn uint_eq_rand_air_test() { - const ARG_SIZE: usize = 256; - const LIMB_SIZE: usize = 8; - let num_ops: usize = 15; - let address_space_range = || 1usize..=2; - let address_range = || 0usize..1 << 29; - - let mut tester = MachineChipTestBuilder::default(); - let mut chip = UintArithmeticChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_chip(), - ); - - let mut rng = create_seeded_rng(); - - for _ in 0..num_ops { - let opcode = Opcode::EQ256; - let operand1 = generate_uint_number::(&mut rng); - let operand2 = if rng.gen_bool(0.5) { - generate_uint_number::(&mut rng) - } else { - operand1.clone() - }; - - let ptr_as = rng.gen_range(address_space_range()); // d - let result_as = rng.gen_range(address_space_range()); // e - let as1 = rng.gen_range(address_space_range()); // f - let as2 = rng.gen_range(address_space_range()); // g - let address1 = rng.gen_range(address_range()); - let address2 = rng.gen_range(address_range()); - let address1_ptr = rng.gen_range(address_range()); - let address2_ptr = rng.gen_range(address_range()); - let result_ptr = rng.gen_range(address_range()); - let result_address = rng.gen_range(address_range()); - - let operand1_f = operand1 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let operand2_f = operand2 - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write::(as1, address1, operand1_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address1_ptr, F::from_canonical_usize(address1)); - tester.write::(as2, address2, operand2_f.as_slice().try_into().unwrap()); - tester.write_cell(ptr_as, address2_ptr, F::from_canonical_usize(address2)); - tester.write_cell(ptr_as, result_ptr, F::from_canonical_usize(result_address)); - - let result = - UintArithmetic::::solve(opcode, (&operand1, &operand2)); - - tester.execute( - &mut chip, - Instruction::from_usize( - opcode, - [ - result_ptr, - address1_ptr, - address2_ptr, - ptr_as, - result_as, - as1, - as2, - ], - ), - ); - match result.0 { - CalculationResult::Uint(_) => unreachable!(), - CalculationResult::Short(result) => { - assert_eq!( - [F::from_bool(result)], - tester.read::<1>(result_as, result_address) - ) - } - } - } - - let tester = tester.build().load(chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn uint_lt_wrong_subtraction_test() { - run_bad_uint_arithmetic_test( - Opcode::LT256, - std::iter::once(65_000) - .chain(std::iter::repeat(0).take(31)) - .collect(), - std::iter::once(65_000) - .chain(std::iter::repeat(0).take(31)) - .collect(), - std::iter::once(1) - .chain(std::iter::repeat(0).take(31)) - .collect(), - std::iter::once(0) - .chain(std::iter::repeat(0).take(31)) - .collect(), - false, - false, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_lt_wrong_carry_test() { - run_bad_uint_arithmetic_test( - Opcode::LT256, - vec![0; 31] - .into_iter() - .chain(std::iter::once(65_000)) - .collect(), - vec![0; 31] - .into_iter() - .chain(std::iter::once(65_000)) - .collect(), - vec![0; 31].into_iter().chain(std::iter::once(0)).collect(), - vec![0; 31].into_iter().chain(std::iter::once(1)).collect(), - true, - false, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn uint_eq_wrong_positive_test() { - run_bad_uint_arithmetic_test( - Opcode::EQ256, - vec![0; 31] - .into_iter() - .chain(std::iter::once(123)) - .collect(), - vec![0; 31] - .into_iter() - .chain(std::iter::once(456)) - .collect(), - vec![0; 31].into_iter().chain(std::iter::once(0)).collect(), - vec![0; 31].into_iter().chain(std::iter::once(0)).collect(), - true, - false, - VerificationError::OodEvaluationMismatch, - ); -} diff --git a/vm/src/uint_arithmetic/trace.rs b/vm/src/uint_arithmetic/trace.rs deleted file mode 100644 index 12a73815cd..0000000000 --- a/vm/src/uint_arithmetic/trace.rs +++ /dev/null @@ -1,155 +0,0 @@ -use std::array::from_fn; - -use afs_stark_backend::{config::StarkGenericConfig, rap::AnyRap}; -use p3_commit::PolynomialSpace; -use p3_field::PrimeField32; -use p3_matrix::dense::RowMajorMatrix; -use p3_uni_stark::Domain; - -use super::{ - columns::{MemoryData, UintArithmeticAuxCols, UintArithmeticCols, UintArithmeticIoCols}, - num_limbs, UintArithmeticChip, WriteRecord, -}; -use crate::{ - arch::{chips::MachineChip, instructions::Opcode}, - memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; - -impl MachineChip - for UintArithmeticChip -{ - fn generate_trace(self) -> RowMajorMatrix { - let aux_cols_factory = self.memory_chip.borrow().aux_cols_factory(); - let num_limbs = num_limbs::(); - let rows = self - .data - .iter() - .map(|operation| { - { - let super::UintArithmeticRecord:: { - from_state, - instruction, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - result, - buffer, - } = operation; - - UintArithmeticCols { - io: UintArithmeticIoCols { - from_state: from_state.map(F::from_canonical_usize), - x: MemoryData:: { - data: x_read.data.to_vec(), - address_space: x_read.address_space, - address: x_read.pointer, - ptr: x_ptr_read.pointer, - }, - y: MemoryData { - data: y_read.data.to_vec(), - address_space: y_read.address_space, - address: y_read.pointer, - ptr: y_ptr_read.pointer, - }, - z: match &z_write { - WriteRecord::Uint(z) => MemoryData { - data: z.data.to_vec(), - address_space: z.address_space, - address: z.pointer, - ptr: z_ptr_read.pointer, - }, - WriteRecord::Short(z) => MemoryData { - data: result - .iter() - .cloned() - .chain(std::iter::repeat(F::zero())) - .take(num_limbs) - .collect(), - address_space: z.address_space, - address: z.pointer, - ptr: z_ptr_read.pointer, - }, - }, - d: instruction.d, - cmp_result: match &z_write { - WriteRecord::Uint(_) => F::zero(), - WriteRecord::Short(z) => z.data[0], - }, - }, - aux: UintArithmeticAuxCols { - is_valid: F::one(), - opcode_add_flag: F::from_bool(instruction.opcode == Opcode::ADD256), - opcode_sub_flag: F::from_bool(instruction.opcode == Opcode::SUB256), - opcode_lt_flag: F::from_bool(instruction.opcode == Opcode::LT256), - opcode_eq_flag: F::from_bool(instruction.opcode == Opcode::EQ256), - buffer: buffer.clone(), - read_ptr_aux_cols: [z_ptr_read, x_ptr_read, y_ptr_read] - .map(|read| aux_cols_factory.make_read_aux_cols(read.clone())), - read_x_aux_cols: aux_cols_factory.make_read_aux_cols(x_read.clone()), - read_y_aux_cols: aux_cols_factory.make_read_aux_cols(y_read.clone()), - write_z_aux_cols: match &z_write { - WriteRecord::Uint(z) => { - aux_cols_factory.make_write_aux_cols(z.clone()) - } - WriteRecord::Short(_) => MemoryWriteAuxCols::disabled(), - }, - write_cmp_aux_cols: match &z_write { - WriteRecord::Uint(_) => MemoryWriteAuxCols::disabled(), - WriteRecord::Short(z) => { - aux_cols_factory.make_write_aux_cols(z.clone()) - } - }, - }, - } - } - .flatten() - }) - .collect::>(); - - let height = rows.len(); - let padded_height = height.next_power_of_two(); - - let blank_row = UintArithmeticCols:: { - io: Default::default(), - aux: UintArithmeticAuxCols { - is_valid: Default::default(), - opcode_add_flag: Default::default(), - opcode_sub_flag: Default::default(), - opcode_lt_flag: Default::default(), - opcode_eq_flag: Default::default(), - buffer: vec![Default::default(); num_limbs], - read_ptr_aux_cols: from_fn(|_| MemoryReadAuxCols::disabled()), - read_x_aux_cols: MemoryReadAuxCols::disabled(), - read_y_aux_cols: MemoryReadAuxCols::disabled(), - write_z_aux_cols: MemoryWriteAuxCols::disabled(), - write_cmp_aux_cols: MemoryWriteAuxCols::disabled(), - }, - } - .flatten(); - let width = blank_row.len(); - - let mut padded_rows = rows; - - padded_rows.extend(std::iter::repeat(blank_row).take(padded_height - height)); - - RowMajorMatrix::new(padded_rows.concat(), width) - } - - fn air(&self) -> Box> - where - Domain: PolynomialSpace, - { - Box::new(self.air) - } - - fn current_trace_height(&self) -> usize { - self.data.len() - } - - fn trace_width(&self) -> usize { - UintArithmeticCols::::width() - } -} diff --git a/vm/src/vm/segment.rs b/vm/src/vm/segment.rs index 2291d77699..2ac65fccc6 100644 --- a/vm/src/vm/segment.rs +++ b/vm/src/vm/segment.rs @@ -26,13 +26,14 @@ use super::{ VmCycleTracker, VmMetrics, }; use crate::{ + alu::ArithmeticLogicChip, arch::{ bus::ExecutionBus, chips::{InstructionExecutor, InstructionExecutorVariant, MachineChip, MachineChipVariant}, columns::ExecutionState, instructions::{ - Opcode, CORE_INSTRUCTIONS, FIELD_ARITHMETIC_INSTRUCTIONS, FIELD_EXTENSION_INSTRUCTIONS, - SHIFT_256_INSTRUCTIONS, UINT256_ARITHMETIC_INSTRUCTIONS, UI_32_INSTRUCTIONS, + Opcode, ALU_256_INSTRUCTIONS, CORE_INSTRUCTIONS, FIELD_ARITHMETIC_INSTRUCTIONS, + FIELD_EXTENSION_INSTRUCTIONS, SHIFT_256_INSTRUCTIONS, UI_32_INSTRUCTIONS, }, }, castf::CastFChip, @@ -50,7 +51,6 @@ use crate::{ program::{bridge::ProgramBus, ExecutionError, Program, ProgramChip}, shift::ShiftChip, ui::UiChip, - uint_arithmetic::UintArithmeticChip, uint_multiplication::UintMultiplicationChip, }; @@ -102,6 +102,7 @@ impl ExecutionSegment { let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, config.memory_config.decomp); let range_checker = Arc::new(VariableRangeCheckerChip::new(range_bus)); + let byte_xor_chip = Arc::new(XorLookupChip::new(BYTE_XOR_BUS)); let memory_chip = Rc::new(RefCell::new(MemoryChip::with_volatile_memory( memory_bus, @@ -175,7 +176,6 @@ impl ExecutionSegment { chips.push(MachineChipVariant::Poseidon2(poseidon2_chip.clone())); } if config.keccak_enabled { - let byte_xor_chip = Arc::new(XorLookupChip::new(BYTE_XOR_BUS)); let keccak_chip = Rc::new(RefCell::new(KeccakVmChip::new( execution_bus, program_bus, @@ -184,7 +184,6 @@ impl ExecutionSegment { ))); assign!([Opcode::KECCAK256], keccak_chip); chips.push(MachineChipVariant::Keccak256(keccak_chip)); - chips.push(MachineChipVariant::ByteXor(byte_xor_chip)); } if config.modular_addsub_enabled { let mod_addsub_coord: ModularAddSubChip = ModularAddSubChip::new( @@ -232,13 +231,16 @@ impl ExecutionSegment { } // Modular multiplication also depends on U256 arithmetic. if config.modular_multdiv_enabled || config.u256_arithmetic_enabled { - let u256_chip = Rc::new(RefCell::new(UintArithmeticChip::new( + let u256_chip = Rc::new(RefCell::new(ArithmeticLogicChip::new( execution_bus, program_bus, memory_chip.clone(), + byte_xor_chip.clone(), ))); - chips.push(MachineChipVariant::U256Arithmetic(u256_chip.clone())); - assign!(UINT256_ARITHMETIC_INSTRUCTIONS, u256_chip); + chips.push(MachineChipVariant::ArithmeticLogicUnit256( + u256_chip.clone(), + )); + assign!(ALU_256_INSTRUCTIONS, u256_chip); } if config.u256_multiplication_enabled { let range_tuple_bus = @@ -298,6 +300,7 @@ impl ExecutionSegment { )); chips.push(MachineChipVariant::Secp256k1Double(secp256k1_double_chip)); } + chips.push(MachineChipVariant::ByteXor(byte_xor_chip)); // Most chips have a reference to the memory chip, and the memory chip has a reference to // the range checker chip. chips.push(MachineChipVariant::Memory(memory_chip.clone()));