Skip to content

Commit

Permalink
feat: Adapter + Integration runtime impl for shift and LT (#507)
Browse files Browse the repository at this point in the history
* feat: Adapter + Integration runtime impl for shift and LT

* refactor: make instruction offsets closer for ALU opcodes
  • Loading branch information
stephenh-axiom-xyz authored Oct 8, 2024
1 parent 21140b5 commit 930d48d
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 6 deletions.
10 changes: 8 additions & 2 deletions vm/src/arch/chips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use crate::{
modular_addsub::ModularAddSubChip,
modular_multdiv::ModularMultDivChip,
new_alu::Rv32ArithmeticLogicChip,
new_lt::Rv32LessThanChip,
new_shift::Rv32ShiftChip,
program::{ExecutionError, Instruction, ProgramChip},
shift::ShiftChip,
ui::UiChip,
Expand Down Expand Up @@ -183,9 +185,11 @@ pub enum InstructionExecutorVariant<F: PrimeField32> {
Keccak256(Rc<RefCell<KeccakVmChip<F>>>),
ModularAddSub(Rc<RefCell<ModularAddSubChip<F, 32, 8>>>),
ModularMultDiv(Rc<RefCell<ModularMultDivChip<F, 63, 32, 8>>>),
ArithmeticLogicUnit256(Rc<RefCell<ArithmeticLogicChip<F, 32, 8>>>),
ArithmeticLogicUnitRv32(Rc<RefCell<Rv32ArithmeticLogicChip<F>>>),
ArithmeticLogicUnit256(Rc<RefCell<ArithmeticLogicChip<F, 32, 8>>>),
LessThanRv32(Rc<RefCell<Rv32LessThanChip<F>>>),
U256Multiplication(Rc<RefCell<UintMultiplicationChip<F, 32, 8>>>),
ShiftRv32(Rc<RefCell<Rv32ShiftChip<F>>>),
Shift256(Rc<RefCell<ShiftChip<F, 32, 8>>>),
Ui(Rc<RefCell<UiChip<F>>>),
CastF(Rc<RefCell<CastFChip<F>>>),
Expand All @@ -206,9 +210,11 @@ pub enum MachineChipVariant<F: PrimeField32> {
RangeTupleChecker(Arc<RangeTupleCheckerChip>),
Keccak256(Rc<RefCell<KeccakVmChip<F>>>),
ByteXor(Arc<XorLookupChip<8>>),
ArithmeticLogicUnit256(Rc<RefCell<ArithmeticLogicChip<F, 32, 8>>>),
ArithmeticLogicUnitRv32(Rc<RefCell<Rv32ArithmeticLogicChip<F>>>),
ArithmeticLogicUnit256(Rc<RefCell<ArithmeticLogicChip<F, 32, 8>>>),
LessThanRv32(Rc<RefCell<Rv32LessThanChip<F>>>),
U256Multiplication(Rc<RefCell<UintMultiplicationChip<F, 32, 8>>>),
ShiftRv32(Rc<RefCell<Rv32ShiftChip<F>>>),
Shift256(Rc<RefCell<ShiftChip<F, 32, 8>>>),
Ui(Rc<RefCell<UiChip<F>>>),
CastF(Rc<RefCell<CastFChip<F>>>),
Expand Down
23 changes: 23 additions & 0 deletions vm/src/arch/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,26 @@ pub enum AluOpcode {
OR,
AND,
}

#[derive(
Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, UsizeOpcode,
)]
#[opcode_offset = 0x305]
#[repr(usize)]
#[allow(non_camel_case_types)]
pub enum ShiftOpcode {
SLL,
SRL,
SRA,
}

#[derive(
Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, UsizeOpcode,
)]
#[opcode_offset = 0x310]
#[repr(usize)]
#[allow(non_camel_case_types)]
pub enum LessThanOpcode {
SLT,
SLTU,
}
2 changes: 2 additions & 0 deletions vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub mod memory;
pub mod modular_addsub;
pub mod modular_multdiv;
pub mod new_alu;
pub mod new_lt;
pub mod new_shift;
pub mod program;
/// SDK functions for running and proving programs in the VM.
#[cfg(feature = "sdk")]
Expand Down
163 changes: 163 additions & 0 deletions vm/src/new_lt/integration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use std::{mem::size_of, sync::Arc};

use afs_derive::AlignedBorrow;
use afs_primitives::xor::{bus::XorBus, lookup::XorLookupChip};
use afs_stark_backend::interaction::InteractionBuilder;
use p3_air::{Air, AirBuilderWithPublicValues, BaseAir, PairBuilder};
use p3_field::{Field, PrimeField32};

use crate::{
arch::{
instructions::{LessThanOpcode, UsizeOpcode},
InstructionOutput, IntegrationInterface, MachineAdapter, MachineAdapterInterface,
MachineIntegration, Result,
},
program::Instruction,
};

#[repr(C)]
#[derive(AlignedBorrow)]
pub struct LessThanCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub b: [T; NUM_LIMBS],
pub c: [T; NUM_LIMBS],
pub cmp_result: T,

pub opcode_slt_flag: T,
pub opcode_sltu_flag: T,

pub x_sign: T,
pub y_sign: T,

// 1 at the most significant index i such that b[i] != c[i], otherwise 0. If such
// an i exists, diff_val = c[i] - b[i]
pub diff_marker: [T; LIMB_BITS],
pub diff_val: T,
}

impl<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> LessThanCols<T, NUM_LIMBS, LIMB_BITS> {
pub fn width() -> usize {
size_of::<LessThanCols<u8, NUM_LIMBS, LIMB_BITS>>()
}
}

#[derive(Copy, Clone, Debug)]
pub struct LessThanAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub bus: XorBus,
}

impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
for LessThanAir<NUM_LIMBS, LIMB_BITS>
{
fn width(&self) -> usize {
LessThanCols::<F, NUM_LIMBS, LIMB_BITS>::width()
}
}

impl<AB: InteractionBuilder, const NUM_LIMBS: usize, const LIMB_BITS: usize> Air<AB>
for LessThanAir<NUM_LIMBS, LIMB_BITS>
{
fn eval(&self, _builder: &mut AB) {
todo!();
}
}

#[derive(Debug)]
pub struct LessThanIntegration<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
pub air: LessThanAir<NUM_LIMBS, LIMB_BITS>,
pub xor_lookup_chip: Arc<XorLookupChip<LIMB_BITS>>,
offset: usize,
}

impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> LessThanIntegration<NUM_LIMBS, LIMB_BITS> {
pub fn new(xor_lookup_chip: Arc<XorLookupChip<LIMB_BITS>>, offset: usize) -> Self {
Self {
air: LessThanAir {
bus: xor_lookup_chip.bus(),
},
xor_lookup_chip,
offset,
}
}
}

impl<F: PrimeField32, A: MachineAdapter<F>, const NUM_LIMBS: usize, const LIMB_BITS: usize>
MachineIntegration<F, A> for LessThanIntegration<NUM_LIMBS, LIMB_BITS>
where
A::Interface<F>: MachineAdapterInterface<F>,
<A::Interface<F> as MachineAdapterInterface<F>>::Reads: Into<[[F; NUM_LIMBS]; 2]>,
<A::Interface<F> as MachineAdapterInterface<F>>::Writes: From<[F; NUM_LIMBS]>,
{
// TODO: update for trace generation
type Record = u32;
type Cols<T> = LessThanCols<T, NUM_LIMBS, LIMB_BITS>;
type Air = LessThanAir<NUM_LIMBS, LIMB_BITS>;

#[allow(clippy::type_complexity)]
fn execute_instruction(
&self,
instruction: &Instruction<F>,
from_pc: F,
reads: <A::Interface<F> as MachineAdapterInterface<F>>::Reads,
) -> Result<(InstructionOutput<F, A::Interface<F>>, Self::Record)> {
let Instruction { opcode, .. } = instruction;
let opcode = LessThanOpcode::from_usize(opcode - self.offset);

let data: [[F; NUM_LIMBS]; 2] = reads.into();
let x = data[0].map(|x| x.as_canonical_u32());
let y = data[1].map(|y| y.as_canonical_u32());
let (cmp_result, _diff_idx, _x_sign, _y_sign) =
solve_less_than::<NUM_LIMBS, LIMB_BITS>(opcode, &x, &y);

let mut writes = [0u32; NUM_LIMBS];
writes[0] = cmp_result as u32;

// Integration doesn't modify PC directly, so we let Adapter handle the increment
let output: InstructionOutput<F, A::Interface<F>> = InstructionOutput {
to_pc: from_pc,
writes: writes.map(F::from_canonical_u32).into(),
};

// TODO: send XorLookupChip requests
// TODO: create Record and return

Ok((output, 0))
}

fn get_opcode_name(&self, _opcode: usize) -> String {
todo!()
}

fn generate_trace_row(&self, _row_slice: &mut Self::Cols<F>, _record: Self::Record) {
todo!()
}

/// Returns `(to_pc, interface)`.
fn eval_primitive<AB: InteractionBuilder<F = F> + PairBuilder + AirBuilderWithPublicValues>(
_air: &Self::Air,
_builder: &mut AB,
_local: &Self::Cols<AB::Var>,
_local_adapter: &A::Cols<AB::Var>,
) -> IntegrationInterface<AB::Expr, A::Interface<AB::Expr>> {
todo!()
}

fn air(&self) -> Self::Air {
self.air
}
}

// Returns (cmp_result, diff_idx, x_sign, y_sign)
pub(super) fn solve_less_than<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
opcode: LessThanOpcode,
x: &[u32; NUM_LIMBS],
y: &[u32; NUM_LIMBS],
) -> (bool, usize, bool, bool) {
let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT;
for i in (0..NUM_LIMBS).rev() {
if x[i] != y[i] {
return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign);
}
}
(false, 0, x_sign, y_sign)
}
10 changes: 10 additions & 0 deletions vm/src/new_lt/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use crate::arch::{MachineChipWrapper, Rv32AluAdapter};

mod integration;
pub use integration::*;

#[cfg(test)]
mod tests;

// TODO: Replace current ALU less than commands upon completion
pub type Rv32LessThanChip<F> = MachineChipWrapper<F, Rv32AluAdapter<F>, LessThanIntegration<4, 8>>;
41 changes: 41 additions & 0 deletions vm/src/new_lt/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use super::integration::solve_less_than;
use crate::arch::instructions::LessThanOpcode;

const RV32_NUM_LIMBS: usize = 4;
const RV32_LIMB_BITS: usize = 8;

#[test]
fn solve_sltu_sanity_test() {
let x: [u32; RV32_NUM_LIMBS] = [145, 34, 25, 205];
let y: [u32; RV32_NUM_LIMBS] = [73, 35, 25, 205];
let (cmp_result, diff_idx, x_sign, y_sign) =
solve_less_than::<RV32_NUM_LIMBS, RV32_LIMB_BITS>(LessThanOpcode::SLTU, &x, &y);
assert!(cmp_result);
assert_eq!(diff_idx, 1);
assert!(!x_sign); // unsigned
assert!(!y_sign); // unsigned
}

#[test]
fn solve_slt_same_sign_sanity_test() {
let x: [u32; RV32_NUM_LIMBS] = [145, 34, 25, 205];
let y: [u32; RV32_NUM_LIMBS] = [73, 35, 25, 205];
let (cmp_result, diff_idx, x_sign, y_sign) =
solve_less_than::<RV32_NUM_LIMBS, RV32_LIMB_BITS>(LessThanOpcode::SLT, &x, &y);
assert!(cmp_result);
assert_eq!(diff_idx, 1);
assert!(x_sign); // negative
assert!(y_sign); // negative
}

#[test]
fn solve_slt_diff_sign_sanity_test() {
let x: [u32; RV32_NUM_LIMBS] = [45, 35, 25, 55];
let y: [u32; RV32_NUM_LIMBS] = [173, 34, 25, 205];
let (cmp_result, diff_idx, x_sign, y_sign) =
solve_less_than::<RV32_NUM_LIMBS, RV32_LIMB_BITS>(LessThanOpcode::SLT, &x, &y);
assert!(!cmp_result);
assert_eq!(diff_idx, 3);
assert!(!x_sign); // positive
assert!(y_sign); // negative
}
Loading

0 comments on commit 930d48d

Please sign in to comment.