From 818cbea56690bb207efb26c427cee568d99bd933 Mon Sep 17 00:00:00 2001 From: luffykai Date: Mon, 13 Jan 2025 22:13:08 +0800 Subject: [PATCH 1/2] use expr builder for modular chips --- .../circuit/src/modular_chip/addsub.rs | 88 +++++++++++++++++-- .../algebra/circuit/src/modular_chip/mod.rs | 20 ++--- .../algebra/circuit/src/modular_chip/tests.rs | 25 +++--- .../algebra/circuit/src/modular_extension.rs | 20 ++--- 4 files changed, 115 insertions(+), 38 deletions(-) diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index 918589391a..fd4282c91a 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -1,21 +1,33 @@ -use std::{cell::RefCell, rc::Rc, sync::Arc}; +use std::{ + cell::RefCell, + rc::Rc, + sync::{Arc, Mutex}, +}; use itertools::Itertools; use num_bigint_dig::BigUint; -use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::arch::{ - instructions::UsizeOpcode, AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, - DynArray, MinimalInstruction, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode}; +use openvm_circuit::{ + arch::{ + instructions::UsizeOpcode, AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, + DynArray, MinimalInstruction, Result, VmAdapterInterface, VmChipWrapper, VmCoreAir, + VmCoreChip, + }, + system::memory::OfflineMemory, }; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; use openvm_circuit_primitives::{ var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, SubAir, TraceSubRowGenerator, }; +use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::instruction::Instruction; use openvm_mod_circuit_builder::{ utils::{biguint_to_limbs_vec, limbs_to_biguint}, - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExprCols, FieldVariable, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExprCols, FieldExpressionCoreChip, + FieldVariable, }; +use openvm_rv32_adapters::Rv32VecHeapAdapterChip; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, @@ -25,6 +37,69 @@ use openvm_stark_backend::{ use serde::Deserialize; use serde_with::serde_derive::Serialize; +pub fn addsub_expr( + config: ExprBuilderConfig, + range_bus: VariableRangeCheckerBus, +) -> (FieldExpr, usize, usize) { + config.check_valid(); + let builder = ExprBuilder::new(config, range_bus.range_max_bits); + let builder = Rc::new(RefCell::new(builder)); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = x1.clone() + x2.clone(); + let x4 = x1.clone() - x2.clone(); + let is_add_flag = builder.borrow_mut().new_flag(); + let is_sub_flag = builder.borrow_mut().new_flag(); + let x5 = FieldVariable::select(is_sub_flag, &x4, &x1); + let mut x6 = FieldVariable::select(is_add_flag, &x3, &x5); + x6.save_output(); + let builder = builder.borrow().clone(); + + ( + FieldExpr::new(builder, range_bus, true), + is_add_flag, + is_sub_flag, + ) +} + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +pub struct ModularAddSubChip( + pub VmChipWrapper< + F, + Rv32VecHeapAdapterChip, + FieldExpressionCoreChip, + >, +); + +impl + ModularAddSubChip +{ + pub fn new( + adapter: Rv32VecHeapAdapterChip, + config: ExprBuilderConfig, + offset: usize, + range_checker: Arc, + offline_memory: Arc>>, + ) -> Self { + let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus()); + let core = FieldExpressionCoreChip::new( + expr, + offset, + vec![ + Rv32ModularArithmeticOpcode::ADD as usize, + Rv32ModularArithmeticOpcode::SUB as usize, + Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, + ], + vec![is_add_flag, is_sub_flag], + range_checker, + "ModularAddSub", + false, + ); + Self(VmChipWrapper::new(adapter, core, offline_memory)) + } +} +/* /// The number of limbs and limb bits are determined at runtime. #[derive(Clone)] pub struct ModularAddSubCoreAir { @@ -230,3 +305,4 @@ where &self.air } } +*/ diff --git a/extensions/algebra/circuit/src/modular_chip/mod.rs b/extensions/algebra/circuit/src/modular_chip/mod.rs index 87279fa3a8..c15b1289ac 100644 --- a/extensions/algebra/circuit/src/modular_chip/mod.rs +++ b/extensions/algebra/circuit/src/modular_chip/mod.rs @@ -15,16 +15,16 @@ mod tests; /// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory. /// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes. -pub type ModularAddSubAir = VmAirWrapper< - Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>, - ModularAddSubCoreAir, ->; -/// See [ModularAddSubAir]. -pub type ModularAddSubChip = VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - ModularAddSubCoreChip, ->; +// pub type ModularAddSubAir = VmAirWrapper< +// Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>, +// ModularAddSubCoreAir, +// >; +// /// See [ModularAddSubAir]. +// pub type ModularAddSubChip = VmChipWrapper< +// F, +// Rv32VecHeapAdapterChip, +// ModularAddSubCoreChip, // should be FieldExpressionCoreChip +// >; /// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory. /// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes. pub type ModularMulDivAir = VmAirWrapper< diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index 729cfe5461..8739b6f4ef 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -26,9 +26,7 @@ use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; -use super::{ - ModularAddSubCoreChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivCoreChip, -}; +use super::{ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivCoreChip}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; @@ -59,11 +57,11 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { num_limbs: NUM_LIMBS, limb_bits: LIMB_BITS, }; - let core = ModularAddSubCoreChip::new( - config, - tester.memory_controller().borrow().range_checker.clone(), - Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, - ); + // let core = ModularAddSubCoreChip::new( + // config, + // tester.memory_controller().borrow().range_checker.clone(), + // Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, + // ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); @@ -75,7 +73,14 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.address_bits(), bitwise_chip.clone(), ); - let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); + // let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = ModularAddSubChip::new( + adapter, + config, + Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, + tester.range_checker(), + tester.offline_memory_mutex_arc(), + ); let mut rng = create_seeded_rng(); let num_tests = 50; let mut all_ops = vec![ADD_LOCAL + 2]; // setup @@ -145,7 +150,7 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.core.air.offset + op), + VmOpcode::from_usize(chip.0.core.air.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 8fdc9770b1..dfc2cf9d01 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -20,8 +20,8 @@ use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; use crate::modular_chip::{ - ModularAddSubChip, ModularAddSubCoreChip, ModularIsEqualChip, ModularIsEqualCoreChip, - ModularMulDivChip, ModularMulDivCoreChip, + ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip, + ModularMulDivCoreChip, }; #[serde_as] @@ -120,11 +120,9 @@ impl VmExtension for ModularExtension { if bytes <= 32 { let addsub_chip = ModularAddSubChip::new( adapter_chip_32.clone(), - ModularAddSubCoreChip::new( - config32.clone(), - range_checker.clone(), - class_offset, - ), + config32.clone(), + class_offset, + range_checker.clone(), offline_memory.clone(), ); inventory.add_executor( @@ -172,11 +170,9 @@ impl VmExtension for ModularExtension { } else if bytes <= 48 { let addsub_chip = ModularAddSubChip::new( adapter_chip_48.clone(), - ModularAddSubCoreChip::new( - config48.clone(), - range_checker.clone(), - class_offset, - ), + config48.clone(), + class_offset, + range_checker.clone(), offline_memory.clone(), ); inventory.add_executor( From 34159bbf21c58117ab7763c3a4b49c13ce8493e6 Mon Sep 17 00:00:00 2001 From: luffykai Date: Mon, 13 Jan 2025 23:04:06 +0800 Subject: [PATCH 2/2] muldiv --- .../circuit/src/modular_chip/addsub.rs | 239 +------------- .../algebra/circuit/src/modular_chip/mod.rs | 31 +- .../circuit/src/modular_chip/muldiv.rs | 310 +++++------------- .../algebra/circuit/src/modular_chip/tests.rs | 25 +- .../algebra/circuit/src/modular_extension.rs | 17 +- 5 files changed, 101 insertions(+), 521 deletions(-) diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index fd4282c91a..846c50d075 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -4,38 +4,16 @@ use std::{ sync::{Arc, Mutex}, }; -use itertools::Itertools; -use num_bigint_dig::BigUint; -use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode}; -use openvm_circuit::{ - arch::{ - instructions::UsizeOpcode, AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, - DynArray, MinimalInstruction, Result, VmAdapterInterface, VmChipWrapper, VmCoreAir, - VmCoreChip, - }, - system::memory::OfflineMemory, -}; +use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; +use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; use openvm_circuit_derive::{InstructionExecutor, Stateful}; -use openvm_circuit_primitives::{ - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, - SubAir, TraceSubRowGenerator, -}; +use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_instructions::instruction::Instruction; use openvm_mod_circuit_builder::{ - utils::{biguint_to_limbs_vec, limbs_to_biguint}, - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExprCols, FieldExpressionCoreChip, - FieldVariable, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, }; use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, - rap::BaseAirWithPublicValues, -}; -use serde::Deserialize; -use serde_with::serde_derive::Serialize; +use openvm_stark_backend::p3_field::PrimeField32; pub fn addsub_expr( config: ExprBuilderConfig, @@ -99,210 +77,3 @@ impl Self(VmChipWrapper::new(adapter, core, offline_memory)) } } -/* -/// The number of limbs and limb bits are determined at runtime. -#[derive(Clone)] -pub struct ModularAddSubCoreAir { - pub expr: FieldExpr, - pub offset: usize, -} - -impl ModularAddSubCoreAir { - pub fn new( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - offset: usize, - ) -> Self { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - let x1 = ExprBuilder::new_input(builder.clone()); - let x2 = ExprBuilder::new_input(builder.clone()); - let x3 = x1.clone() + x2.clone(); - let x4 = x1.clone() - x2.clone(); - let is_add_flag = builder.borrow_mut().new_flag(); - let is_sub_flag = builder.borrow_mut().new_flag(); - let x5 = FieldVariable::select(is_sub_flag, &x4, &x1); - let mut x6 = FieldVariable::select(is_add_flag, &x3, &x5); - x6.save(); - let builder = builder.borrow().clone(); - - let expr = FieldExpr::new(builder, range_bus, true); - Self { expr, offset } - } -} - -impl BaseAir for ModularAddSubCoreAir { - fn width(&self) -> usize { - BaseAir::::width(&self.expr) - } -} - -impl BaseAirWithPublicValues for ModularAddSubCoreAir {} - -impl VmCoreAir for ModularAddSubCoreAir -where - I: VmAdapterInterface, - AdapterAirContext: - From>>, -{ - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - _from_pc: AB::Var, - ) -> AdapterAirContext { - assert_eq!(local.len(), BaseAir::::width(&self.expr)); - self.expr.eval(builder, local); - - let FieldExprCols { - is_valid, - inputs, - vars, - flags, - .. - } = self.expr.load_vars(local); - assert_eq!(inputs.len(), 2); - assert_eq!(vars.len(), 1); - assert_eq!(flags.len(), 2); - let reads: Vec = inputs.concat().iter().map(|x| (*x).into()).collect(); - let writes: Vec = vars[0].iter().map(|x| (*x).into()).collect(); - - let local_opcode_idx = flags[0] - * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::ADD as usize) - + flags[1] * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::SUB as usize) - + (AB::Expr::ONE - flags[0] - flags[1]) - * AB::Expr::from_canonical_usize( - Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, - ); - - let instruction = MinimalInstruction { - is_valid: is_valid.into(), - opcode: local_opcode_idx + AB::Expr::from_canonical_usize(self.offset), - }; - - let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext { - to_pc: None, - reads: reads.into(), - writes: writes.into(), - instruction: instruction.into(), - }; - ctx.into() - } -} - -/// Number of limbs and limb size are determined purely at runtime -pub struct ModularAddSubCoreChip { - pub air: ModularAddSubCoreAir, - pub range_checker: Arc, -} - -impl ModularAddSubCoreChip { - pub fn new( - config: ExprBuilderConfig, - range_checker: Arc, - offset: usize, - ) -> Self { - let air = ModularAddSubCoreAir::new(config, range_checker.bus(), offset); - Self { air, range_checker } - } -} - -#[derive(Serialize, Deserialize)] -pub struct ModularAddSubCoreRecord { - pub x: BigUint, - pub y: BigUint, - pub is_add_flag: bool, - pub is_sub_flag: bool, -} - -impl VmCoreChip for ModularAddSubCoreChip -where - I: VmAdapterInterface, - I::Reads: Into>, - AdapterRuntimeContext: From>>, -{ - type Record = ModularAddSubCoreRecord; - type Air = ModularAddSubCoreAir; - - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let num_limbs = self.air.expr.canonical_num_limbs(); - let limb_bits = self.air.expr.canonical_limb_bits(); - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(self.air.offset); - let data: DynArray<_> = reads.into(); - let data = data.0; - debug_assert_eq!(data.len(), 2 * num_limbs); - let x = data[..num_limbs] - .iter() - .map(|x| x.as_canonical_u32()) - .collect_vec(); - let y = data[num_limbs..] - .iter() - .map(|x| x.as_canonical_u32()) - .collect_vec(); - - let x_biguint = limbs_to_biguint(&x, limb_bits); - let y_biguint = limbs_to_biguint(&y, limb_bits); - - let local_opcode = Rv32ModularArithmeticOpcode::from_usize(local_opcode_idx); - let is_add_flag = match local_opcode { - Rv32ModularArithmeticOpcode::ADD => true, - Rv32ModularArithmeticOpcode::SUB | Rv32ModularArithmeticOpcode::SETUP_ADDSUB => false, - _ => panic!("Unsupported opcode: {:?}", local_opcode), - }; - let is_sub_flag = match local_opcode { - Rv32ModularArithmeticOpcode::SUB => true, - Rv32ModularArithmeticOpcode::ADD | Rv32ModularArithmeticOpcode::SETUP_ADDSUB => false, - _ => panic!("Unsupported opcode: {:?}", local_opcode), - }; - - let vars = self.air.expr.execute( - vec![x_biguint.clone(), y_biguint.clone()], - vec![is_add_flag, is_sub_flag], - ); - assert_eq!(vars.len(), 1); - let z_biguint = vars[0].clone(); - tracing::trace!( - "ModularArithmeticOpcode | {local_opcode:?} | {z_biguint:?} | {x_biguint:?} | {y_biguint:?}", - ); - let z_limbs = biguint_to_limbs_vec(z_biguint, limb_bits, num_limbs); - let writes = z_limbs.into_iter().map(F::from_canonical_u32).collect_vec(); - let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes); - - Ok(( - ctx.into(), - ModularAddSubCoreRecord { - x: x_biguint, - y: y_biguint, - is_add_flag, - is_sub_flag, - }, - )) - } - - fn get_opcode_name(&self, _opcode: usize) -> String { - "ModularAddSub".to_string() - } - - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.air.expr.generate_subrow( - ( - &self.range_checker, - vec![record.x, record.y], - vec![record.is_add_flag, record.is_sub_flag], - ), - row_slice, - ); - } - - fn air(&self) -> &Self::Air { - &self.air - } -} -*/ diff --git a/extensions/algebra/circuit/src/modular_chip/mod.rs b/extensions/algebra/circuit/src/modular_chip/mod.rs index c15b1289ac..2dd9838206 100644 --- a/extensions/algebra/circuit/src/modular_chip/mod.rs +++ b/extensions/algebra/circuit/src/modular_chip/mod.rs @@ -4,40 +4,13 @@ mod is_eq; pub use is_eq::*; mod muldiv; pub use muldiv::*; -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::VmChipWrapper; use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32_adapters::{ - Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterChip, -}; +use openvm_rv32_adapters::Rv32IsEqualModAdapterChip; #[cfg(test)] mod tests; -/// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory. -/// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes. -// pub type ModularAddSubAir = VmAirWrapper< -// Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>, -// ModularAddSubCoreAir, -// >; -// /// See [ModularAddSubAir]. -// pub type ModularAddSubChip = VmChipWrapper< -// F, -// Rv32VecHeapAdapterChip, -// ModularAddSubCoreChip, // should be FieldExpressionCoreChip -// >; -/// Each prime field element will be represented as `NUM_LANES * LANE_SIZE` cells in memory. -/// The `LANE_SIZE` must be a power of 2 and determines the size of the batch memory read/writes. -pub type ModularMulDivAir = VmAirWrapper< - Rv32VecHeapAdapterAir<2, NUM_LANES, NUM_LANES, LANE_SIZE, LANE_SIZE>, - ModularMulDivCoreAir, ->; -/// See [ModularMulDivAir]. -pub type ModularMulDivChip = VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - ModularMulDivCoreChip, ->; - // Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE pub type ModularIsEqualChip< F, diff --git a/extensions/algebra/circuit/src/modular_chip/muldiv.rs b/extensions/algebra/circuit/src/modular_chip/muldiv.rs index 7493b2aba0..c24c7284be 100644 --- a/extensions/algebra/circuit/src/modular_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/modular_chip/muldiv.rs @@ -1,247 +1,93 @@ -use std::{cell::RefCell, rc::Rc, sync::Arc}; +use std::{ + cell::RefCell, + rc::Rc, + sync::{Arc, Mutex}, +}; -use itertools::Itertools; -use num_bigint_dig::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, -}; -use openvm_circuit_primitives::{ - var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}, - SubAir, TraceSubRowGenerator, -}; -use openvm_instructions::{instruction::Instruction, UsizeOpcode}; +use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; +use openvm_circuit_derive::{InstructionExecutor, Stateful}; +use openvm_circuit_primitives::var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip}; +use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_mod_circuit_builder::{ - utils::{biguint_to_limbs_vec, limbs_to_biguint}, - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExprCols, FieldVariable, SymbolicExpr, -}; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, - rap::BaseAirWithPublicValues, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr, }; -use serde::Deserialize; -use serde_with::serde_derive::Serialize; - -/// The number of limbs and limb bits are determined at runtime. -#[derive(Clone)] -pub struct ModularMulDivCoreAir { - pub expr: FieldExpr, - pub offset: usize, +use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_stark_backend::p3_field::PrimeField32; + +pub fn muldiv_expr( + config: ExprBuilderConfig, + range_bus: VariableRangeCheckerBus, +) -> (FieldExpr, usize, usize) { + config.check_valid(); + let builder = ExprBuilder::new(config, range_bus.range_max_bits); + let builder = Rc::new(RefCell::new(builder)); + let x = ExprBuilder::new_input(builder.clone()); + let y = ExprBuilder::new_input(builder.clone()); + let (z_idx, z) = builder.borrow_mut().new_var(); + let mut z = FieldVariable::from_var(builder.clone(), z); + let is_mul_flag = builder.borrow_mut().new_flag(); + let is_div_flag = builder.borrow_mut().new_flag(); + // constraint is x * y = z, or z * y = x + let lvar = FieldVariable::select(is_mul_flag, &x, &z); + let rvar = FieldVariable::select(is_mul_flag, &z, &x); + // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x = 0, whatever z is. + let constraint = lvar * y.clone() - rvar; + builder.borrow_mut().set_constraint(z_idx, constraint.expr); + let compute = SymbolicExpr::Select( + is_mul_flag, + Box::new(x.expr.clone() * y.expr.clone()), + Box::new(SymbolicExpr::Select( + is_div_flag, + Box::new(x.expr.clone() / y.expr.clone()), + Box::new(x.expr.clone()), + )), + ); + builder.borrow_mut().set_compute(z_idx, compute); + z.save_output(); + + let builder = builder.borrow().clone(); + + ( + FieldExpr::new(builder, range_bus, true), + is_mul_flag, + is_div_flag, + ) } -impl ModularMulDivCoreAir { - pub fn new( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - offset: usize, - ) -> Self { - config.check_valid(); - - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - let x = ExprBuilder::new_input(builder.clone()); - let y = ExprBuilder::new_input(builder.clone()); - let (z_idx, z) = builder.borrow_mut().new_var(); - let z = FieldVariable::from_var(builder.clone(), z); - let is_mul_flag = builder.borrow_mut().new_flag(); - let is_div_flag = builder.borrow_mut().new_flag(); - // constraint is x * y = z, or z * y = x - let lvar = FieldVariable::select(is_mul_flag, &x, &z); - let rvar = FieldVariable::select(is_mul_flag, &z, &x); - // When it's SETUP op, x = p == 0, y = 0, both flags are false, and it still works: z * 0 - x = 0, whatever z is. - let constraint = lvar * y.clone() - rvar; - builder.borrow_mut().set_constraint(z_idx, constraint.expr); - let compute = SymbolicExpr::Select( - is_mul_flag, - Box::new(x.expr.clone() * y.expr.clone()), - Box::new(SymbolicExpr::Select( - is_div_flag, - Box::new(x.expr.clone() / y.expr.clone()), - Box::new(x.expr.clone()), - )), - ); - builder.borrow_mut().set_compute(z_idx, compute); - - let builder = builder.borrow().clone(); - - let expr = FieldExpr::new(builder, range_bus, true); - Self { expr, offset } - } -} - -impl BaseAir for ModularMulDivCoreAir { - fn width(&self) -> usize { - BaseAir::::width(&self.expr) - } -} - -impl BaseAirWithPublicValues for ModularMulDivCoreAir {} - -impl VmCoreAir for ModularMulDivCoreAir -where - I: VmAdapterInterface, - AdapterAirContext: - From>>, +#[derive(Chip, ChipUsageGetter, InstructionExecutor, Stateful)] +pub struct ModularMulDivChip( + pub VmChipWrapper< + F, + Rv32VecHeapAdapterChip, + FieldExpressionCoreChip, + >, +); + +impl + ModularMulDivChip { - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - _from_pc: AB::Var, - ) -> AdapterAirContext { - assert_eq!(local.len(), BaseAir::::width(&self.expr)); - self.expr.eval(builder, local); - - let FieldExprCols { - is_valid, - inputs, - vars, - flags, - .. - } = self.expr.load_vars(local); - assert_eq!(inputs.len(), 2); - assert_eq!(vars.len(), 1); - assert_eq!(flags.len(), 2); - let reads: Vec = inputs.concat().iter().map(|x| (*x).into()).collect(); - let writes: Vec = vars[0].iter().map(|x| (*x).into()).collect(); - - // Attention: we multiply in the setup case, hence flags[0] (is_mul_flag) does NOT imply that is_setup is false! - let local_opcode_idx = flags[0] - * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::MUL as usize) - + flags[1] * AB::Expr::from_canonical_usize(Rv32ModularArithmeticOpcode::DIV as usize) - + (AB::Expr::ONE - flags[0] - flags[1]) - * AB::Expr::from_canonical_usize( - Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, - ); - - let instruction = MinimalInstruction { - is_valid: is_valid.into(), - opcode: local_opcode_idx + AB::Expr::from_canonical_usize(self.offset), - }; - - let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext { - to_pc: None, - reads: reads.into(), - writes: writes.into(), - instruction: instruction.into(), - }; - ctx.into() - } -} - -pub struct ModularMulDivCoreChip { - pub air: ModularMulDivCoreAir, - pub range_checker: Arc, -} - -impl ModularMulDivCoreChip { pub fn new( + adapter: Rv32VecHeapAdapterChip, config: ExprBuilderConfig, - range_checker: Arc, offset: usize, + range_checker: Arc, + offline_memory: Arc>>, ) -> Self { - let air = ModularMulDivCoreAir::new(config, range_checker.bus(), offset); - Self { air, range_checker } - } -} - -#[derive(Serialize, Deserialize)] -pub struct ModularMulDivCoreRecord { - pub x: BigUint, - pub y: BigUint, - pub is_mul_flag: bool, - pub is_div_flag: bool, -} - -impl VmCoreChip for ModularMulDivCoreChip -where - I: VmAdapterInterface, - I::Reads: Into>, - AdapterRuntimeContext: From>>, -{ - type Record = ModularMulDivCoreRecord; - type Air = ModularMulDivCoreAir; - - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let num_limbs = self.air.expr.canonical_num_limbs(); - let limb_bits = self.air.expr.canonical_limb_bits(); - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(self.air.offset); - let data: DynArray<_> = reads.into(); - let data = data.0; - assert_eq!(data.len(), 2 * num_limbs); - let x = data[..num_limbs] - .iter() - .map(|x| x.as_canonical_u32()) - .collect_vec(); - let y = data[num_limbs..] - .iter() - .map(|x| x.as_canonical_u32()) - .collect_vec(); - - let x_biguint = limbs_to_biguint(&x, limb_bits); - let y_biguint = limbs_to_biguint(&y, limb_bits); - - let local_opcode = Rv32ModularArithmeticOpcode::from_usize(local_opcode_idx); - let is_mul_flag = match local_opcode { - Rv32ModularArithmeticOpcode::MUL => true, - Rv32ModularArithmeticOpcode::DIV | Rv32ModularArithmeticOpcode::SETUP_MULDIV => false, - _ => panic!("Unsupported opcode: {:?}", local_opcode), - }; - let is_div_flag = match local_opcode { - Rv32ModularArithmeticOpcode::DIV => true, - Rv32ModularArithmeticOpcode::MUL | Rv32ModularArithmeticOpcode::SETUP_MULDIV => false, - _ => panic!("Unsupported opcode: {:?}", local_opcode), - }; - - let vars = self.air.expr.execute( - vec![x_biguint.clone(), y_biguint.clone()], + let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus()); + let core = FieldExpressionCoreChip::new( + expr, + offset, + vec![ + Rv32ModularArithmeticOpcode::MUL as usize, + Rv32ModularArithmeticOpcode::DIV as usize, + Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, + ], vec![is_mul_flag, is_div_flag], + range_checker, + "ModularMulDiv", + false, ); - assert_eq!(vars.len(), 1); - let z_biguint = vars[0].clone(); - tracing::trace!( - "ModularArithmeticOpcode | {local_opcode:?} | {z_biguint:?} | {x_biguint:?} | {y_biguint:?}", - ); - let z_limbs = biguint_to_limbs_vec(z_biguint, limb_bits, num_limbs); - let writes = z_limbs.into_iter().map(F::from_canonical_u32).collect_vec(); - let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes); - - Ok(( - ctx.into(), - ModularMulDivCoreRecord { - x: x_biguint, - y: y_biguint, - is_mul_flag, - is_div_flag, - }, - )) - } - - fn get_opcode_name(&self, _opcode: usize) -> String { - "ModularMulDiv".to_string() - } - - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.air.expr.generate_subrow( - ( - &self.range_checker, - vec![record.x, record.y], - vec![record.is_mul_flag, record.is_div_flag], - ), - row_slice, - ); - } - - fn air(&self) -> &Self::Air { - &self.air + Self(VmChipWrapper::new(adapter, core, offline_memory)) } } diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index 8739b6f4ef..6faadc3853 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -4,7 +4,7 @@ use num_bigint_dig::BigUint; use num_traits::Zero; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::arch::{ - instructions::UsizeOpcode, testing::VmChipTestBuilder, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, + instructions::UsizeOpcode, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::{ bigint::utils::{ @@ -26,7 +26,7 @@ use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; -use super::{ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivCoreChip}; +use super::{ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; @@ -57,11 +57,6 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { num_limbs: NUM_LIMBS, limb_bits: LIMB_BITS, }; - // let core = ModularAddSubCoreChip::new( - // config, - // tester.memory_controller().borrow().range_checker.clone(), - // Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, - // ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); @@ -73,7 +68,6 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.address_bits(), bitwise_chip.clone(), ); - // let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); let mut chip = ModularAddSubChip::new( adapter, config, @@ -192,11 +186,6 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { num_limbs: NUM_LIMBS, limb_bits: LIMB_BITS, }; - let core = ModularMulDivCoreChip::new( - config, - tester.memory_controller().borrow().range_checker.clone(), - Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, - ); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); // doing 1xNUM_LIMBS reads and writes @@ -207,7 +196,13 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { tester.address_bits(), bitwise_chip.clone(), ); - let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = ModularMulDivChip::new( + adapter, + config, + Rv32ModularArithmeticOpcode::default_offset() + opcode_offset, + tester.range_checker(), + tester.offline_memory_mutex_arc(), + ); let mut rng = create_seeded_rng(); let num_tests = 50; let mut all_ops = vec![MUL_LOCAL + 2]; @@ -278,7 +273,7 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.core.air.offset + op), + VmOpcode::from_usize(chip.0.core.air.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index dfc2cf9d01..fc0798e020 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -21,7 +21,6 @@ use strum::EnumCount; use crate::modular_chip::{ ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip, - ModularMulDivCoreChip, }; #[serde_as] @@ -133,11 +132,9 @@ impl VmExtension for ModularExtension { )?; let muldiv_chip = ModularMulDivChip::new( adapter_chip_32.clone(), - ModularMulDivCoreChip::new( - config32.clone(), - range_checker.clone(), - class_offset, - ), + config32.clone(), + class_offset, + range_checker.clone(), offline_memory.clone(), ); inventory.add_executor( @@ -183,11 +180,9 @@ impl VmExtension for ModularExtension { )?; let muldiv_chip = ModularMulDivChip::new( adapter_chip_48.clone(), - ModularMulDivCoreChip::new( - config48.clone(), - range_checker.clone(), - class_offset, - ), + config48.clone(), + class_offset, + range_checker.clone(), offline_memory.clone(), ); inventory.add_executor(