Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] use expr builder for modular chips #1212

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 64 additions & 217 deletions extensions/algebra/circuit/src/modular_chip/addsub.rs
Original file line number Diff line number Diff line change
@@ -1,232 +1,79 @@
use std::{cell::RefCell, rc::Rc, sync::Arc};
use std::{
cell::RefCell,
rc::Rc,
sync::{Arc, Mutex},
};

use itertools::Itertools;
use num_bigint_dig::BigUint;
use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
use openvm_circuit::arch::{
instructions::UsizeOpcode, AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface,
DynArray, MinimalInstruction, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
};
use openvm_circuit_primitives::{
var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
SubAir, TraceSubRowGenerator,
};
use openvm_instructions::instruction::Instruction;
use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory};
use openvm_circuit_derive::{InstructionExecutor, Stateful};
use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip};
use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
use openvm_mod_circuit_builder::{
utils::{biguint_to_limbs_vec, limbs_to_biguint},
ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExprCols, FieldVariable,
ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable,
};
use openvm_stark_backend::{
interaction::InteractionBuilder,
p3_air::BaseAir,
p3_field::{Field, FieldAlgebra, PrimeField32},
rap::BaseAirWithPublicValues,
};
use serde::Deserialize;
use serde_with::serde_derive::Serialize;

/// The number of limbs and limb bits are determined at runtime.
#[derive(Clone)]
pub struct ModularAddSubCoreAir {
pub expr: FieldExpr,
pub offset: usize,
}

impl ModularAddSubCoreAir {
pub fn new(
config: ExprBuilderConfig,
range_bus: VariableRangeCheckerBus,
offset: usize,
) -> Self {
config.check_valid();
let builder = ExprBuilder::new(config, range_bus.range_max_bits);
let builder = Rc::new(RefCell::new(builder));
let x1 = ExprBuilder::new_input(builder.clone());
let x2 = ExprBuilder::new_input(builder.clone());
let x3 = x1.clone() + x2.clone();
let x4 = x1.clone() - x2.clone();
let is_add_flag = builder.borrow_mut().new_flag();
let is_sub_flag = builder.borrow_mut().new_flag();
let x5 = FieldVariable::select(is_sub_flag, &x4, &x1);
let mut x6 = FieldVariable::select(is_add_flag, &x3, &x5);
x6.save();
let builder = builder.borrow().clone();

let expr = FieldExpr::new(builder, range_bus, true);
Self { expr, offset }
}
}

impl<F: Field> BaseAir<F> for ModularAddSubCoreAir {
fn width(&self) -> usize {
BaseAir::<F>::width(&self.expr)
}
use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
use openvm_stark_backend::p3_field::PrimeField32;

pub fn addsub_expr(
config: ExprBuilderConfig,
range_bus: VariableRangeCheckerBus,
) -> (FieldExpr, usize, usize) {
config.check_valid();
let builder = ExprBuilder::new(config, range_bus.range_max_bits);
let builder = Rc::new(RefCell::new(builder));

let x1 = ExprBuilder::new_input(builder.clone());
let x2 = ExprBuilder::new_input(builder.clone());
let x3 = x1.clone() + x2.clone();
let x4 = x1.clone() - x2.clone();
let is_add_flag = builder.borrow_mut().new_flag();
let is_sub_flag = builder.borrow_mut().new_flag();
let x5 = FieldVariable::select(is_sub_flag, &x4, &x1);
let mut x6 = FieldVariable::select(is_add_flag, &x3, &x5);
x6.save_output();
let builder = builder.borrow().clone();

(
FieldExpr::new(builder, range_bus, true),
is_add_flag,
is_sub_flag,
)
}

impl<F: Field> BaseAirWithPublicValues<F> for ModularAddSubCoreAir {}

impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for ModularAddSubCoreAir
where
I: VmAdapterInterface<AB::Expr>,
AdapterAirContext<AB::Expr, I>:
From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)]
pub struct ModularAddSubChip<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>(
pub VmChipWrapper<
F,
Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
FieldExpressionCoreChip,
>,
);

impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize>
ModularAddSubChip<F, BLOCKS, BLOCK_SIZE>
{
fn eval(
&self,
builder: &mut AB,
local: &[AB::Var],
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
self.expr.eval(builder, local);

let FieldExprCols {
is_valid,
inputs,
vars,
flags,
..
} = self.expr.load_vars(local);
assert_eq!(inputs.len(), 2);
assert_eq!(vars.len(), 1);
assert_eq!(flags.len(), 2);
let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
let writes: Vec<AB::Expr> = vars[0].iter().map(|x| (*x).into()).collect();

let local_opcode_idx = flags[0]
* AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::ADD as usize)
+ flags[1] * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SUB as usize)
+ (AB::Expr::ONE - flags[0] - flags[1])
* AB::Expr::from_canonical_usize(
Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize,
);

let instruction = MinimalInstruction {
is_valid: is_valid.into(),
opcode: local_opcode_idx + AB::Expr::from_canonical_usize(self.offset),
};

let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
to_pc: None,
reads: reads.into(),
writes: writes.into(),
instruction: instruction.into(),
};
ctx.into()
}
}

/// Number of limbs and limb size are determined purely at runtime
pub struct ModularAddSubCoreChip {
pub air: ModularAddSubCoreAir,
pub range_checker: Arc<VariableRangeCheckerChip>,
}

impl ModularAddSubCoreChip {
pub fn new(
adapter: Rv32VecHeapAdapterChip<F, 2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>,
config: ExprBuilderConfig,
range_checker: Arc<VariableRangeCheckerChip>,
offset: usize,
range_checker: Arc<VariableRangeCheckerChip>,
offline_memory: Arc<Mutex<OfflineMemory<F>>>,
) -> Self {
let air = ModularAddSubCoreAir::new(config, range_checker.bus(), offset);
Self { air, range_checker }
}
}

#[derive(Serialize, Deserialize)]
pub struct ModularAddSubCoreRecord {
pub x: BigUint,
pub y: BigUint,
pub is_add_flag: bool,
pub is_sub_flag: bool,
}

impl<F: PrimeField32, I> VmCoreChip<F, I> for ModularAddSubCoreChip
where
I: VmAdapterInterface<F>,
I::Reads: Into<DynArray<F>>,
AdapterRuntimeContext<F, I>: From<AdapterRuntimeContext<F, DynAdapterInterface<F>>>,
{
type Record = ModularAddSubCoreRecord;
type Air = ModularAddSubCoreAir;

fn execute_instruction(
&self,
instruction: &Instruction<F>,
_from_pc: u32,
reads: I::Reads,
) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
let num_limbs = self.air.expr.canonical_num_limbs();
let limb_bits = self.air.expr.canonical_limb_bits();
let Instruction { opcode, .. } = instruction;
let local_opcode_idx = opcode.local_opcode_idx(self.air.offset);
let data: DynArray<_> = reads.into();
let data = data.0;
debug_assert_eq!(data.len(), 2 * num_limbs);
let x = data[..num_limbs]
.iter()
.map(|x| x.as_canonical_u32())
.collect_vec();
let y = data[num_limbs..]
.iter()
.map(|x| x.as_canonical_u32())
.collect_vec();

let x_biguint = limbs_to_biguint(&x, limb_bits);
let y_biguint = limbs_to_biguint(&y, limb_bits);

let local_opcode = Rv32ModularArithmeticOpcode::from_usize(local_opcode_idx);
let is_add_flag = match local_opcode {
Rv32ModularArithmeticOpcode::ADD => true,
Rv32ModularArithmeticOpcode::SUB | Rv32ModularArithmeticOpcode::SETUP_ADDSUB => false,
_ => panic!("Unsupported opcode: {:?}", local_opcode),
};
let is_sub_flag = match local_opcode {
Rv32ModularArithmeticOpcode::SUB => true,
Rv32ModularArithmeticOpcode::ADD | Rv32ModularArithmeticOpcode::SETUP_ADDSUB => false,
_ => panic!("Unsupported opcode: {:?}", local_opcode),
};

let vars = self.air.expr.execute(
vec![x_biguint.clone(), y_biguint.clone()],
let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus());
let core = FieldExpressionCoreChip::new(
expr,
offset,
vec![
Rv32ModularArithmeticOpcode::ADD as usize,
Rv32ModularArithmeticOpcode::SUB as usize,
Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize,
],
vec![is_add_flag, is_sub_flag],
range_checker,
"ModularAddSub",
false,
);
assert_eq!(vars.len(), 1);
let z_biguint = vars[0].clone();
tracing::trace!(
"ModularArithmeticOpcode | {local_opcode:?} | {z_biguint:?} | {x_biguint:?} | {y_biguint:?}",
);
let z_limbs = biguint_to_limbs_vec(z_biguint, limb_bits, num_limbs);
let writes = z_limbs.into_iter().map(F::from_canonical_u32).collect_vec();
let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes);

Ok((
ctx.into(),
ModularAddSubCoreRecord {
x: x_biguint,
y: y_biguint,
is_add_flag,
is_sub_flag,
},
))
}

fn get_opcode_name(&self, _opcode: usize) -> String {
"ModularAddSub".to_string()
}

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
self.air.expr.generate_subrow(
(
&self.range_checker,
vec![record.x, record.y],
vec![record.is_add_flag, record.is_sub_flag],
),
row_slice,
);
}

fn air(&self) -> &Self::Air {
&self.air
Self(VmChipWrapper::new(adapter, core, offline_memory))
}
}
31 changes: 2 additions & 29 deletions extensions/algebra/circuit/src/modular_chip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,13 @@ mod is_eq;
pub use is_eq::*;
mod muldiv;
pub use muldiv::*;
use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper};
use openvm_circuit::arch::VmChipWrapper;
use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
use openvm_rv32_adapters::{
Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterChip,
};
use openvm_rv32_adapters::Rv32IsEqualModAdapterChip;

#[cfg(test)]
mod tests;

/// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory.
/// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes.
pub type ModularAddSubAir<const NUM_LANES: usize, const LANE_SIZE: usize> = VmAirWrapper<
Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>,
ModularAddSubCoreAir,
>;
/// See [ModularAddSubAir].
pub type ModularAddSubChip<F, const NUM_LANES: usize, const LANE_SIZE: usize> = VmChipWrapper<
F,
Rv32VecHeapAdapterChip<F, 2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>,
ModularAddSubCoreChip,
>;
/// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory.
/// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes.
pub type ModularMulDivAir<const NUM_LANES: usize, const LANE_SIZE: usize> = VmAirWrapper<
Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>,
ModularMulDivCoreAir,
>;
/// See [ModularMulDivAir].
pub type ModularMulDivChip<F, const NUM_LANES: usize, const LANE_SIZE: usize> = VmChipWrapper<
F,
Rv32VecHeapAdapterChip<F, 2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>,
ModularMulDivCoreChip,
>;

// Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE
pub type ModularIsEqualChip<
F,
Expand Down
Loading
Loading