From 408d9a5a7d8792df357a149d00cef8dfc605c239 Mon Sep 17 00:00:00 2001 From: stephenh-axiom-xyz Date: Thu, 31 Oct 2024 17:24:30 -0400 Subject: [PATCH] feat: RISC-V 256-bit Integer Chip Implementations (#716) * feat: RISC-V base ALU int256 chip * feat: remove eDSL support for U256 arithmetic, ECC, and modular arithmetic --- .github/workflows/benchmark-call.yml | 7 - .github/workflows/recursion-bench.yml | 14 - lib/recursion/Cargo.toml | 3 - lib/recursion/src/bin/alu256_e2e.rs | 173 ----- lib/recursion/src/hints.rs | 26 +- toolchain/instructions/src/lib.rs | 34 - toolchain/native-compiler/src/asm/compiler.rs | 136 ---- .../native-compiler/src/asm/instruction.rs | 128 ---- .../native-compiler/src/conversion/mod.rs | 168 ----- toolchain/native-compiler/src/ir/bits.rs | 17 +- .../native-compiler/src/ir/elliptic_curve.rs | 113 --- .../native-compiler/src/ir/instructions.rs | 57 -- toolchain/native-compiler/src/ir/mod.rs | 4 - .../src/ir/modular_arithmetic.rs | 252 ------- toolchain/native-compiler/src/ir/uint.rs | 103 --- .../tests/modular_arithmetic.rs | 192 ----- toolchain/native-compiler/tests/uint.rs | 361 --------- vm/src/arch/chip_set.rs | 149 ++-- vm/src/arch/chips.rs | 26 +- vm/src/arch/config.rs | 12 +- vm/src/arch/integration_api.rs | 216 +++++- vm/src/intrinsics/int256/mod.rs | 34 + vm/src/intrinsics/int256/tests.rs | 202 ++++++ vm/src/intrinsics/mod.rs | 1 + vm/src/lib.rs | 3 - vm/src/old/alu/air.rs | 158 ---- vm/src/old/alu/bridge.rs | 149 ---- vm/src/old/alu/columns.rs | 52 -- vm/src/old/alu/mod.rs | 301 -------- vm/src/old/alu/tests.rs | 662 ----------------- vm/src/old/alu/trace.rs | 133 ---- vm/src/old/mod.rs | 3 - vm/src/old/shift/air.rs | 179 ----- vm/src/old/shift/bridge.rs | 113 --- vm/src/old/shift/columns.rs | 58 -- vm/src/old/shift/mod.rs | 270 ------- vm/src/old/shift/tests.rs | 682 ------------------ vm/src/old/shift/trace.rs | 132 ---- vm/src/old/uint_multiplication/air.rs | 74 -- vm/src/old/uint_multiplication/bridge.rs | 92 --- vm/src/old/uint_multiplication/columns.rs | 43 -- vm/src/old/uint_multiplication/mod.rs | 182 ----- vm/src/old/uint_multiplication/tests.rs | 246 ------- vm/src/old/uint_multiplication/trace.rs | 100 --- vm/src/rv32im/adapters/heap.rs | 212 ++++++ vm/src/rv32im/adapters/mod.rs | 5 + vm/src/rv32im/adapters/vec_heap.rs | 233 +++--- vm/src/rv32im/base_alu/core.rs | 2 +- vm/src/rv32im/base_alu/tests.rs | 55 +- vm/src/rv32im/less_than/tests.rs | 87 +-- vm/src/rv32im/mul/tests.rs | 55 +- vm/src/rv32im/shift/tests.rs | 58 +- vm/src/system/program/util.rs | 5 +- vm/src/utils/test_utils.rs | 33 +- 54 files changed, 1023 insertions(+), 5782 deletions(-) delete mode 100644 lib/recursion/src/bin/alu256_e2e.rs delete mode 100644 toolchain/native-compiler/src/ir/elliptic_curve.rs delete mode 100644 toolchain/native-compiler/src/ir/modular_arithmetic.rs delete mode 100644 toolchain/native-compiler/src/ir/uint.rs delete mode 100644 toolchain/native-compiler/tests/modular_arithmetic.rs delete mode 100644 toolchain/native-compiler/tests/uint.rs create mode 100644 vm/src/intrinsics/int256/mod.rs create mode 100644 vm/src/intrinsics/int256/tests.rs delete mode 100644 vm/src/old/alu/air.rs delete mode 100644 vm/src/old/alu/bridge.rs delete mode 100644 vm/src/old/alu/columns.rs delete mode 100644 vm/src/old/alu/mod.rs delete mode 100644 vm/src/old/alu/tests.rs delete mode 100644 vm/src/old/alu/trace.rs delete mode 100644 vm/src/old/mod.rs delete mode 100644 vm/src/old/shift/air.rs delete mode 100644 vm/src/old/shift/bridge.rs delete mode 100644 vm/src/old/shift/columns.rs delete mode 100644 vm/src/old/shift/mod.rs delete mode 100644 vm/src/old/shift/tests.rs delete mode 100644 vm/src/old/shift/trace.rs delete mode 100644 vm/src/old/uint_multiplication/air.rs delete mode 100644 vm/src/old/uint_multiplication/bridge.rs delete mode 100644 vm/src/old/uint_multiplication/columns.rs delete mode 100644 vm/src/old/uint_multiplication/mod.rs delete mode 100644 vm/src/old/uint_multiplication/tests.rs delete mode 100644 vm/src/old/uint_multiplication/trace.rs create mode 100644 vm/src/rv32im/adapters/heap.rs diff --git a/.github/workflows/benchmark-call.yml b/.github/workflows/benchmark-call.yml index 8cc6948eb6..43f3ab88b5 100644 --- a/.github/workflows/benchmark-call.yml +++ b/.github/workflows/benchmark-call.yml @@ -10,7 +10,6 @@ on: options: - verify_fibair - fibonacci - - alu256_e2e - small_e2e instance_type: type: string @@ -93,12 +92,6 @@ jobs: run: | python3 ../ci/scripts/bench.py $BIN_NAME $CMD_ARGS - - name: Run benchmark - if: inputs.benchmark_name == 'alu256_e2e' - working-directory: lib/recursion - run: | - python3 ../../ci/scripts/bench.py $BIN_NAME $CMD_ARGS - - name: Run benchmark if: inputs.benchmark_name == 'small_e2e' working-directory: lib/recursion diff --git a/.github/workflows/recursion-bench.yml b/.github/workflows/recursion-bench.yml index 7524963599..5203413204 100644 --- a/.github/workflows/recursion-bench.yml +++ b/.github/workflows/recursion-bench.yml @@ -24,20 +24,6 @@ env: AXIOM_FAST_TEST: "1" jobs: - benchmark_alu256_e2e: - uses: ./.github/workflows/benchmark-call.yml - # run on pull request with label 'run-benchmark' or 'run-benchmark-e2e' - # and always run on push to main - if: | - (github.event_name == 'pull_request' && - (contains(github.event.pull_request.labels.*.name, 'run-benchmark') || - contains(github.event.pull_request.labels.*.name, 'run-benchmark-e2e'))) || - (github.event_name == 'push' && github.ref == 'refs/heads/main') - with: - instance_type: 64cpu-linux-arm64 - benchmark_name: alu256_e2e - secrets: inherit - benchmark_small_e2e: uses: ./.github/workflows/benchmark-call.yml # run on non-draft pull request with label 'run-benchmark' diff --git a/lib/recursion/Cargo.toml b/lib/recursion/Cargo.toml index 1e11408322..8a26a51917 100644 --- a/lib/recursion/Cargo.toml +++ b/lib/recursion/Cargo.toml @@ -49,9 +49,6 @@ cfg-if = { workspace = true } [dev-dependencies] ax-circuit-primitives.workspace = true -[[bin]] -name = "alu256_e2e" - [[bin]] name = "small_e2e" diff --git a/lib/recursion/src/bin/alu256_e2e.rs b/lib/recursion/src/bin/alu256_e2e.rs deleted file mode 100644 index a6c39c442f..0000000000 --- a/lib/recursion/src/bin/alu256_e2e.rs +++ /dev/null @@ -1,173 +0,0 @@ -/// E2E benchmark to aggregate small program with ALU chips. -/// Proofs: -/// 1. Prove a program with some ALU operations. -/// 2. Verify the proof of 1. in the inner config. -/// 2. Verify the proof of 2. in the outer config. -/// 3. Verify the proof of 3. using a Halo2 static verifier. -/// 4. Wrapper Halo2 circuit to reduce the size of 4. -use std::iter; - -use ax_stark_sdk::{ - bench::run_with_metric_collection, - config::{ - baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, - }, - engine::{ProofInputForTest, StarkFriEngine}, -}; -use axvm_circuit::{ - arch::{instructions::program::Program, ExecutorName, VmConfig}, - utils::gen_vm_program_test_proof_input, -}; -use axvm_native_compiler::{ - asm::AsmBuilder, - conversion::CompilerOptions, - ir::{RVar, Var}, -}; -use axvm_recursion::testing_utils::inner::build_verification_program; -use num_bigint_dig::BigUint; -use p3_baby_bear::BabyBear; -use p3_commit::PolynomialSpace; -use p3_field::{extension::BinomialExtensionField, AbstractField}; -use p3_uni_stark::{Domain, StarkGenericConfig}; -use tracing::info_span; - -const NUM_DIGITS: usize = 8; - -fn bench_program() -> Program { - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let sum_digits = iter::repeat(0u32).take(NUM_DIGITS).collect::>(); - let min_digits = iter::repeat(u32::MAX).take(NUM_DIGITS).collect::>(); - let val_digits = iter::once(246) - .chain(iter::repeat(0u32)) - .take(NUM_DIGITS) - .collect::>(); - let one_digits = iter::once(1) - .chain(iter::repeat(0u32)) - .take(NUM_DIGITS) - .collect::>(); - - let n: Var<_> = builder.eval(F::from_canonical_u32(32)); - let sum = builder.eval_biguint(BigUint::new(sum_digits)); - let min = builder.eval_biguint(BigUint::new(min_digits)); - let val = builder.eval_biguint(BigUint::new(val_digits)); - let one = builder.eval_biguint(BigUint::new(one_digits)); - - builder.range(RVar::zero(), n).for_each(|_, builder| { - let add = builder.add_256(&sum, &val); - let sub = builder.sub_256(&min, &val); - - let and = builder.and_256(&add, &sub); - let xor = builder.xor_256(&add, &sub); - let or = builder.or_256(&and, &xor); - - let sltu = builder.sltu_256(&add, &sub); - let slt = builder.slt_256(&add, &sub); - - let shift_val = or.clone(); - builder - .if_eq(sltu, F::from_canonical_u32(1)) - .then(|builder| { - let srl = builder.srl_256(&shift_val, &one); - builder.assign(&shift_val, srl); - }); - builder - .if_eq(slt, F::from_canonical_u32(0)) - .then(|builder| { - let sra = builder.sra_256(&shift_val, &one); - builder.assign(&shift_val, sra); - }); - - let sll = builder.sll_256(&shift_val, &one); - let eq = builder.eq_256(&sll, &or); - builder.if_eq(eq, F::from_canonical_u32(0)).then(|builder| { - let temp = builder.add_256(&add, &one); - builder.assign(&add, temp); - }); - builder.if_eq(eq, F::from_canonical_u32(1)).then(|builder| { - let temp = builder.sub_256(&sub, &one); - builder.assign(&sub, temp); - }); - - builder.assign(&sum, add); - builder.assign(&min, sub); - }); - - builder.halt(); - builder.compile_isa_with_options(CompilerOptions { - word_size: 32, - ..Default::default() - }) -} - -fn bench_program_test_proof_input() -> ProofInputForTest -where - Domain: PolynomialSpace, -{ - let program = bench_program(); - - let vm_config = VmConfig { - ..Default::default() - } - .add_executor(ExecutorName::BranchEqual) - .add_executor(ExecutorName::Jal) - .add_executor(ExecutorName::LoadStore) - .add_executor(ExecutorName::FieldArithmetic) - .add_executor(ExecutorName::ArithmeticLogicUnit256) - .add_executor(ExecutorName::Shift256); - gen_vm_program_test_proof_input(program, vec![], vm_config) -} - -fn main() { - run_with_metric_collection("OUTPUT_PATH", || { - let vdata = - info_span!("Bench Program Inner", group = "bench_program_inner").in_scope(|| { - let program_stark = bench_program_test_proof_input(); - program_stark - .run_test(&BabyBearPoseidon2Engine::new( - standard_fri_params_with_100_bits_conjectured_security(4), - )) - .unwrap() - }); - - let compiler_options = CompilerOptions { - enable_cycle_tracker: true, - ..Default::default() - }; - #[allow(unused_variables)] - let vdata = info_span!("Inner Verifier", group = "inner_verifier").in_scope(|| { - let (program, witness_stream) = - build_verification_program(vdata, compiler_options.clone()); - let inner_verifier_stf = gen_vm_program_test_proof_input( - program, - witness_stream, - VmConfig::aggregation(4, 7), - ); - inner_verifier_stf - .run_test(&BabyBearPoseidon2Engine::new( - // log_blowup = 3 because of poseidon2 chip. - standard_fri_params_with_100_bits_conjectured_security(3), - )) - .unwrap() - }); - - #[cfg(feature = "static-verifier")] - info_span!("Recursive Verify e2e", group = "recursive_verify_e2e").in_scope(|| { - let (program, witness_stream) = - build_verification_program(vdata, compiler_options.clone()); - let outer_verifier_sft = gen_vm_program_test_proof_input( - program, - witness_stream, - VmConfig::aggregation(4, 7), - ); - axvm_recursion::halo2::testing_utils::run_evm_verifier_e2e_test( - outer_verifier_sft, - // log_blowup = 3 because of poseidon2 chip. - Some(standard_fri_params_with_100_bits_conjectured_security(3)), - ); - }); - }); -} diff --git a/lib/recursion/src/hints.rs b/lib/recursion/src/hints.rs index 205c591016..aee24ab11c 100644 --- a/lib/recursion/src/hints.rs +++ b/lib/recursion/src/hints.rs @@ -1,6 +1,5 @@ use std::cmp::Reverse; -use ax_circuit_primitives::bigint::utils::big_uint_to_num_limbs; use ax_stark_backend::{ keygen::types::TraceWidth, prover::{ @@ -10,11 +9,9 @@ use ax_stark_backend::{ }; use ax_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; use axvm_native_compiler::ir::{ - unsafe_array_transmute, Array, BigUintVar, Builder, Config, Ext, Felt, MemVariable, Usize, Var, - DIGEST_SIZE, LIMB_BITS, NUM_LIMBS, + unsafe_array_transmute, Array, Builder, Config, Ext, Felt, MemVariable, Usize, Var, DIGEST_SIZE, }; use itertools::Itertools; -use num_bigint_dig::BigUint; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_commit::ExtensionMmcs; use p3_field::{extension::BinomialExtensionField, AbstractExtensionField, AbstractField, Field}; @@ -472,27 +469,6 @@ impl Hintable for Commitments { } } -impl Hintable for BigUint { - type HintVariable = BigUintVar; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let ret = builder.uninit_biguint(); - for i in 0..NUM_LIMBS { - // FIXME: range check for each element. - let v = builder.hint_var(); - builder.set_value(&ret, i, v); - } - ret - } - - fn write(&self) -> Vec::N>> { - vec![big_uint_to_num_limbs(self, LIMB_BITS, NUM_LIMBS) - .iter() - .map(|x| ::N::from_canonical_usize(*x)) - .collect()] - } -} - #[cfg(test)] mod test { use axvm_circuit::system::program::util::execute_program; diff --git a/toolchain/instructions/src/lib.rs b/toolchain/instructions/src/lib.rs index 458ff320df..ff9378c575 100644 --- a/toolchain/instructions/src/lib.rs +++ b/toolchain/instructions/src/lib.rs @@ -150,40 +150,6 @@ pub enum ModularArithmeticOpcode { DIV, } -// to be deleted and replaced by Rv32Alu256Opcodes below -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, UsizeOpcode, -)] -#[opcode_offset = 0x180] -#[repr(usize)] -pub enum U256Opcode { - // maybe later we will make it uint and specify the parameters in the config - ADD, - SUB, - LT, - EQ, - XOR, - AND, - OR, - SLT, - - SLL, - SRL, - SRA, - - MUL, -} -impl U256Opcode { - // Excludes multiplication - pub fn arithmetic_opcodes() -> impl Iterator { - (U256Opcode::ADD as usize..=U256Opcode::SLT as usize).map(U256Opcode::from_usize) - } - - pub fn shift_opcodes() -> impl Iterator { - (U256Opcode::SLL as usize..=U256Opcode::SRA as usize).map(U256Opcode::from_usize) - } -} - // to be deleted and replaced by Rv32SwOpcode #[derive( Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, UsizeOpcode, diff --git a/toolchain/native-compiler/src/asm/compiler.rs b/toolchain/native-compiler/src/asm/compiler.rs index e0bd211bc0..f952ca1f36 100644 --- a/toolchain/native-compiler/src/asm/compiler.rs +++ b/toolchain/native-compiler/src/asm/compiler.rs @@ -285,130 +285,6 @@ impl + TwoAdicField> AsmCo DslIr::MulEFI(dst, lhs, rhs) => { self.mul_ext_felti(dst, lhs, rhs, debug_info); } - DslIr::ModularAdd(modulus, dst, lhs, rhs) => { - self.push( - AsmInstruction::ModularAdd( - modulus, - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } - DslIr::ModularSub(modulus, dst, lhs, rhs) => { - self.push( - AsmInstruction::ModularSub( - modulus, - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } - DslIr::ModularMul(modulus, dst, lhs, rhs) => { - self.push( - AsmInstruction::ModularMul( - modulus, - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } - DslIr::ModularDiv(modulus, dst, lhs, rhs) => { - self.push( - AsmInstruction::ModularDiv( - modulus, - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } - DslIr::Add256(dst, lhs, rhs) => { - self.push( - AsmInstruction::Add256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::Sub256(dst, lhs, rhs) => { - self.push( - AsmInstruction::Sub256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::Mul256(dst, lhs, rhs) => { - self.push( - AsmInstruction::Mul256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::LessThanU256(dst, lhs, rhs) => { - self.push( - AsmInstruction::LessThanU256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::EqualTo256(dst, lhs, rhs) => { - self.push( - AsmInstruction::EqualTo256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::Xor256(dst, lhs, rhs) => { - self.push( - AsmInstruction::Xor256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::And256(dst, lhs, rhs) => { - self.push( - AsmInstruction::And256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::Or256(dst, lhs, rhs) => { - self.push( - AsmInstruction::Or256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::LessThanI256(dst, lhs, rhs) => { - self.push( - AsmInstruction::LessThanI256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::ShiftLeft256(dst, lhs, rhs) => { - self.push( - AsmInstruction::ShiftLeft256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), - debug_info, - ); - } - DslIr::ShiftRightLogic256(dst, lhs, rhs) => { - self.push( - AsmInstruction::ShiftRightLogic256( - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } - DslIr::ShiftRightArith256(dst, lhs, rhs) => { - self.push( - AsmInstruction::ShiftRightArith256( - dst.ptr_fp(), - lhs.ptr_fp(), - rhs.ptr_fp(), - ), - debug_info, - ); - } DslIr::IfEq(lhs, rhs, then_block, else_block) => { let if_compiler = IfCompiler { compiler: self, @@ -639,18 +515,6 @@ impl + TwoAdicField> AsmCo }, _ => unimplemented!(), }, - DslIr::Secp256k1AddUnequal(dst, p, q) => { - self.push( - AsmInstruction::Secp256k1AddUnequal(dst.ptr_fp(), p.ptr_fp(), q.ptr_fp()), - debug_info, - ); - } - DslIr::Secp256k1Double(dst, p) => { - self.push( - AsmInstruction::Secp256k1Double(dst.ptr_fp(), p.ptr_fp()), - debug_info, - ); - } DslIr::Error() => self.push(AsmInstruction::j(self.trap_label), debug_info), DslIr::PrintF(dst) => { self.push(AsmInstruction::PrintF(dst.fp()), debug_info); diff --git a/toolchain/native-compiler/src/asm/instruction.rs b/toolchain/native-compiler/src/asm/instruction.rs index e2bfd3b868..62c552d231 100644 --- a/toolchain/native-compiler/src/asm/instruction.rs +++ b/toolchain/native-compiler/src/asm/instruction.rs @@ -1,7 +1,6 @@ use alloc::{collections::BTreeMap, format}; use core::fmt; -use num_bigint_dig::BigUint; use p3_field::{ExtensionField, PrimeField32}; use super::A0; @@ -56,11 +55,6 @@ pub enum AsmInstruction { /// Divide value from immediate, dst = lhs / rhs. DivFIN(i32, F, i32), - /// U256 equal, dst = lhs == rhs. - /// (a, b, c) are memory pointers to (*z, *x, *y), which are - /// themselves memory pointers to (z, x, y) where z = (x == y ? 1 : 0) - EqU256(i32, i32, i32), - /// Add extension, dst = lhs + rhs. AddE(i32, i32, i32), @@ -73,47 +67,6 @@ pub enum AsmInstruction { /// Divide extension, dst = lhs / rhs. DivE(i32, i32, i32), - ModularAdd(BigUint, i32, i32, i32), - ModularSub(BigUint, i32, i32, i32), - ModularMul(BigUint, i32, i32, i32), - ModularDiv(BigUint, i32, i32, i32), - - /// int add, dst = lhs + rhs. - Add256(i32, i32, i32), - - /// int subtract, dst = lhs - rhs. - Sub256(i32, i32, i32), - - /// int multiply, dst = lhs * rhs. - Mul256(i32, i32, i32), - - /// uint less than, dst = lhs < rhs. - LessThanU256(i32, i32, i32), - - /// int equal to, dst = lhs == rhs. - EqualTo256(i32, i32, i32), - - /// int bitwise XOR, dst = lhs ^ rhs - Xor256(i32, i32, i32), - - /// int bitwise AND, dst = lhs & rhs - And256(i32, i32, i32), - - /// int bitwise OR, dst = lhs | rhs - Or256(i32, i32, i32), - - /// signed int less than, dst = lhs < rhs - LessThanI256(i32, i32, i32), - - /// int shift left, dst = lhs << rhs - ShiftLeft256(i32, i32, i32), - - /// int shift right logical, dst = lhs >> rhs - ShiftRightLogic256(i32, i32, i32), - - /// int shift right arithmetic, dst = lhs >> rhs - ShiftRightArith256(i32, i32, i32), - /// Jump. Jump(i32, F), @@ -165,14 +118,6 @@ pub enum AsmInstruction { /// Same as `Keccak256`, but with fixed length input (hence length is an immediate value). Keccak256FixLen(i32, i32, F), - /// (dst_ptr_ptr, p_ptr_ptr, q_ptr_ptr) are pointers to pointers to (dst, p, q). - /// Reads p,q from memory and writes p+q to dst. - /// Assumes p != +-q as secp256k1 points. - Secp256k1AddUnequal(i32, i32, i32), - /// (dst_ptr_ptr, p_ptr_ptr) are pointers to pointers to (dst, p). - /// Reads p,q from memory and writes 2*p to dst. - Secp256k1Double(i32, i32), - /// Print a variable. PrintV(i32), @@ -272,9 +217,6 @@ impl> AsmInstruction { AsmInstruction::DivFIN(dst, lhs, rhs) => { write!(f, "divi ({})fp, {}, ({})fp", dst, lhs, rhs) } - AsmInstruction::EqU256(dst, lhs, rhs) => { - write!(f, "eq ({})fp, ({})fp, ({})fp", dst, lhs, rhs) - } AsmInstruction::AddE(dst, lhs, rhs) => { write!(f, "eadd ({})fp, ({})fp, ({})fp", dst, lhs, rhs) } @@ -386,12 +328,6 @@ impl> AsmInstruction { AsmInstruction::Keccak256FixLen(dst, src, len) => { write!(f, "keccak256 ({dst})fp, ({src})fp, {len}",) } - AsmInstruction::Secp256k1AddUnequal(dst, p, q) => { - write!(f, "secp256k1_add_unequal ({})fp, ({})fp, ({})fp", dst, p, q) - } - AsmInstruction::Secp256k1Double(dst, p) => { - write!(f, "secp256k1_double ({})fp, ({})fp", dst, p) - } AsmInstruction::PrintF(dst) => { write!(f, "print_f ({})fp", dst) } @@ -414,70 +350,6 @@ impl> AsmInstruction { AsmInstruction::CycleTrackerEnd() => { write!(f, "cycle_tracker_end") } - AsmInstruction::ModularAdd(modulus, dst, src1, src2) => { - write!( - f, - "modular_add with modulus {:?}: ({})fp ({})fp ({})fp", - modulus, dst, src1, src2 - ) - } - AsmInstruction::ModularSub(modulus, dst, src1, src2) => { - write!( - f, - "modular_sub with modulus {:?}: ({})fp ({})fp ({})fp", - modulus, dst, src1, src2 - ) - } - AsmInstruction::ModularMul(modulus, dst, src1, src2) => { - write!( - f, - "modular_mul with modulus {:?}: ({})fp ({})fp ({})fp", - modulus, dst, src1, src2 - ) - } - AsmInstruction::ModularDiv(modulus, dst, src1, src2) => { - write!( - f, - "modular_div with modulus {:?}: ({})fp ({})fp ({})fp", - modulus, dst, src1, src2 - ) - } - AsmInstruction::Add256(dst, src1, src2) => { - write!(f, "add_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::Sub256(dst, src1, src2) => { - write!(f, "sub_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::Mul256(dst, src1, src2) => { - write!(f, "mul_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::LessThanU256(dst, src1, src2) => { - write!(f, "sltu_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::EqualTo256(dst, src1, src2) => { - write!(f, "eq_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::Xor256(dst, src1, src2) => { - write!(f, "xor_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::And256(dst, src1, src2) => { - write!(f, "and_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::Or256(dst, src1, src2) => { - write!(f, "or_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::LessThanI256(dst, src1, src2) => { - write!(f, "slt_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::ShiftLeft256(dst, src1, src2) => { - write!(f, "sll_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::ShiftRightLogic256(dst, src1, src2) => { - write!(f, "srl_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } - AsmInstruction::ShiftRightArith256(dst, src1, src2) => { - write!(f, "sra_256 ({})fp ({})fp ({})fp", dst, src1, src2) - } } } } diff --git a/toolchain/native-compiler/src/conversion/mod.rs b/toolchain/native-compiler/src/conversion/mod.rs index 13cfad15fe..2d94a53f8a 100644 --- a/toolchain/native-compiler/src/conversion/mod.rs +++ b/toolchain/native-compiler/src/conversion/mod.rs @@ -145,28 +145,6 @@ fn i32_f(x: i32) -> F { } } -fn convert_comparison_instruction>( - instruction: AsmInstruction, - options: &CompilerOptions, -) -> Vec> { - match instruction { - AsmInstruction::EqU256(a, b, c) => vec![inst_large( - options.opcode_with_offset(U256Opcode::EQ), - i32_f(a), - i32_f(b), - i32_f(c), - AS::Memory, - AS::Memory, - AS::Memory.to_field(), - AS::Memory.to_field(), - )], - _ => panic!( - "Illegal argument to convert_comparison_instruction: {:?}", - instruction - ), - } -} - fn convert_base_arithmetic_instruction>( instruction: AsmInstruction, options: &CompilerOptions, @@ -622,7 +600,6 @@ fn convert_instruction>( ) } } - AsmInstruction::EqU256(..) => convert_comparison_instruction(instruction, options), AsmInstruction::AddE(..) | AsmInstruction::SubE(..) | AsmInstruction::MulE(..) @@ -652,134 +629,6 @@ fn convert_instruction>( AS::Memory, AS::Memory, )], - AsmInstruction::ModularAdd(modulus, dst, src1, src2) => vec![inst( - options.modular_opcode_with_offset(ModularArithmeticOpcode::ADD, modulus), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ModularSub(modulus, dst, src1, src2) => vec![inst( - options.modular_opcode_with_offset(ModularArithmeticOpcode::SUB, modulus), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ModularMul(modulus, dst, src1, src2) => vec![inst( - options.modular_opcode_with_offset(ModularArithmeticOpcode::MUL, modulus), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ModularDiv(modulus, dst, src1, src2) => vec![inst( - options.modular_opcode_with_offset(ModularArithmeticOpcode::DIV, modulus), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::Add256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::ADD), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::Sub256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::SUB), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::Mul256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::MUL), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::LessThanU256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::LT), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::EqualTo256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::EQ), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::Xor256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::XOR), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::And256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::AND), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::Or256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::OR), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::LessThanI256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::SLT), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ShiftLeft256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::SLL), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ShiftRightLogic256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::SRL), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], - AsmInstruction::ShiftRightArith256(dst, src1, src2) => vec![inst( - options.opcode_with_offset(U256Opcode::SRA), - i32_f(dst), - i32_f(src1), - i32_f(src2), - AS::Memory, - AS::Memory, - )], AsmInstruction::Keccak256(dst, src, len) => vec![inst_med( options.opcode_with_offset(Keccak256Opcode::KECCAK256), i32_f(dst), @@ -801,23 +650,6 @@ fn convert_instruction>( // AS::Immediate, // ) } - AsmInstruction::Secp256k1AddUnequal(dst_ptr_ptr, p_ptr_ptr, q_ptr_ptr) => vec![inst_med( - options.opcode_with_offset(EccOpcode::EC_ADD_NE), - i32_f(dst_ptr_ptr), - i32_f(p_ptr_ptr), - i32_f(q_ptr_ptr), - AS::Memory, - AS::Memory, - AS::Memory, - )], - AsmInstruction::Secp256k1Double(dst_ptr_ptr, p_ptr_ptr) => vec![inst( - options.opcode_with_offset(EccOpcode::EC_DOUBLE), - i32_f(dst_ptr_ptr), - i32_f(p_ptr_ptr), - F::zero(), - AS::Memory, - AS::Memory, - )], AsmInstruction::CycleTrackerStart() => { if options.enable_cycle_tracker { vec![Instruction::debug(PhantomInstruction::CtStart)] diff --git a/toolchain/native-compiler/src/ir/bits.rs b/toolchain/native-compiler/src/ir/bits.rs index 964116882d..1882d735d2 100644 --- a/toolchain/native-compiler/src/ir/bits.rs +++ b/toolchain/native-compiler/src/ir/bits.rs @@ -1,6 +1,6 @@ use p3_field::AbstractField; -use super::{Array, BigUintVar, Builder, Config, DslIr, Felt, MemIndex, RVar, Var}; +use super::{Array, Builder, Config, DslIr, Felt, MemIndex, RVar, Var}; pub const NUM_BITS: usize = 31; @@ -32,21 +32,6 @@ impl Builder { output } - pub fn num2bits_biguint(&mut self, biguint: &BigUintVar) -> Array> { - let repr_size = self.bigint_repr_size; - let num_limbs = (256 + repr_size - 1) / repr_size; - let bits = self.dyn_array((num_limbs * repr_size) as usize); - for i in 0..num_limbs as usize { - let limb = self.get(biguint, i); - let limb_bits = self.num2bits_v(limb, repr_size); - for j in 0..repr_size as usize { - let val = self.get(&limb_bits, j); - self.set_value(&bits, j + i * repr_size as usize, val); - } - } - bits - } - /// Converts a variable to bits inside a circuit. pub fn num2bits_v_circuit(&mut self, num: Var, bits: usize) -> Vec> { let mut output = Vec::new(); diff --git a/toolchain/native-compiler/src/ir/elliptic_curve.rs b/toolchain/native-compiler/src/ir/elliptic_curve.rs deleted file mode 100644 index 3d02e0245c..0000000000 --- a/toolchain/native-compiler/src/ir/elliptic_curve.rs +++ /dev/null @@ -1,113 +0,0 @@ -use p3_field::{AbstractField, PrimeField64}; - -use super::{Array, DslIr}; -use crate::ir::{modular_arithmetic::BigUintVar, Builder, Config, Var}; - -impl Builder -where - C::N: PrimeField64, -{ - /// Computes `p + q`, handling cases where `p` or `q` are identity. - /// - /// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes. - /// Identity point is represented as (0, 0). - pub fn secp256k1_add( - &mut self, - point_1: Array>, - point_2: Array>, - ) -> Array> { - // number of limbs to represent one coordinate - let num_limbs = ((256 + self.bigint_repr_size - 1) / self.bigint_repr_size) as usize; - // Assuming point_1.len() = 2 * num_limbs - let x1 = point_1.slice(self, 0, num_limbs); - let y1 = point_1.slice(self, num_limbs, 2 * num_limbs); - - let res = self.uninit(); - let x1_zero = self.secp256k1_coord_is_zero(&x1); - let y1_zero = self.secp256k1_coord_is_zero(&y1); - - // if point_1 is identity - self.if_eq(x1_zero * y1_zero, C::N::one()).then_or_else( - |builder| { - builder.assign(&res, point_2.clone()); - }, - |builder| { - let x2 = point_2.slice(builder, 0, num_limbs); - let y2 = point_2.slice(builder, num_limbs, 2 * num_limbs); - let x2_zero = builder.secp256k1_coord_is_zero(&x2); - let y2_zero = builder.secp256k1_coord_is_zero(&y2); - // else if point_2 is identity - builder.if_eq(x2_zero * y2_zero, C::N::one()).then_or_else( - |builder| { - builder.assign(&res, point_1.clone()); - }, - |builder| { - let xs_equal = builder.secp256k1_coord_eq(&x1, &x2); - builder.if_eq(xs_equal, C::N::one()).then_or_else( - |builder| { - // if x1 == x2 - let ys_equal = builder.secp256k1_coord_eq(&y1, &y2); - builder.if_eq(ys_equal, C::N::one()).then_or_else( - |builder| { - // if y1 == y2 => point_1 == point_2, do double - let res_double = builder.secp256k1_double(point_1.clone()); - builder.assign(&res, res_double); - }, - |builder| { - // else y1 != y2 => x1 = x2, y1 = - y2 so point_1 + point_2 = identity - let identity = builder.array(2 * num_limbs); - for i in 0..2 * num_limbs { - builder.set(&identity, i, C::N::zero()); - } - builder.assign(&res, identity) - }, - ) - }, - |builder| { - // if x1 != x2 - let res_ne = - builder.secp256k1_add_unequal(point_1.clone(), point_2.clone()); - builder.assign(&res, res_ne); - }, - ) - }, - ) - }, - ); - res - } - - /// Assumes that `point_1 != +- point_2` which is equivalent to `point_1.x != point_2.x`. - /// Does not handle identity points. - /// - /// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes. - pub fn secp256k1_add_unequal( - &mut self, - point_1: Array>, - point_2: Array>, - ) -> Array> { - // TODO: enforce this is constant length - let dst = self.array(point_1.len()); - self.push(DslIr::Secp256k1AddUnequal(dst.clone(), point_1, point_2)); - dst - } - - /// Does not handle identity points. - /// - /// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes. - pub fn secp256k1_double(&mut self, point: Array>) -> Array> { - let dst = self.array(point.len()); - self.push(DslIr::Secp256k1Double(dst.clone(), point)); - dst - } - - /// Assert (x, y) is on the curve. - pub fn ec_is_on_curve(&mut self, x: &BigUintVar, y: &BigUintVar) -> Var { - let x2 = self.secp256k1_coord_mul(x, x); - let x3 = self.secp256k1_coord_mul(&x2, x); - let c7 = self.eval_biguint(7u64.into()); - let x3_plus_7 = self.secp256k1_coord_add(&x3, &c7); - let y2 = self.secp256k1_coord_mul(y, y); - self.secp256k1_coord_eq(&y2, &x3_plus_7) - } -} diff --git a/toolchain/native-compiler/src/ir/instructions.rs b/toolchain/native-compiler/src/ir/instructions.rs index 119c6e8d99..42874b2d4f 100644 --- a/toolchain/native-compiler/src/ir/instructions.rs +++ b/toolchain/native-compiler/src/ir/instructions.rs @@ -1,7 +1,4 @@ -use num_bigint_dig::BigUint; - use super::{Array, Config, Ext, Felt, MemIndex, Ptr, RVar, TracedVec, Var}; -use crate::ir::modular_arithmetic::BigUintVar; /// An intermeddiate instruction set for implementing programs. /// @@ -36,10 +33,6 @@ pub enum DslIr { AddEFI(Ext, Ext, C::F), /// Add a field element and an ext field immediate (ext = felt + ext field imm). AddEFFI(Ext, Felt, C::EF), - /// Add two modular BigInts over some field. - ModularAdd(BigUint, BigUintVar, BigUintVar, BigUintVar), - /// Add two 256-bit integers - Add256(BigUintVar, BigUintVar, BigUintVar), // Subtractions. /// Subtracts two variables (var = var - var). @@ -64,10 +57,6 @@ pub enum DslIr { SubEFI(Ext, Ext, C::F), /// Subtracts an extension field element and a field element (ext = ext - felt). SubEF(Ext, Ext, Felt), - /// Subtract two modular BigInts over some field. - ModularSub(BigUint, BigUintVar, BigUintVar, BigUintVar), - /// Subtract two 256-bit integers - Sub256(BigUintVar, BigUintVar, BigUintVar), // Multiplications. /// Multiplies two variables (var = var * var). @@ -86,10 +75,6 @@ pub enum DslIr { MulEFI(Ext, Ext, C::F), /// Multiplies an extension field element and a field element (ext = ext * felt). MulEF(Ext, Ext, Felt), - /// Multiply two modular BigInts over some field. - ModularMul(BigUint, BigUintVar, BigUintVar, BigUintVar), - /// Multiply two 256-bit integers - Mul256(BigUintVar, BigUintVar, BigUintVar), // Divisions. /// Divides two variables (var = var / var). @@ -108,8 +93,6 @@ pub enum DslIr { DivEFI(Ext, Ext, C::F), /// Divides an extension field element and a field element (ext = ext / felt). DivEF(Ext, Ext, Felt), - /// Divide two modular BigInts over some field. - ModularDiv(BigUint, BigUintVar, BigUintVar, BigUintVar), // Negations. /// Negates a variable (var = -var). @@ -124,28 +107,6 @@ pub enum DslIr { LessThanV(Var, Var, Var), /// Compares a variable and an immediate LessThanVI(Var, Var, C::N), - /// Compare two u256 for < - LessThanU256(Ptr, BigUintVar, BigUintVar), - /// Compare two 256-bit integers for == - EqualTo256(Ptr, BigUintVar, BigUintVar), - /// Compare two signed 256-bit integers for < - LessThanI256(Ptr, BigUintVar, BigUintVar), - - // Bitwise operations. - /// Bitwise XOR on two 256-bit integers - Xor256(BigUintVar, BigUintVar, BigUintVar), - /// Bitwise AND on two 256-bit integers - And256(BigUintVar, BigUintVar, BigUintVar), - /// Bitwise OR on two 256-bit integers - Or256(BigUintVar, BigUintVar, BigUintVar), - - // Shifts. - /// Shift left on 256-bit integers - ShiftLeft256(BigUintVar, BigUintVar, BigUintVar), - /// Shift right logical on 256-bit integers - ShiftRightLogic256(BigUintVar, BigUintVar, BigUintVar), - /// Shift right arithmetic on 256-bit integers - ShiftRightArith256(BigUintVar, BigUintVar, BigUintVar), // ======= @@ -248,24 +209,6 @@ pub enum DslIr { /// **little-endian**. The `output` is exactly 16 limbs (32 bytes). Keccak256(Array>, Array>), - /// ```ignore - /// Secp256k1AddUnequal(dst, p, q) - /// ``` - /// Reads `p,q` from heap and writes `dst = p + q` to heap. A point is represented on the heap - /// as two affine coordinates concatenated together into a byte array. - /// Assumes that `p.x != q.x` which is equivalent to `p != +-q`. - Secp256k1AddUnequal( - Array>, - Array>, - Array>, - ), - /// ```ignore - /// Secp256k1Double(dst, p) - /// ``` - /// Reads `p` from heap and writes `dst = p + p` to heap. A point is represented on the heap - /// as two affine coordinates concatenated together into a byte array. - Secp256k1Double(Array>, Array>), - // Miscellaneous instructions. /// Prints a variable. PrintV(Var), diff --git a/toolchain/native-compiler/src/ir/mod.rs b/toolchain/native-compiler/src/ir/mod.rs index 625fab175e..a962f92344 100644 --- a/toolchain/native-compiler/src/ir/mod.rs +++ b/toolchain/native-compiler/src/ir/mod.rs @@ -1,7 +1,6 @@ pub use builder::*; pub use collections::*; pub use instructions::*; -pub use modular_arithmetic::*; use p3_field::{ExtensionField, PrimeField, TwoAdicField}; pub use poseidon::{DIGEST_SIZE, PERMUTATION_WIDTH}; pub use ptr::*; @@ -15,17 +14,14 @@ pub use var::*; mod bits; mod builder; mod collections; -mod elliptic_curve; mod instructions; mod keccak; -mod modular_arithmetic; mod poseidon; mod ptr; mod ref_ptr; mod select; mod symbolic; mod types; -mod uint; mod utils; mod var; diff --git a/toolchain/native-compiler/src/ir/modular_arithmetic.rs b/toolchain/native-compiler/src/ir/modular_arithmetic.rs deleted file mode 100644 index b5ea96fc12..0000000000 --- a/toolchain/native-compiler/src/ir/modular_arithmetic.rs +++ /dev/null @@ -1,252 +0,0 @@ -use ax_circuit_primitives::bigint::utils::big_uint_to_num_limbs; -use axvm_circuit::{ - arch::Modulus, - intrinsics::modular::{SECP256K1_COORD_PRIME, SECP256K1_SCALAR_PRIME}, -}; -use num_bigint_dig::BigUint; -use num_traits::Zero; -use p3_field::{AbstractField, PrimeField64}; - -use super::{ - utils::{LIMB_BITS, NUM_LIMBS}, - Array, Builder, Config, DslIr, IfBuilder, Var, -}; - -pub type BigUintVar = Array::N>>; - -impl BigUintVar { - pub fn ptr_fp(&self) -> i32 { - match self { - Array::Fixed(_) => panic!(), - Array::Dyn(ptr, _) => ptr.fp(), - } - } -} - -impl Builder -where - C::N: PrimeField64, -{ - pub fn eval_biguint(&mut self, biguint: BigUint) -> BigUintVar { - let array = self.dyn_array(NUM_LIMBS); - - let elems: Vec = big_uint_to_num_limbs(&biguint, LIMB_BITS, NUM_LIMBS) - .into_iter() - .map(C::N::from_canonical_usize) - .collect(); - for (i, &elem) in elems.iter().enumerate() { - self.set(&array, i, elem); - } - - array - } - - pub fn uninit_biguint(&mut self) -> BigUintVar { - self.dyn_array(NUM_LIMBS) - } - - fn mod_operation( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - modulus: BigUint, - operation: impl Fn(BigUint, BigUintVar, BigUintVar, BigUintVar) -> DslIr, - ) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(operation(modulus, dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn secp256k1_coord_add( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Coord.prime(), - DslIr::ModularAdd, - ) - } - - pub fn secp256k1_coord_sub( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Coord.prime(), - DslIr::ModularSub, - ) - } - - pub fn secp256k1_coord_mul( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Coord.prime(), - DslIr::ModularMul, - ) - } - - pub fn secp256k1_coord_div( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Coord.prime(), - DslIr::ModularDiv, - ) - } - - pub fn assert_secp256k1_coord_eq(&mut self, left: &BigUintVar, right: &BigUintVar) { - let res = self.secp256k1_coord_eq(left, right); - self.assert_var_eq(res, C::N::one()); - } - - pub fn secp256k1_coord_is_zero(&mut self, biguint: &BigUintVar) -> Var { - // TODO: either EqU256 needs to support address space 0 or we just need better pointer handling here. - let ret_arr = self.array(1); - // FIXME: reuse constant zero. - let big_zero = self.eval_biguint(BigUint::zero()); - self.operations - .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_zero)); - let ret: Var<_> = self.get(&ret_arr, 0); - self.if_ne(ret, C::N::one()).then(|builder| { - // FIXME: reuse constant. - let big_n = builder.eval_biguint(SECP256K1_COORD_PRIME.clone()); - builder - .operations - .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_n)); - let _ret: Var<_> = builder.get(&ret_arr, 0); - builder.assign(&ret, _ret); - }); - ret - } - - pub fn secp256k1_coord_set_to_zero(&mut self, biguint: &BigUintVar) { - for i in 0..NUM_LIMBS { - self.set(biguint, i, C::N::zero()); - } - } - - pub fn secp256k1_coord_eq(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { - let diff = self.secp256k1_coord_sub(left, right); - self.secp256k1_coord_is_zero(&diff) - } - - pub fn if_secp256k1_coord_eq( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> IfBuilder { - let eq = self.secp256k1_coord_eq(left, right); - self.if_eq(eq, C::N::one()) - } - - pub fn secp256k1_scalar_add( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Scalar.prime(), - DslIr::ModularAdd, - ) - } - - pub fn secp256k1_scalar_sub( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Scalar.prime(), - DslIr::ModularSub, - ) - } - - pub fn secp256k1_scalar_mul( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Scalar.prime(), - DslIr::ModularMul, - ) - } - - pub fn secp256k1_scalar_div( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> BigUintVar { - self.mod_operation( - left, - right, - Modulus::Secp256k1Scalar.prime(), - DslIr::ModularDiv, - ) - } - - pub fn assert_secp256k1_scalar_eq(&mut self, left: &BigUintVar, right: &BigUintVar) { - let res = self.secp256k1_scalar_eq(left, right); - self.assert_var_eq(res, C::N::one()); - } - - pub fn secp256k1_scalar_is_zero(&mut self, biguint: &BigUintVar) -> Var { - // TODO: either EqU256 needs to support address space 0 or we just need better pointer handling here. - let ret_arr = self.array(1); - // FIXME: reuse constant zero. - let big_zero = self.eval_biguint(BigUint::zero()); - self.operations - .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_zero)); - let ret: Var<_> = self.get(&ret_arr, 0); - self.if_ne(ret, C::N::one()).then(|builder| { - // FIXME: reuse constant. - let big_n = builder.eval_biguint(SECP256K1_SCALAR_PRIME.clone()); - builder - .operations - .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_n)); - let _ret: Var<_> = builder.get(&ret_arr, 0); - builder.assign(&ret, _ret); - }); - ret - } - - pub fn secp256k1_scalar_eq( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> Var { - let diff = self.secp256k1_scalar_sub(left, right); - self.secp256k1_scalar_is_zero(&diff) - } - - pub fn if_secp256k1_scalar_eq( - &mut self, - left: &BigUintVar, - right: &BigUintVar, - ) -> IfBuilder { - let eq = self.secp256k1_scalar_eq(left, right); - self.if_eq(eq, C::N::one()) - } -} diff --git a/toolchain/native-compiler/src/ir/uint.rs b/toolchain/native-compiler/src/ir/uint.rs deleted file mode 100644 index 0fa243b758..0000000000 --- a/toolchain/native-compiler/src/ir/uint.rs +++ /dev/null @@ -1,103 +0,0 @@ -use p3_field::PrimeField64; - -use super::{modular_arithmetic::BigUintVar, utils::NUM_LIMBS, Var}; -use crate::ir::{Builder, Config, DslIr}; - -impl Builder -where - C::N: PrimeField64, -{ - pub fn add_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::Add256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn sub_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::Sub256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn mul_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::Mul256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn sltu_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { - let dst = self.array(1); - self.operations - .push(DslIr::LessThanU256(dst.ptr(), left.clone(), right.clone())); - self.get(&dst, 0) - } - - pub fn eq_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { - // let dst = self.alloc(1, as MemVariable>::size_of()); - let dst = self.array(1); - self.operations - .push(DslIr::EqualTo256(dst.ptr(), left.clone(), right.clone())); - self.get(&dst, 0) - } - - pub fn xor_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::Xor256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn and_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::And256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn or_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations - .push(DslIr::Or256(dst.clone(), left.clone(), right.clone())); - dst - } - - pub fn slt_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { - let dst = self.array(1); - self.operations - .push(DslIr::LessThanI256(dst.ptr(), left.clone(), right.clone())); - self.get(&dst, 0) - } - - pub fn sll_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations.push(DslIr::ShiftLeft256( - dst.clone(), - left.clone(), - right.clone(), - )); - dst - } - - pub fn srl_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations.push(DslIr::ShiftRightLogic256( - dst.clone(), - left.clone(), - right.clone(), - )); - dst - } - - pub fn sra_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { - let dst = self.dyn_array(NUM_LIMBS); - self.operations.push(DslIr::ShiftRightArith256( - dst.clone(), - left.clone(), - right.clone(), - )); - dst - } -} diff --git a/toolchain/native-compiler/tests/modular_arithmetic.rs b/toolchain/native-compiler/tests/modular_arithmetic.rs deleted file mode 100644 index 9cb161c75d..0000000000 --- a/toolchain/native-compiler/tests/modular_arithmetic.rs +++ /dev/null @@ -1,192 +0,0 @@ -use ax_stark_sdk::utils::create_seeded_rng; -use axvm_circuit::system::program::util::execute_program; -use axvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions, ir::Var}; -use num_bigint_dig::BigUint; -use num_traits::{FromPrimitive, One, Zero}; -use p3_baby_bear::BabyBear; -use p3_field::{extension::BinomialExtensionField, AbstractField, ExtensionField, TwoAdicField}; -use rand::RngCore; - -fn secp256k1_coord_prime() -> BigUint { - let mut result = BigUint::one() << 256; - for power in [32, 9, 8, 7, 6, 4, 0] { - result -= BigUint::one() << power; - } - result -} - -fn test_modular_arithmetic_program + TwoAdicField>( - builder: AsmBuilder, -) { - let program = builder.compile_isa_with_options(CompilerOptions { - word_size: 32, - ..Default::default() - }); - execute_program(program, vec![]); -} - -#[test] -fn test_compiler_modular_arithmetic_1() { - let a = BigUint::from_isize(31).unwrap(); - let b = BigUint::from_isize(115).unwrap(); - - let r = BigUint::from_isize(31 * 115).unwrap(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let a_var = builder.eval_biguint(a); - let b_var = builder.eval_biguint(b); - let r_var = builder.secp256k1_coord_mul(&a_var, &b_var); - let r_check_var = builder.eval_biguint(r); - builder.assert_secp256k1_coord_eq(&r_var, &r_check_var); - builder.halt(); - test_modular_arithmetic_program(builder); -} - -#[test] -fn test_compiler_modular_arithmetic_2() { - let num_digits = 8; - - let mut rng = create_seeded_rng(); - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits); - let b_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let b = BigUint::new(b_digits); - // if these are not true then trace is not guaranteed to be verifiable - assert!(a < secp256k1_coord_prime()); - assert!(b < secp256k1_coord_prime()); - - let r = (a.clone() * b.clone()) % secp256k1_coord_prime(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let a_var = builder.eval_biguint(a); - let b_var = builder.eval_biguint(b); - let r_var = builder.secp256k1_coord_mul(&a_var, &b_var); - let r_check_var = builder.eval_biguint(r); - builder.assert_secp256k1_coord_eq(&r_var, &r_check_var); - builder.halt(); - - test_modular_arithmetic_program(builder); -} - -#[test] -fn test_compiler_modular_arithmetic_conditional() { - let a = BigUint::from_isize(23).unwrap(); - let b = BigUint::from_isize(41).unwrap(); - - let r = BigUint::from_isize(23 * 41).unwrap(); - let s = BigUint::from_isize(1000).unwrap(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let a_var = builder.eval_biguint(a); - let b_var = builder.eval_biguint(b); - let product_var = builder.secp256k1_coord_mul(&a_var, &b_var); - let r_var = builder.eval_biguint(r); - let s_var = builder.eval_biguint(s); - - let should_be_1: Var = builder.uninit(); - let should_be_2: Var = builder.uninit(); - - builder - .if_secp256k1_coord_eq(&product_var, &r_var) - .then_or_else( - |builder| builder.assign(&should_be_1, F::one()), - |builder| builder.assign(&should_be_1, F::two()), - ); - builder - .if_secp256k1_coord_eq(&product_var, &s_var) - .then_or_else( - |builder| builder.assign(&should_be_2, F::one()), - |builder| builder.assign(&should_be_2, F::two()), - ); - - builder.assert_var_eq(should_be_1, F::one()); - builder.assert_var_eq(should_be_2, F::two()); - - builder.halt(); - - test_modular_arithmetic_program(builder); -} - -#[test] -#[should_panic] -fn test_compiler_modular_arithmetic_negative() { - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let one = builder.eval_biguint(BigUint::one()); - let one_times_one = builder.secp256k1_coord_mul(&one, &one); - let zero = builder.eval_biguint(BigUint::zero()); - - builder.assert_secp256k1_coord_eq(&one_times_one, &zero); - builder.halt(); - - test_modular_arithmetic_program(builder); -} - -#[test] -fn test_compiler_modular_scalar_arithmetic_conditional() { - let a = BigUint::from_isize(23).unwrap(); - let b = BigUint::from_isize(41).unwrap(); - - let r = BigUint::from_isize(23 * 41).unwrap(); - let s = BigUint::from_isize(1000).unwrap(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let a_var = builder.eval_biguint(a); - let b_var = builder.eval_biguint(b); - let product_var = builder.secp256k1_scalar_mul(&a_var, &b_var); - let r_var = builder.eval_biguint(r); - let s_var = builder.eval_biguint(s); - - let should_be_1: Var = builder.uninit(); - let should_be_2: Var = builder.uninit(); - - builder - .if_secp256k1_scalar_eq(&product_var, &r_var) - .then_or_else( - |builder| builder.assign(&should_be_1, F::one()), - |builder| builder.assign(&should_be_1, F::two()), - ); - builder - .if_secp256k1_scalar_eq(&product_var, &s_var) - .then_or_else( - |builder| builder.assign(&should_be_2, F::one()), - |builder| builder.assign(&should_be_2, F::two()), - ); - - builder.assert_var_eq(should_be_1, F::one()); - builder.assert_var_eq(should_be_2, F::two()); - - builder.halt(); - - test_modular_arithmetic_program(builder); -} - -#[test] -#[should_panic] -fn test_compiler_modular_scalar_arithmetic_negative() { - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - let one = builder.eval_biguint(BigUint::one()); - let one_times_one = builder.secp256k1_scalar_mul(&one, &one); - let zero = builder.eval_biguint(BigUint::zero()); - builder.assert_secp256k1_scalar_eq(&one_times_one, &zero); - builder.halt(); - - test_modular_arithmetic_program(builder); -} diff --git a/toolchain/native-compiler/tests/uint.rs b/toolchain/native-compiler/tests/uint.rs deleted file mode 100644 index 5b46b19598..0000000000 --- a/toolchain/native-compiler/tests/uint.rs +++ /dev/null @@ -1,361 +0,0 @@ -use std::iter; - -use ax_stark_sdk::utils::create_seeded_rng; -use axvm_circuit::{ - arch::{ExecutorName, VmConfig}, - system::program::util::{execute_program, execute_program_with_config}, -}; -use axvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions, ir::Var}; -use num_bigint_dig::BigUint; -use num_traits::Zero; -use p3_baby_bear::BabyBear; -use p3_field::{extension::BinomialExtensionField, AbstractField}; -use rand::{Rng, RngCore}; - -#[test] -fn test_compiler_256_add_sub() { - let num_digits = 8; - let num_ops = 15; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - let u256_modulus = BigUint::from(1u32) << 256; - - for _ in 0..num_ops { - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits); - let b_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - let add_flag = rng.gen_bool(0.5); - - let c = if add_flag { - (a.clone() + b.clone()) % u256_modulus.clone() - } else { - (a.clone() + (u256_modulus.clone() - b.clone())) % u256_modulus.clone() - }; - - let c_var = if add_flag { - builder.add_256(&a_var, &b_var) - } else { - builder.sub_256(&a_var, &b_var) - }; - let c_check_var = builder.eval_biguint(c); - builder.assert_var_array_eq(&c_var, &c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa(); - execute_program(program, vec![]); -} - -#[test] -fn test_compiler_256_mul() { - let num_digits = 8; - let num_ops = 10; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - let u256_modulus = BigUint::from(1u32) << 256; - - for _ in 0..num_ops { - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits); - let b_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - let c = (a.clone() * b.clone()) % u256_modulus.clone(); - let c_var = builder.mul_256(&a_var, &b_var); - - let c_check_var = builder.eval_biguint(c); - builder.assert_var_array_eq(&c_var, &c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa(); - execute_program_with_config( - VmConfig { - num_public_values: 4, - max_segment_len: (1 << 25) - 100, - ..Default::default() - } - .add_executor(ExecutorName::Phantom) - .add_executor(ExecutorName::LoadStore) - .add_executor(ExecutorName::BranchEqual) - .add_executor(ExecutorName::Jal) - .add_executor(ExecutorName::FieldArithmetic) - .add_executor(ExecutorName::U256Multiplication), - program, - vec![], - ); -} - -#[test] -fn test_compiler_256_sltu_eq() { - let num_digits = 8; - let num_ops = 15; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - for _ in 0..num_ops { - let lt_flag = rng.gen_bool(0.5); - - let a_digits: Vec = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits.clone()); - let b_digits = if lt_flag || rng.gen_bool(0.5) { - (0..num_digits).map(|_| rng.next_u32()).collect() - } else { - a_digits.clone() - }; - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - let c = if lt_flag { - a.clone() < b.clone() - } else { - a.clone() == b.clone() - }; - - let c_var = if lt_flag { - builder.sltu_256(&a_var, &b_var) - } else { - builder.eq_256(&a_var, &b_var) - }; - - let c_check_var: Var<_> = builder.eval(F::from_bool(c)); - builder.assert_var_eq(c_var, c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa_with_options(CompilerOptions { - word_size: 32, - ..Default::default() - }); - execute_program(program, vec![]); -} - -#[test] -fn test_compiler_256_slt_eq() { - let num_digits = 8; - let num_ops = 15; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - let msb_mask: u32 = 1 << 31; - - for _ in 0..num_ops { - let a_digits: Vec = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits.clone()); - let b_digits: Vec = (0..num_digits).map(|_| rng.next_u32()).collect(); - let b = BigUint::new(b_digits.clone()); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - let same_sign = - (a_digits[num_digits - 1] & msb_mask) == (b_digits[num_digits - 1] & msb_mask); - - let c = if same_sign { - a.clone() < b.clone() - } else { - a_digits[num_digits - 1] & msb_mask == msb_mask - }; - - let c_var = builder.slt_256(&a_var, &b_var); - let c_check_var: Var<_> = builder.eval(F::from_bool(c)); - builder.assert_var_eq(c_var, c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa_with_options(CompilerOptions { - word_size: 32, - ..Default::default() - }); - execute_program(program, vec![]); -} - -#[test] -fn test_compiler_256_xor_and_or() { - let num_digits = 8; - let num_ops = 20; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - for _ in 0..num_ops { - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let a = BigUint::new(a_digits); - let b_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - // xor = 0, and = 1, or = 2 - let flag: u8 = rng.gen_range(0..=2); - - let c = if flag == 0 { - a.clone() ^ b.clone() - } else if flag == 1 { - a.clone() & b.clone() - } else { - a.clone() | b.clone() - }; - - let c_var = if flag == 0 { - builder.xor_256(&a_var, &b_var) - } else if flag == 1 { - builder.and_256(&a_var, &b_var) - } else { - builder.or_256(&a_var, &b_var) - }; - let c_check_var = builder.eval_biguint(c); - builder.assert_var_array_eq(&c_var, &c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa(); - execute_program(program, vec![]); -} - -#[test] -fn test_compiler_256_sll_srl() { - let num_digits = 8; - let num_ops = 15; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - - for _ in 0..num_ops { - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect::>(); - let a = BigUint::new(a_digits.clone()); - - let b_shift = rng.gen_range(0..=64); - let b_digits = iter::once(b_shift as u32) - .chain(iter::repeat(0u32)) - .take(num_digits) - .collect::>(); - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - // sll = 0, srl = 1 - let sll_flag = rng.gen_bool(0.5); - - let c = if sll_flag { - a.clone() << b_shift - } else { - a.clone() >> b_shift - }; - - let c_var = if sll_flag { - builder.sll_256(&a_var, &b_var) - } else { - builder.srl_256(&a_var, &b_var) - }; - let c_check_var = builder.eval_biguint(c); - builder.assert_var_array_eq(&c_var, &c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa(); - execute_program_with_config( - VmConfig { - num_public_values: 4, - max_segment_len: (1 << 25) - 100, - ..Default::default() - } - .add_executor(ExecutorName::LoadStore) - .add_executor(ExecutorName::BranchEqual) - .add_executor(ExecutorName::Jal) - .add_executor(ExecutorName::Shift256) - .add_executor(ExecutorName::FieldArithmetic), - program, - vec![], - ); -} - -#[test] -fn test_compiler_256_sra() { - let num_digits = 8; - let num_ops = 10; - let mut rng = create_seeded_rng(); - - type F = BabyBear; - type EF = BinomialExtensionField; - let mut builder = AsmBuilder::::default(); - let msb_mask: u32 = 1 << 31; - - for _ in 0..num_ops { - let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect::>(); - let a_sign = a_digits[num_digits - 1] & msb_mask == msb_mask; - let a = BigUint::new(a_digits.clone()); - - let b_shift = rng.gen_range(0..=256); - let b_digits = iter::once(b_shift as u32) - .chain(iter::repeat(0u32)) - .take(num_digits) - .collect::>(); - let b = BigUint::new(b_digits); - - let a_var = builder.eval_biguint(a.clone()); - let b_var = builder.eval_biguint(b.clone()); - - let ones = iter::repeat(0) - .take((256 - b_shift) / 32) - .chain(iter::once(u32::MAX << (32 - (b_shift % 32)))) - .chain(iter::repeat(u32::MAX)) - .take(num_digits) - .collect::>(); - - let c = (a.clone() >> b_shift) - + if a_sign { - BigUint::new(ones) - } else { - BigUint::zero() - }; - - let c_var = builder.sra_256(&a_var, &b_var); - let c_check_var = builder.eval_biguint(c); - builder.assert_var_array_eq(&c_var, &c_check_var); - } - builder.halt(); - - let program = builder.clone().compile_isa(); - execute_program_with_config( - VmConfig { - num_public_values: 4, - max_segment_len: (1 << 25) - 100, - ..Default::default() - } - .add_executor(ExecutorName::LoadStore) - .add_executor(ExecutorName::BranchEqual) - .add_executor(ExecutorName::Jal) - .add_executor(ExecutorName::Shift256) - .add_executor(ExecutorName::FieldArithmetic), - program, - vec![], - ); -} diff --git a/vm/src/arch/chip_set.rs b/vm/src/arch/chip_set.rs index 8089265cd4..3afb195423 100644 --- a/vm/src/arch/chip_set.rs +++ b/vm/src/arch/chip_set.rs @@ -7,6 +7,7 @@ use std::{ sync::Arc, }; +use adapters::Rv32HeapAdapterChip; use ax_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, @@ -41,6 +42,9 @@ use crate::{ sw::{EcAddNeChip, EcDoubleChip}, }, hashes::{keccak::hasher::KeccakVmChip, poseidon2::Poseidon2Chip}, + int256::{ + Rv32BaseAlu256Chip, Rv32LessThan256Chip, Rv32Multiplication256Chip, Rv32Shift256Chip, + }, modular::{ ModularAddSubChip, ModularAddSubCoreChip, ModularMulDivChip, ModularMulDivCoreChip, }, @@ -62,9 +66,6 @@ use crate::{ modular::{KernelModularAddSubChip, KernelModularMulDivChip}, public_values::{core::PublicValuesCoreChip, PublicValuesChip}, }, - old::{ - alu::ArithmeticLogicChip, shift::ShiftChip, uint_multiplication::UintMultiplicationChip, - }, rv32im::{ adapters::{ Rv32BaseAluAdapterChip, Rv32BranchAdapterChip, Rv32CondRdWriteAdapterChip, @@ -256,7 +257,7 @@ impl VmConfig { let mut required_executors: BTreeSet<_> = self.executors.clone().into_iter().collect(); let mut chips = vec![]; - let mul_u256_enabled = required_executors.contains(&ExecutorName::U256Multiplication); + let mul_u256_enabled = required_executors.contains(&ExecutorName::Multiplication256Rv32); let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [(1 << 8), if mul_u256_enabled { 32 } else { 8 } * (1 << 8)], @@ -404,7 +405,7 @@ impl VmConfig { } chips.push(AxVmChip::Keccak256(chip)); } - ExecutorName::ArithmeticLogicUnitRv32 => { + ExecutorName::BaseAluRv32 => { let chip = Rc::new(RefCell::new(Rv32BaseAluChip::new( Rv32BaseAluAdapterChip::new( execution_bus, @@ -417,22 +418,7 @@ impl VmConfig { for opcode in range { executors.insert(opcode, chip.clone().into()); } - chips.push(AxVmChip::ArithmeticLogicUnitRv32(chip)); - } - ExecutorName::ArithmeticLogicUnit256 => { - // We probably must include this chip if we include any modular arithmetic, - // not sure if we need to enforce this here. - let chip = Rc::new(RefCell::new(ArithmeticLogicChip::new( - execution_bus, - program_bus, - memory_controller.clone(), - bitwise_lookup_chip.clone(), - offset, - ))); - for opcode in range { - executors.insert(opcode, chip.clone().into()); - } - chips.push(AxVmChip::ArithmeticLogicUnit256(chip)); + chips.push(AxVmChip::BaseAluRv32(chip)); } ExecutorName::LessThanRv32 => { let chip = Rc::new(RefCell::new(Rv32LessThanChip::new( @@ -483,19 +469,6 @@ impl VmConfig { } chips.push(AxVmChip::MultiplicationHighRv32(chip)); } - ExecutorName::U256Multiplication => { - let chip = Rc::new(RefCell::new(UintMultiplicationChip::new( - execution_bus, - program_bus, - memory_controller.clone(), - range_tuple_checker.clone(), - offset, - ))); - for opcode in range { - executors.insert(opcode, chip.clone().into()); - } - chips.push(AxVmChip::U256Multiplication(chip)); - } ExecutorName::DivRemRv32 => { let chip = Rc::new(RefCell::new(Rv32DivRemChip::new( Rv32MultAdapterChip::new( @@ -534,19 +507,6 @@ impl VmConfig { } chips.push(AxVmChip::ShiftRv32(chip)); } - ExecutorName::Shift256 => { - let chip = Rc::new(RefCell::new(ShiftChip::new( - execution_bus, - program_bus, - memory_controller.clone(), - bitwise_lookup_chip.clone(), - offset, - ))); - for opcode in range { - executors.insert(opcode, chip.clone().into()); - } - chips.push(AxVmChip::Shift256(chip)); - } ExecutorName::LoadStoreRv32 => { let chip = Rc::new(RefCell::new(Rv32LoadStoreChip::new( Rv32LoadStoreAdapterChip::new( @@ -676,6 +636,70 @@ impl VmConfig { } chips.push(AxVmChip::AuipcRv32(chip)); } + ExecutorName::BaseAlu256Rv32 => { + let chip = Rc::new(RefCell::new(Rv32BaseAlu256Chip::new( + Rv32HeapAdapterChip::new( + execution_bus, + program_bus, + memory_controller.clone(), + ), + BaseAluCoreChip::new(bitwise_lookup_chip.clone(), offset), + memory_controller.clone(), + ))); + for opcode in range { + executors.insert(opcode, chip.clone().into()); + } + chips.push(AxVmChip::BaseAlu256Rv32(chip)); + } + ExecutorName::LessThan256Rv32 => { + let chip = Rc::new(RefCell::new(Rv32LessThan256Chip::new( + Rv32HeapAdapterChip::new( + execution_bus, + program_bus, + memory_controller.clone(), + ), + LessThanCoreChip::new(bitwise_lookup_chip.clone(), offset), + memory_controller.clone(), + ))); + for opcode in range { + executors.insert(opcode, chip.clone().into()); + } + chips.push(AxVmChip::LessThan256Rv32(chip)); + } + ExecutorName::Multiplication256Rv32 => { + let chip = Rc::new(RefCell::new(Rv32Multiplication256Chip::new( + Rv32HeapAdapterChip::new( + execution_bus, + program_bus, + memory_controller.clone(), + ), + MultiplicationCoreChip::new(range_tuple_checker.clone(), offset), + memory_controller.clone(), + ))); + for opcode in range { + executors.insert(opcode, chip.clone().into()); + } + chips.push(AxVmChip::Multiplication256Rv32(chip)); + } + ExecutorName::Shift256Rv32 => { + let chip = Rc::new(RefCell::new(Rv32Shift256Chip::new( + Rv32HeapAdapterChip::new( + execution_bus, + program_bus, + memory_controller.clone(), + ), + ShiftCoreChip::new( + bitwise_lookup_chip.clone(), + range_checker.clone(), + offset, + ), + memory_controller.clone(), + ))); + for opcode in range { + executors.insert(opcode, chip.clone().into()); + } + chips.push(AxVmChip::Shift256Rv32(chip)); + } ExecutorName::CastF => { let chip = Rc::new(RefCell::new(CastFChip::new( ConvertAdapterChip::new( @@ -1224,7 +1248,7 @@ fn default_executor_range(executor: ExecutorName) -> (Range, usize) { Keccak256Opcode::COUNT, Keccak256Opcode::default_offset(), ), - ExecutorName::ArithmeticLogicUnitRv32 => ( + ExecutorName::BaseAluRv32 => ( BaseAluOpcode::default_offset(), BaseAluOpcode::COUNT, BaseAluOpcode::default_offset(), @@ -1261,16 +1285,21 @@ fn default_executor_range(executor: ExecutorName) -> (Range, usize) { Rv32AuipcOpcode::COUNT, Rv32AuipcOpcode::default_offset(), ), - ExecutorName::ArithmeticLogicUnit256 => ( - U256Opcode::default_offset(), - 8, - U256Opcode::default_offset(), + ExecutorName::BaseAlu256Rv32 => ( + Rv32BaseAlu256Opcode::default_offset(), + BaseAluOpcode::COUNT, + Rv32BaseAlu256Opcode::default_offset(), ), ExecutorName::LessThanRv32 => ( LessThanOpcode::default_offset(), LessThanOpcode::COUNT, LessThanOpcode::default_offset(), ), + ExecutorName::LessThan256Rv32 => ( + Rv32LessThan256Opcode::default_offset(), + LessThanOpcode::COUNT, + Rv32LessThan256Opcode::default_offset(), + ), ExecutorName::MultiplicationRv32 => ( MulOpcode::default_offset(), MulOpcode::COUNT, @@ -1281,10 +1310,10 @@ fn default_executor_range(executor: ExecutorName) -> (Range, usize) { MulHOpcode::COUNT, MulHOpcode::default_offset(), ), - ExecutorName::U256Multiplication => ( - U256Opcode::default_offset() + 11, - 1, - U256Opcode::default_offset(), + ExecutorName::Multiplication256Rv32 => ( + Rv32Mul256Opcode::default_offset(), + MulOpcode::COUNT, + Rv32Mul256Opcode::default_offset(), ), ExecutorName::DivRemRv32 => ( DivRemOpcode::default_offset(), @@ -1296,10 +1325,10 @@ fn default_executor_range(executor: ExecutorName) -> (Range, usize) { ShiftOpcode::COUNT, ShiftOpcode::default_offset(), ), - ExecutorName::Shift256 => ( - U256Opcode::default_offset() + 8, - 3, - U256Opcode::default_offset(), + ExecutorName::Shift256Rv32 => ( + Rv32Shift256Opcode::default_offset(), + ShiftOpcode::COUNT, + Rv32Shift256Opcode::default_offset(), ), ExecutorName::BranchEqualRv32 => ( BranchEqualOpcode::default_offset(), diff --git a/vm/src/arch/chips.rs b/vm/src/arch/chips.rs index f25b16b02e..edbfb82aba 100644 --- a/vm/src/arch/chips.rs +++ b/vm/src/arch/chips.rs @@ -20,6 +20,9 @@ use crate::{ sw::{EcAddNeChip, EcDoubleChip}, }, hashes::{keccak::hasher::KeccakVmChip, poseidon2::Poseidon2Chip}, + int256::{ + Rv32BaseAlu256Chip, Rv32LessThan256Chip, Rv32Multiplication256Chip, Rv32Shift256Chip, + }, modular::{ModularAddSubChip, ModularMulDivChip}, }, kernels::{ @@ -32,9 +35,6 @@ use crate::{ modular::{KernelModularAddSubChip, KernelModularMulDivChip}, public_values::PublicValuesChip, }, - old::{ - alu::ArithmeticLogicChip, shift::ShiftChip, uint_multiplication::UintMultiplicationChip, - }, rv32im::*, system::{phantom::PhantomChip, program::ExecutionError}, }; @@ -84,15 +84,13 @@ pub enum AxVmInstructionExecutor { PublicValues(Rc>>), Poseidon2(Rc>>), Keccak256(Rc>>), - ArithmeticLogicUnitRv32(Rc>>), - ArithmeticLogicUnit256(Rc>>), + /// Rv32 (for standard 32-bit integers): + BaseAluRv32(Rc>>), LessThanRv32(Rc>>), MultiplicationRv32(Rc>>), MultiplicationHighRv32(Rc>>), - U256Multiplication(Rc>>), DivRemRv32(Rc>>), ShiftRv32(Rc>>), - Shift256(Rc>>), LoadStoreRv32(Rc>>), LoadSignExtendRv32(Rc>>), HintStoreRv32(Rc>>), @@ -101,6 +99,11 @@ pub enum AxVmInstructionExecutor { JalLuiRv32(Rc>>), JalrRv32(Rc>>), AuipcRv32(Rc>>), + /// 256Rv32 (for 256-bit integers): + BaseAlu256Rv32(Rc>>), + LessThan256Rv32(Rc>>), + Multiplication256Rv32(Rc>>), + Shift256Rv32(Rc>>), // Intrinsics: ModularAddSubRv32_1x32(Rc>>), ModularMulDivRv32_1x32(Rc>>), @@ -135,15 +138,16 @@ pub enum AxVmChip { RangeTupleChecker(Arc>), Keccak256(Rc>>), BitwiseOperationLookup(Arc>), - ArithmeticLogicUnitRv32(Rc>>), - ArithmeticLogicUnit256(Rc>>), + BaseAluRv32(Rc>>), + BaseAlu256Rv32(Rc>>), LessThanRv32(Rc>>), + LessThan256Rv32(Rc>>), MultiplicationRv32(Rc>>), MultiplicationHighRv32(Rc>>), - U256Multiplication(Rc>>), + Multiplication256Rv32(Rc>>), DivRemRv32(Rc>>), ShiftRv32(Rc>>), - Shift256(Rc>>), + Shift256Rv32(Rc>>), LoadStoreRv32(Rc>>), LoadSignExtendRv32(Rc>>), HintStoreRv32(Rc>>), diff --git a/vm/src/arch/config.rs b/vm/src/arch/config.rs index d03396a905..179553520f 100644 --- a/vm/src/arch/config.rs +++ b/vm/src/arch/config.rs @@ -92,6 +92,16 @@ impl VmConfig { self } + pub fn add_int256_alu(self) -> Self { + self.add_executor(ExecutorName::BaseAlu256Rv32) + .add_executor(ExecutorName::LessThan256Rv32) + .add_executor(ExecutorName::Shift256Rv32) + } + + pub fn add_int256_m(self) -> Self { + self.add_executor(ExecutorName::Multiplication256Rv32) + } + pub fn add_modular_support(self, enabled_modulus: Vec) -> Self { let mut res = self; res.supported_modulus.extend(enabled_modulus); @@ -149,7 +159,7 @@ impl VmConfig { ..Default::default() } .add_executor(ExecutorName::Phantom) - .add_executor(ExecutorName::ArithmeticLogicUnitRv32) + .add_executor(ExecutorName::BaseAluRv32) .add_executor(ExecutorName::LessThanRv32) .add_executor(ExecutorName::ShiftRv32) .add_executor(ExecutorName::LoadStoreRv32) diff --git a/vm/src/arch/integration_api.rs b/vm/src/arch/integration_api.rs index ab9bc6cafa..f777868a02 100644 --- a/vm/src/arch/integration_api.rs +++ b/vm/src/arch/integration_api.rs @@ -399,25 +399,32 @@ impl< pub struct VecHeapAdapterInterface< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, >(PhantomData); impl< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > VmAdapterInterface - for VecHeapAdapterInterface + for VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > { - type Reads = [[[T; READ_SIZE]; NUM_READS]; R]; - type Writes = [[T; WRITE_SIZE]; NUM_WRITES]; + type Reads = [[[T; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type Writes = [[T; WRITE_SIZE]; BLOCKS_PER_WRITE]; type ProcessedInstruction = MinimalInstruction; } @@ -462,23 +469,37 @@ mod conversions { // AdapterAirContext: VecHeapAdapterInterface -> DynInterface impl< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > From< AdapterAirContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, >, > for AdapterAirContext> { fn from( ctx: AdapterAirContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, >, ) -> Self { AdapterAirContext { @@ -493,23 +514,37 @@ mod conversions { // AdapterRuntimeContext: VecHeapAdapterInterface -> DynInterface impl< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > From< AdapterRuntimeContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, >, > for AdapterRuntimeContext> { fn from( ctx: AdapterRuntimeContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, >, ) -> Self { AdapterRuntimeContext { @@ -522,15 +557,22 @@ mod conversions { // AdapterAirContext: DynInterface -> VecHeapAdapterInterface impl< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > From>> for AdapterAirContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, > { fn from(ctx: AdapterAirContext>) -> Self { @@ -546,15 +588,22 @@ mod conversions { // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface impl< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > From>> for AdapterRuntimeContext< T, - VecHeapAdapterInterface, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, > { fn from(ctx: AdapterRuntimeContext>) -> Self { @@ -565,6 +614,131 @@ mod conversions { } } + // AdapterRuntimeContext: BasicInterface -> VecHeapAdapterInterface + impl< + T, + PI, + const BASIC_NUM_READS: usize, + const BASIC_NUM_WRITES: usize, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > + From< + AdapterRuntimeContext< + T, + BasicAdapterInterface< + T, + PI, + BASIC_NUM_READS, + BASIC_NUM_WRITES, + READ_SIZE, + WRITE_SIZE, + >, + >, + > + for AdapterRuntimeContext< + T, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, + > + { + fn from( + ctx: AdapterRuntimeContext< + T, + BasicAdapterInterface< + T, + PI, + BASIC_NUM_READS, + BASIC_NUM_WRITES, + READ_SIZE, + WRITE_SIZE, + >, + >, + ) -> Self { + assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE); + let mut writes_it = ctx.writes.into_iter(); + let writes = from_fn(|_| writes_it.next().unwrap()); + AdapterRuntimeContext { + to_pc: ctx.to_pc, + writes, + } + } + } + + // AdapterAirContext: BasicInterface -> VecHeapAdapterInterface + impl< + T, + PI: Into>, + const BASIC_NUM_READS: usize, + const BASIC_NUM_WRITES: usize, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > + From< + AdapterAirContext< + T, + BasicAdapterInterface< + T, + PI, + BASIC_NUM_READS, + BASIC_NUM_WRITES, + READ_SIZE, + WRITE_SIZE, + >, + >, + > + for AdapterAirContext< + T, + VecHeapAdapterInterface< + T, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, + > + { + fn from( + ctx: AdapterAirContext< + T, + BasicAdapterInterface< + T, + PI, + BASIC_NUM_READS, + BASIC_NUM_WRITES, + READ_SIZE, + WRITE_SIZE, + >, + >, + ) -> Self { + assert_eq!(BASIC_NUM_READS, NUM_READS * BLOCKS_PER_READ); + let mut reads_it = ctx.reads.into_iter(); + let reads = from_fn(|_| from_fn(|_| reads_it.next().unwrap())); + assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE); + let mut writes_it = ctx.writes.into_iter(); + let writes = from_fn(|_| writes_it.next().unwrap()); + AdapterAirContext { + to_pc: ctx.to_pc, + reads, + writes, + instruction: ctx.instruction.into(), + } + } + } + // AdapterAirContext: FlatInterface -> BasicInterface impl< T, diff --git a/vm/src/intrinsics/int256/mod.rs b/vm/src/intrinsics/int256/mod.rs new file mode 100644 index 0000000000..31b129b5b1 --- /dev/null +++ b/vm/src/intrinsics/int256/mod.rs @@ -0,0 +1,34 @@ +use crate::{ + arch::VmChipWrapper, + rv32im::{ + adapters::{Rv32HeapAdapterChip, INT256_NUM_LIMBS, RV32_CELL_BITS}, + BaseAluCoreChip, LessThanCoreChip, MultiplicationCoreChip, ShiftCoreChip, + }, +}; + +#[cfg(test)] +mod tests; + +pub type Rv32BaseAlu256Chip = VmChipWrapper< + F, + Rv32HeapAdapterChip, + BaseAluCoreChip, +>; + +pub type Rv32LessThan256Chip = VmChipWrapper< + F, + Rv32HeapAdapterChip, + LessThanCoreChip, +>; + +pub type Rv32Multiplication256Chip = VmChipWrapper< + F, + Rv32HeapAdapterChip, + MultiplicationCoreChip, +>; + +pub type Rv32Shift256Chip = VmChipWrapper< + F, + Rv32HeapAdapterChip, + ShiftCoreChip, +>; diff --git a/vm/src/intrinsics/int256/tests.rs b/vm/src/intrinsics/int256/tests.rs new file mode 100644 index 0000000000..164b0f020e --- /dev/null +++ b/vm/src/intrinsics/int256/tests.rs @@ -0,0 +1,202 @@ +use std::sync::Arc; + +use ax_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}, +}; +use ax_stark_sdk::utils::create_seeded_rng; +use axvm_instructions::{ + riscv::RV32_CELL_BITS, BaseAluOpcode, LessThanOpcode, MulOpcode, ShiftOpcode, +}; +use p3_baby_bear::BabyBear; +use p3_field::AbstractField; + +use super::{Rv32BaseAlu256Chip, Rv32LessThan256Chip, Rv32Multiplication256Chip, Rv32Shift256Chip}; +use crate::{ + arch::{ + testing::VmChipTestBuilder, InstructionExecutor, BITWISE_OP_LOOKUP_BUS, + RANGE_TUPLE_CHECKER_BUS, + }, + rv32im::{ + adapters::{Rv32HeapAdapterChip, INT256_NUM_LIMBS}, + BaseAluCoreChip, LessThanCoreChip, MultiplicationCoreChip, ShiftCoreChip, + }, + utils::{generate_long_number, rv32_write_heap_default}, +}; + +type F = BabyBear; + +fn run_int_256_rand_execute>( + opcode: usize, + executor: &mut E, + tester: &mut VmChipTestBuilder, + num_ops: usize, +) { + let mut rng = create_seeded_rng(); + for _ in 0..num_ops { + let b = generate_long_number::(&mut rng); + let c = generate_long_number::(&mut rng); + let instruction = rv32_write_heap_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + opcode, + ); + tester.execute(executor, instruction); + } +} + +fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let mut tester = VmChipTestBuilder::default(); + let mut chip = Rv32BaseAlu256Chip::::new( + Rv32HeapAdapterChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_controller(), + ), + BaseAluCoreChip::new(bitwise_chip.clone(), 0), + tester.memory_controller(), + ); + + run_int_256_rand_execute(opcode as usize, &mut chip, &mut tester, num_ops); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn alu_256_add_rand_test() { + run_alu_256_rand_test(BaseAluOpcode::ADD, 24); +} + +#[test] +fn alu_256_sub_rand_test() { + run_alu_256_rand_test(BaseAluOpcode::SUB, 24); +} + +#[test] +fn alu_256_xor_rand_test() { + run_alu_256_rand_test(BaseAluOpcode::XOR, 24); +} + +#[test] +fn alu_256_or_rand_test() { + run_alu_256_rand_test(BaseAluOpcode::OR, 24); +} + +#[test] +fn alu_256_and_rand_test() { + run_alu_256_rand_test(BaseAluOpcode::AND, 24); +} + +fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let mut tester = VmChipTestBuilder::default(); + let mut chip = Rv32LessThan256Chip::::new( + Rv32HeapAdapterChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_controller(), + ), + LessThanCoreChip::new(bitwise_chip.clone(), 0), + tester.memory_controller(), + ); + + run_int_256_rand_execute(opcode as usize, &mut chip, &mut tester, num_ops); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn lt_256_slt_rand_test() { + run_lt_256_rand_test(LessThanOpcode::SLT, 24); +} + +#[test] +fn lt_256_sltu_rand_test() { + run_lt_256_rand_test(LessThanOpcode::SLTU, 24); +} + +fn run_mul_256_rand_test(num_ops: usize) { + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [ + 1 << RV32_CELL_BITS, + (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, + ], + ); + let range_tuple_checker = Arc::new(RangeTupleCheckerChip::new(range_tuple_bus)); + + let mut tester = VmChipTestBuilder::default(); + let mut chip = Rv32Multiplication256Chip::::new( + Rv32HeapAdapterChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_controller(), + ), + MultiplicationCoreChip::new(range_tuple_checker.clone(), 0), + tester.memory_controller(), + ); + + run_int_256_rand_execute(MulOpcode::MUL as usize, &mut chip, &mut tester, num_ops); + let tester = tester + .build() + .load(chip) + .load(range_tuple_checker) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn mul_256_rand_test() { + run_mul_256_rand_test(24); +} + +fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let mut tester = VmChipTestBuilder::default(); + let mut chip = Rv32Shift256Chip::::new( + Rv32HeapAdapterChip::::new( + tester.execution_bus(), + tester.program_bus(), + tester.memory_controller(), + ), + ShiftCoreChip::new( + bitwise_chip.clone(), + tester.memory_controller().borrow().range_checker.clone(), + 0, + ), + tester.memory_controller(), + ); + + run_int_256_rand_execute(opcode as usize, &mut chip, &mut tester, num_ops); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn shift_256_sll_rand_test() { + run_shift_256_rand_test(ShiftOpcode::SLL, 24); +} + +#[test] +fn shift_256_srl_rand_test() { + run_shift_256_rand_test(ShiftOpcode::SRL, 24); +} + +#[test] +fn shift_256_sra_rand_test() { + run_shift_256_rand_test(ShiftOpcode::SRA, 24); +} diff --git a/vm/src/intrinsics/mod.rs b/vm/src/intrinsics/mod.rs index e470e445c6..a3d38742d6 100644 --- a/vm/src/intrinsics/mod.rs +++ b/vm/src/intrinsics/mod.rs @@ -1,6 +1,7 @@ pub mod ecc; pub mod field_expression; pub mod hashes; +pub mod int256; pub mod modular; #[cfg(test)] diff --git a/vm/src/lib.rs b/vm/src/lib.rs index c9f40a4c1a..4d05af629b 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -17,6 +17,3 @@ pub mod rv32im; pub mod system; /// Utility functions and test utils pub mod utils; - -// To be deleted: -pub mod old; diff --git a/vm/src/old/alu/air.rs b/vm/src/old/alu/air.rs deleted file mode 100644 index b412d801c9..0000000000 --- a/vm/src/old/alu/air.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::{array, borrow::Borrow}; - -use ax_circuit_primitives::{bitwise_op_lookup::BitwiseOperationLookupBus, utils}; -use ax_stark_backend::{ - interaction::InteractionBuilder, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; -use strum::IntoEnumIterator; - -use super::columns::ArithmeticLogicCols; -use crate::{ - arch::{instructions::U256Opcode, ExecutionBridge}, - system::memory::offline_checker::MemoryBridge, -}; - -#[derive(Copy, Clone, Debug)] -pub struct ArithmeticLogicCoreAir { - pub(super) execution_bridge: ExecutionBridge, - pub(super) memory_bridge: MemoryBridge, - pub bus: BitwiseOperationLookupBus, - - pub(super) offset: usize, -} - -impl PartitionedBaseAir - for ArithmeticLogicCoreAir -{ -} -impl BaseAir - for ArithmeticLogicCoreAir -{ - fn width(&self) -> usize { - ArithmeticLogicCols::::width() - } -} - -impl BaseAirWithPublicValues - for ArithmeticLogicCoreAir -{ -} - -impl Air - for ArithmeticLogicCoreAir -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - - let ArithmeticLogicCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow(); - builder.assert_bool(aux.is_valid); - - let flags = [ - aux.opcode_add_flag, - aux.opcode_sub_flag, - aux.opcode_sltu_flag, - aux.opcode_eq_flag, - aux.opcode_xor_flag, - aux.opcode_and_flag, - aux.opcode_or_flag, - aux.opcode_slt_flag, - ]; - for flag in flags { - builder.assert_bool(flag); - } - - builder.assert_eq( - aux.is_valid, - flags - .iter() - .fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()), - ); - - let x_limbs = &io.x.data; - let y_limbs = &io.y.data; - let z_limbs = &io.z.data; - - // For ADD, define carry[i] = (x[i] + y[i] + carry[i - 1] - z[i]) / 2^LIMB_BITS. If - // each carry[i] is boolean and 0 <= z[i] < 2^NUM_LIMBS, it can be proven that - // z[i] = (x[i] + y[i]) % 256 as necessary. The same holds for SUB when carry[i] is - // (z[i] + y[i] - x[i] + carry[i - 1]) / 2^LIMB_BITS. - let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); - let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::zero()); - let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse(); - - for i in 0..NUM_LIMBS { - // We explicitly separate the constraints for ADD and SUB in order to keep degree - // cubic. Because we constrain that the carry (which is arbitrary) is bool, if - // carry has degree larger than 1 the max-degree constrain could be at least 4. - carry_add[i] = AB::Expr::from(carry_divide) - * (x_limbs[i] + y_limbs[i] - z_limbs[i] - + if i > 0 { - carry_add[i - 1].clone() - } else { - AB::Expr::zero() - }); - builder - .when(aux.opcode_add_flag) - .assert_bool(carry_add[i].clone()); - carry_sub[i] = AB::Expr::from(carry_divide) - * (z_limbs[i] + y_limbs[i] - x_limbs[i] - + if i > 0 { - carry_sub[i - 1].clone() - } else { - AB::Expr::zero() - }); - builder - .when(aux.opcode_sub_flag + aux.opcode_sltu_flag + aux.opcode_slt_flag) - .assert_bool(carry_sub[i].clone()); - } - - // For LT, cmp_result must be equal to the last carry. For SLT, cmp_result ^ x_sign ^ y_sign must - // be equal to the last carry. To ensure maximum cubic degree constraints, we set aux.x_sign and - // aux.y_sign are 0 when not computing an SLT. - builder.assert_bool(aux.x_sign); - builder.assert_bool(aux.y_sign); - builder - .when(utils::not(aux.opcode_slt_flag)) - .assert_zero(aux.x_sign); - builder - .when(utils::not(aux.opcode_slt_flag)) - .assert_zero(aux.y_sign); - - let slt_xor = - (aux.opcode_sltu_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign - - AB::Expr::from_canonical_u32(2) - * (io.cmp_result * aux.x_sign - + io.cmp_result * aux.y_sign - + aux.x_sign * aux.y_sign) - + AB::Expr::from_canonical_u32(4) * (io.cmp_result * aux.x_sign * aux.y_sign); - builder.assert_eq( - slt_xor, - (aux.opcode_sltu_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(), - ); - - // For EQ, z is filled with 0 except at the lowest index i such that x[i] != y[i]. If - // such an i exists z[i] is the inverse of x[i] - y[i], meaning sum_eq should be 1. - let mut sum_eq: AB::Expr = io.cmp_result.into(); - for i in 0..NUM_LIMBS { - sum_eq += (x_limbs[i] - y_limbs[i]) * z_limbs[i]; - builder - .when(aux.opcode_eq_flag) - .assert_zero(io.cmp_result * (x_limbs[i] - y_limbs[i])); - } - builder.when(aux.opcode_eq_flag).assert_one(sum_eq); - - let expected_opcode = flags - .iter() - .zip(U256Opcode::iter()) - .fold(AB::Expr::zero(), |acc, (flag, opcode)| { - acc + (*flag).into() * AB::Expr::from_canonical_u8(opcode as u8) - }); - - self.eval_interactions(builder, io, aux, expected_opcode); - } -} diff --git a/vm/src/old/alu/bridge.rs b/vm/src/old/alu/bridge.rs deleted file mode 100644 index f8d3043b32..0000000000 --- a/vm/src/old/alu/bridge.rs +++ /dev/null @@ -1,149 +0,0 @@ -use ax_stark_backend::interaction::InteractionBuilder; -use itertools::izip; -use p3_field::AbstractField; - -use super::{ - air::ArithmeticLogicCoreAir, - columns::{ArithmeticLogicAuxCols, ArithmeticLogicIoCols}, -}; -use crate::system::memory::MemoryAddress; - -impl ArithmeticLogicCoreAir { - pub fn eval_interactions( - &self, - builder: &mut AB, - io: &ArithmeticLogicIoCols, - aux: &ArithmeticLogicAuxCols, - expected_opcode: AB::Expr, - ) { - let timestamp: AB::Var = io.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) - }; - - let range_check = - aux.opcode_add_flag + aux.opcode_sub_flag + aux.opcode_sltu_flag + aux.opcode_slt_flag; - let bitwise = aux.opcode_xor_flag + aux.opcode_and_flag + aux.opcode_or_flag; - - // Read the operand pointer's values, which are themselves pointers - // for the actual IO data. - for (ptr, value, mem_aux) in izip!( - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address - ], - [io.z.address, io.x.address, io.y.address], - &aux.read_ptr_aux_cols - ) { - self.memory_bridge - .read( - MemoryAddress::new(io.ptr_as, ptr), - [value], - timestamp_pp(), - mem_aux, - ) - .eval(builder, aux.is_valid); - } - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.x.address), - io.x.data, - timestamp_pp(), - &aux.read_x_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.y.address), - io.y.data, - timestamp_pp(), - &aux.read_y_aux_cols, - ) - .eval(builder, aux.is_valid); - - // Special handling for writing output z data: - self.memory_bridge - .write( - MemoryAddress::new(io.address_as, io.z.address), - io.z.data, - timestamp + AB::F::from_canonical_usize(timestamp_delta), - &aux.write_z_aux_cols, - ) - .eval( - builder, - aux.opcode_add_flag + aux.opcode_sub_flag + bitwise.clone(), - ); - - // Special handling for writing output cmp data: - self.memory_bridge - .write( - MemoryAddress::new(io.address_as, io.z.address), - [io.cmp_result], - timestamp + AB::F::from_canonical_usize(timestamp_delta), - &aux.write_cmp_aux_cols, - ) - .eval( - builder, - aux.opcode_sltu_flag + aux.opcode_eq_flag + aux.opcode_slt_flag, - ); - timestamp_delta += 1; - - self.execution_bridge - .execute_and_increment_pc( - expected_opcode + AB::Expr::from_canonical_usize(self.offset), - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address, - io.ptr_as, - io.address_as, - ], - io.from_state, - AB::F::from_canonical_usize(timestamp_delta), - ) - .eval(builder, aux.is_valid); - - // Check x[NUM_LIMBS - 1] & (1 << LIMB_BITS - 1) == x_sign * (1 << LIMB_BITS - 1) using XOR - let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); - let x_sign_shifted = aux.x_sign * mask; - let y_sign_shifted = aux.y_sign * mask; - self.bus - .send_xor( - io.x.data[NUM_LIMBS - 1], - mask, - io.x.data[NUM_LIMBS - 1] + mask - - (AB::Expr::from_canonical_u32(2) * x_sign_shifted), - ) - .eval(builder, aux.opcode_slt_flag); - self.bus - .send_xor( - io.y.data[NUM_LIMBS - 1], - mask, - io.y.data[NUM_LIMBS - 1] + mask - - (AB::Expr::from_canonical_u32(2) * y_sign_shifted), - ) - .eval(builder, aux.opcode_slt_flag); - - // Chip-specific interactions - for i in 0..NUM_LIMBS { - let x = range_check.clone() * io.z.data[i] + bitwise.clone() * io.x.data[i]; - let y = range_check.clone() * io.z.data[i] + bitwise.clone() * io.y.data[i]; - let xor_res = aux.opcode_xor_flag * io.z.data[i] - + aux.opcode_and_flag - * (io.x.data[i] + io.y.data[i] - - (AB::Expr::from_canonical_u32(2) * io.z.data[i])) - + aux.opcode_or_flag - * ((AB::Expr::from_canonical_u32(2) * io.z.data[i]) - - io.x.data[i] - - io.y.data[i]); - self.bus - .send_xor(x, y, xor_res) - .eval(builder, range_check.clone() + bitwise.clone()); - } - } -} diff --git a/vm/src/old/alu/columns.rs b/vm/src/old/alu/columns.rs deleted file mode 100644 index 7c49cb16d6..0000000000 --- a/vm/src/old/alu/columns.rs +++ /dev/null @@ -1,52 +0,0 @@ -use ax_circuit_derive::AlignedBorrow; - -use crate::{ - arch::ExecutionState, - old::uint_multiplication::MemoryData, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ArithmeticLogicCols { - pub io: ArithmeticLogicIoCols, - pub aux: ArithmeticLogicAuxCols, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ArithmeticLogicIoCols { - pub from_state: ExecutionState, - pub x: MemoryData, - pub y: MemoryData, - pub z: MemoryData, - pub cmp_result: T, - pub ptr_as: T, - pub address_as: T, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ArithmeticLogicAuxCols { - pub is_valid: T, - pub x_sign: T, - pub y_sign: T, - - // Opcode flags for different operations - pub opcode_add_flag: T, - pub opcode_sub_flag: T, - pub opcode_sltu_flag: T, - pub opcode_eq_flag: T, - pub opcode_xor_flag: T, - pub opcode_and_flag: T, - pub opcode_or_flag: T, - pub opcode_slt_flag: T, - - /// Pointer read auxiliary columns for [z, x, y]. - /// **Note** the ordering, which is designed to match the instruction order. - pub read_ptr_aux_cols: [MemoryReadAuxCols; 3], - pub read_x_aux_cols: MemoryReadAuxCols, - pub read_y_aux_cols: MemoryReadAuxCols, - pub write_z_aux_cols: MemoryWriteAuxCols, - pub write_cmp_aux_cols: MemoryWriteAuxCols, -} diff --git a/vm/src/old/alu/mod.rs b/vm/src/old/alu/mod.rs deleted file mode 100644 index 7ebcea03fe..0000000000 --- a/vm/src/old/alu/mod.rs +++ /dev/null @@ -1,301 +0,0 @@ -use std::sync::Arc; - -use air::ArithmeticLogicCoreAir; -use ax_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupChip; -use axvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; -use p3_field::PrimeField32; - -use crate::{ - arch::{ - instructions::{U256Opcode, UsizeOpcode}, - ExecutionBridge, ExecutionBus, ExecutionState, InstructionExecutor, - }, - system::{ - memory::{MemoryControllerRef, MemoryReadRecord, MemoryWriteRecord}, - program::{ExecutionError, ProgramBus}, - }, -}; - -mod air; -mod bridge; -mod columns; -mod trace; - -// pub use air::*; -pub use columns::*; - -#[cfg(test)] -mod tests; - -pub const ALU_CMP_INSTRUCTIONS: [U256Opcode; 3] = [U256Opcode::LT, U256Opcode::EQ, U256Opcode::SLT]; -pub const ALU_ARITHMETIC_INSTRUCTIONS: [U256Opcode; 2] = [U256Opcode::ADD, U256Opcode::SUB]; -pub const ALU_BITWISE_INSTRUCTIONS: [U256Opcode; 3] = - [U256Opcode::XOR, U256Opcode::AND, U256Opcode::OR]; - -#[derive(Clone, Debug)] -pub enum WriteRecord { - Long(MemoryWriteRecord), - Bool(MemoryWriteRecord), -} - -#[derive(Clone, Debug)] -pub struct ArithmeticLogicRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub x_ptr_read: MemoryReadRecord, - pub y_ptr_read: MemoryReadRecord, - pub z_ptr_read: MemoryReadRecord, - - pub x_read: MemoryReadRecord, - pub y_read: MemoryReadRecord, - pub z_write: WriteRecord, - - // sign of x and y if SLT, else should be 0 - pub x_sign: T, - pub y_sign: T, - - // empty if not bool instruction, else contents of this vector will be stored in z - pub cmp_buffer: Vec, -} - -#[derive(Debug)] -pub struct ArithmeticLogicChip { - pub air: ArithmeticLogicCoreAir, - data: Vec>, - memory_controller: MemoryControllerRef, - pub bitwise_lookup_chip: Arc>, - - offset: usize, -} - -impl - ArithmeticLogicChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_controller: MemoryControllerRef, - bitwise_lookup_chip: Arc>, - offset: usize, - ) -> Self { - let memory_bridge = memory_controller.borrow().memory_bridge(); - Self { - air: ArithmeticLogicCoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - offset, - }, - data: vec![], - memory_controller, - bitwise_lookup_chip, - offset, - } - } -} - -impl InstructionExecutor - for ArithmeticLogicChip -{ - fn execute( - &mut self, - instruction: Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction.clone(); - let local_opcode_index = U256Opcode::from_usize(opcode - self.offset); - - let mut memory_controller = self.memory_controller.borrow_mut(); - debug_assert_eq!(from_state.timestamp, memory_controller.timestamp()); - - let [z_ptr_read, x_ptr_read, y_ptr_read] = - [a, b, c].map(|ptr_of_ptr| memory_controller.read_cell(d, ptr_of_ptr)); - let x_read = memory_controller.read::(e, x_ptr_read.value()); - let y_read = memory_controller.read::(e, y_ptr_read.value()); - - let x = x_read.data.map(|x| x.as_canonical_u32()); - let y = y_read.data.map(|x| x.as_canonical_u32()); - let (z, cmp) = run_alu::(local_opcode_index, &x, &y); - - let z_write = if ALU_CMP_INSTRUCTIONS.contains(&local_opcode_index) { - WriteRecord::Bool(memory_controller.write_cell( - e, - z_ptr_read.value(), - T::from_bool(cmp), - )) - } else { - WriteRecord::Long( - memory_controller.write::( - e, - z_ptr_read.value(), - z.clone() - .into_iter() - .map(T::from_canonical_u32) - .collect::>() - .try_into() - .unwrap(), - ), - ) - }; - - let mut x_sign = 0; - let mut y_sign = 0; - - if local_opcode_index == U256Opcode::SLT { - x_sign = x[NUM_LIMBS - 1] >> (LIMB_BITS - 1); - y_sign = y[NUM_LIMBS - 1] >> (LIMB_BITS - 1); - self.bitwise_lookup_chip - .request_xor(x[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); - self.bitwise_lookup_chip - .request_xor(y[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); - } - - if ALU_BITWISE_INSTRUCTIONS.contains(&local_opcode_index) { - for i in 0..NUM_LIMBS { - self.bitwise_lookup_chip.request_xor(x[i], y[i]); - } - } else if local_opcode_index != U256Opcode::EQ { - for z_val in &z { - self.bitwise_lookup_chip.request_xor(*z_val, *z_val); - } - } - - self.data - .push(ArithmeticLogicRecord:: { - from_state, - instruction: Instruction { - opcode: local_opcode_index as usize, - ..instruction - }, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - x_sign: T::from_canonical_u32(x_sign), - y_sign: T::from_canonical_u32(y_sign), - cmp_buffer: if ALU_CMP_INSTRUCTIONS.contains(&local_opcode_index) { - z.into_iter().map(T::from_canonical_u32).collect() - } else { - vec![] - }, - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory_controller.timestamp(), - }) - } - - fn get_opcode_name(&self, opcode: usize) -> String { - let local_opcode_index = U256Opcode::from_usize(opcode - self.offset); - format!("{local_opcode_index:?}<{NUM_LIMBS},{LIMB_BITS}>") - } -} - -fn run_alu( - local_opcode_index: U256Opcode, - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - match local_opcode_index { - U256Opcode::ADD => run_add::(x, y), - U256Opcode::SUB | U256Opcode::LT => run_subtract::(x, y), - U256Opcode::EQ => run_eq::(x, y), - U256Opcode::XOR => run_xor::(x, y), - U256Opcode::AND => run_and::(x, y), - U256Opcode::OR => run_or::(x, y), - U256Opcode::SLT => { - let (z, cmp) = run_subtract::(x, y); - ( - z, - cmp ^ (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) != 0) - ^ (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) != 0), - ) - } - _ => unreachable!(), - } -} - -fn run_add( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let mut z = vec![0u32; NUM_LIMBS]; - let mut carry = vec![0u32; NUM_LIMBS]; - for i in 0..NUM_LIMBS { - z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; - carry[i] = z[i] >> LIMB_BITS; - z[i] &= (1 << LIMB_BITS) - 1; - } - (z, false) -} - -fn run_subtract( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let mut z = vec![0u32; NUM_LIMBS]; - let mut carry = vec![0u32; NUM_LIMBS]; - for i in 0..NUM_LIMBS { - let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; - if x[i] >= rhs { - z[i] = x[i] - rhs; - carry[i] = 0; - } else { - z[i] = x[i] + (1 << LIMB_BITS) - rhs; - carry[i] = 1; - } - } - (z, carry[NUM_LIMBS - 1] != 0) -} - -fn run_eq( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let mut z = vec![0u32; NUM_LIMBS]; - for i in 0..NUM_LIMBS { - if x[i] != y[i] { - z[i] = (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])) - .inverse() - .as_canonical_u32(); - return (z, false); - } - } - (z, true) -} - -fn run_xor( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let z = (0..NUM_LIMBS).map(|i| x[i] ^ y[i]).collect(); - (z, false) -} - -fn run_and( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let z = (0..NUM_LIMBS).map(|i| x[i] & y[i]).collect(); - (z, false) -} - -fn run_or( - x: &[u32], - y: &[u32], -) -> (Vec, bool) { - let z = (0..NUM_LIMBS).map(|i| x[i] | y[i]).collect(); - (z, false) -} diff --git a/vm/src/old/alu/tests.rs b/vm/src/old/alu/tests.rs deleted file mode 100644 index 0e41ee628b..0000000000 --- a/vm/src/old/alu/tests.rs +++ /dev/null @@ -1,662 +0,0 @@ -use std::{array, borrow::BorrowMut, iter, sync::Arc}; - -use ax_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, -}; -use ax_stark_backend::{utils::disable_debug_builder, verifier::VerificationError, Chip}; -use ax_stark_sdk::utils::create_seeded_rng; -use axvm_instructions::instruction::Instruction; -use p3_baby_bear::BabyBear; -use p3_field::{AbstractField, PrimeField32}; -use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use rand::{rngs::StdRng, Rng}; - -use super::{ - columns::ArithmeticLogicCols, run_subtract, ArithmeticLogicChip, ALU_CMP_INSTRUCTIONS, -}; -use crate::{ - arch::{ - instructions::U256Opcode, - testing::{memory::gen_pointer, VmChipTestBuilder}, - BITWISE_OP_LOOKUP_BUS, - }, - old::alu::run_alu, -}; - -type F = BabyBear; - -const NUM_LIMBS: usize = 32; -const LIMB_BITS: usize = 8; - -fn generate_long_number( - rng: &mut StdRng, -) -> Vec { - (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..1 << LIMB_BITS)) - .collect() -} - -#[allow(clippy::too_many_arguments)] -fn run_alu_rand_write_execute( - tester: &mut VmChipTestBuilder, - chip: &mut ArithmeticLogicChip, - opcode: U256Opcode, - x: Vec, - y: Vec, - rng: &mut StdRng, -) { - let address_space_range = || 1usize..=2; - - let d = rng.gen_range(address_space_range()); - let e = rng.gen_range(address_space_range()); - - let x_address = gen_pointer(rng, 32); - let y_address = gen_pointer(rng, 32); - let res_address = gen_pointer(rng, 32); - let x_ptr_to_address = gen_pointer(rng, 1); - let y_ptr_to_address = gen_pointer(rng, 1); - let res_ptr_to_address = gen_pointer(rng, 1); - - let x_f = x - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let y_f = y - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write_cell(d, x_ptr_to_address, F::from_canonical_usize(x_address)); - tester.write_cell(d, y_ptr_to_address, F::from_canonical_usize(y_address)); - tester.write_cell(d, res_ptr_to_address, F::from_canonical_usize(res_address)); - tester.write::(e, x_address, x_f.as_slice().try_into().unwrap()); - tester.write::(e, y_address, y_f.as_slice().try_into().unwrap()); - - let (z, cmp) = run_alu::(opcode, &x, &y); - tester.execute( - chip, - Instruction::from_usize( - opcode as usize, - [res_ptr_to_address, x_ptr_to_address, y_ptr_to_address, d, e], - ), - ); - - if ALU_CMP_INSTRUCTIONS.contains(&opcode) { - assert_eq!([F::from_bool(cmp)], tester.read::<1>(e, res_address)) - } else { - assert_eq!( - z.into_iter().map(F::from_canonical_u32).collect::>(), - tester.read::(e, res_address) - ) - } -} - -/// Given a fake trace of a single operation, setup a chip and run the test. -/// We replace the "output" part of the trace, and we _may_ replace the interactions -/// based on the desired output. We check that it produces the error we expect. -#[allow(clippy::too_many_arguments)] -fn run_alu_negative_test( - opcode: U256Opcode, - x: Vec, - y: Vec, - z: Vec, - cmp_result: bool, - x_sign: u32, - y_sign: u32, - expected_error: VerificationError, -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - let mut rng = create_seeded_rng(); - run_alu_rand_write_execute( - &mut tester, - &mut chip, - opcode, - x.clone(), - y.clone(), - &mut rng, - ); - - let mut chip_input = chip.generate_air_proof_input(); - let alu_trace = chip_input.raw.common_main.as_mut().unwrap(); - let mut alu_trace_row = alu_trace.row_slice(0).to_vec(); - let alu_trace_cols: &mut ArithmeticLogicCols = (*alu_trace_row).borrow_mut(); - - alu_trace_cols.io.z.data = array::from_fn(|i| F::from_canonical_u32(z[i])); - alu_trace_cols.io.cmp_result = F::from_bool(cmp_result); - alu_trace_cols.aux.x_sign = F::from_canonical_u32(x_sign); - alu_trace_cols.aux.y_sign = F::from_canonical_u32(y_sign); - *alu_trace = RowMajorMatrix::new( - alu_trace_row, - ArithmeticLogicCols::::width(), - ); - - disable_debug_builder(); - let tester = tester - .build() - .load_air_proof_input(chip_input) - .load(bitwise_chip) - .finalize(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - &expected_error - ); - let result = tester.simple_test(); - assert_eq!(result.err(), Some(expected_error), "{}", msg); -} - -#[test] -fn alu_add_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::ADD, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_add_out_of_range_negative_test() { - run_alu_negative_test( - U256Opcode::ADD, - iter::once(250) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(250) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(500) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - 0, - 0, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_add_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::ADD, - iter::once(250) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(250) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(500 - (1 << 8)) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_sub_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::SUB, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_sub_out_of_range_negative_test() { - run_alu_negative_test( - U256Opcode::SUB, - iter::once(1) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(F::neg_one().as_canonical_u32()) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - 0, - 0, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_sub_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::SUB, - iter::once(1) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - iter::once((1 << 8) - 1) - .chain(iter::repeat(0).take(NUM_LIMBS - 1)) - .collect(), - false, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_sltu_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::LT, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_sltu_wrong_subtraction_test() { - run_alu_negative_test( - U256Opcode::LT, - iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), - iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), - std::iter::once(1).chain(iter::repeat(0).take(31)).collect(), - false, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_sltu_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::LT, - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - iter::once(0).chain(iter::repeat(0).take(31)).collect(), - true, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_sltu_non_zero_sign_negative_test() { - run_alu_negative_test( - U256Opcode::LT, - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - vec![0; NUM_LIMBS], - false, - 1, - 1, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_eq_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::EQ, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_eq_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::EQ, - vec![0; 31].into_iter().chain(iter::once(123)).collect(), - vec![0; 31].into_iter().chain(iter::once(456)).collect(), - vec![0; 31].into_iter().chain(iter::once(0)).collect(), - true, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_xor_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::XOR, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_xor_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::XOR, - vec![0; 31].into_iter().chain(iter::once(1)).collect(), - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - true, - 0, - 0, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_and_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::AND, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_and_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::AND, - vec![0; 31].into_iter().chain(iter::once(1)).collect(), - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - vec![0; NUM_LIMBS], - true, - 0, - 0, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_or_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::OR, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_or_wrong_negative_test() { - run_alu_negative_test( - U256Opcode::OR, - vec![0; NUM_LIMBS], - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS - 1] - .into_iter() - .chain(iter::once((1 << LIMB_BITS) - 2)) - .collect(), - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - true, - 0, - 0, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_slt_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ArithmeticLogicChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, U256Opcode::SLT, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn alu_slt_pos_neg_sign_negative_test() { - let x = [0; NUM_LIMBS]; - let y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - true, - 0, - 1, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_slt_neg_pos_sign_negative_test() { - let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - let y = [0; NUM_LIMBS]; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - false, - 1, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_slt_both_pos_sign_negative_test() { - let x = [0; NUM_LIMBS]; - let mut y = [0; NUM_LIMBS]; - y[0] = 1; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - false, - 0, - 0, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_slt_both_neg_sign_negative_test() { - let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - y[0] = 1; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - true, - 1, - 1, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_slt_wrong_sign_negative_test() { - let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - y[0] = 1; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - true, - 0, - 1, - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn alu_slt_non_boolean_sign_negative_test() { - let x = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - let mut y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - y[0] = 1; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - false, - 2, - 1, - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn alu_slt_wrong_xor_test() { - let x = [(1 << (LIMB_BITS - 1)) + 1; NUM_LIMBS]; - let y = [(1 << LIMB_BITS) - 1; NUM_LIMBS]; - run_alu_negative_test( - U256Opcode::SLT, - x.to_vec(), - y.to_vec(), - run_subtract::(&x, &y).0, - false, - 0, - 1, - VerificationError::NonZeroCumulativeSum, - ); -} diff --git a/vm/src/old/alu/trace.rs b/vm/src/old/alu/trace.rs deleted file mode 100644 index bfafa6e486..0000000000 --- a/vm/src/old/alu/trace.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use ax_stark_backend::{ - config::{StarkGenericConfig, Val}, - prover::types::AirProofInput, - rap::{get_air_name, AnyRap}, - Chip, ChipUsageGetter, -}; -use p3_field::{AbstractField, PrimeField32}; -use p3_matrix::dense::RowMajorMatrix; - -use super::{ - columns::{ArithmeticLogicAuxCols, ArithmeticLogicCols, ArithmeticLogicIoCols}, - ArithmeticLogicChip, ArithmeticLogicRecord, WriteRecord, -}; -use crate::{ - arch::instructions::{U256Opcode, UsizeOpcode}, - old::uint_multiplication::MemoryData, - system::memory::offline_checker::MemoryWriteAuxCols, -}; - -impl Chip - for ArithmeticLogicChip, NUM_LIMBS, LIMB_BITS> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let air = self.air(); - let aux_cols_factory = self.memory_controller.borrow().aux_cols_factory(); - - let width = self.trace_width(); - let height = self.data.len(); - let padded_height = height.next_power_of_two(); - let mut rows = vec![Val::::zero(); width * padded_height]; - - for (row, operation) in rows.chunks_mut(width).zip(self.data) { - let ArithmeticLogicRecord::, NUM_LIMBS, LIMB_BITS> { - from_state, - instruction, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - x_sign, - y_sign, - cmp_buffer, - } = operation; - - let row: &mut ArithmeticLogicCols, NUM_LIMBS, LIMB_BITS> = row.borrow_mut(); - - row.io = ArithmeticLogicIoCols { - from_state: from_state.map(Val::::from_canonical_u32), - x: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: x_read.data, - address: x_read.pointer, - ptr_to_address: x_ptr_read.pointer, - }, - y: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: y_read.data, - address: y_read.pointer, - ptr_to_address: y_ptr_read.pointer, - }, - z: match &z_write { - WriteRecord::Long(z) => MemoryData { - data: z.data, - address: z.pointer, - ptr_to_address: z_ptr_read.pointer, - }, - WriteRecord::Bool(z) => MemoryData { - data: array::from_fn(|i| cmp_buffer[i]), - address: z.pointer, - ptr_to_address: z_ptr_read.pointer, - }, - }, - cmp_result: match &z_write { - WriteRecord::Long(_) => Val::::zero(), - WriteRecord::Bool(z) => z.data[0], - }, - ptr_as: instruction.d, - address_as: instruction.e, - }; - - let opcode = U256Opcode::from_usize(instruction.opcode); - row.aux = ArithmeticLogicAuxCols { - is_valid: Val::::one(), - x_sign, - y_sign, - opcode_add_flag: Val::::from_bool(opcode == U256Opcode::ADD), - opcode_sub_flag: Val::::from_bool(opcode == U256Opcode::SUB), - opcode_sltu_flag: Val::::from_bool(opcode == U256Opcode::LT), - opcode_eq_flag: Val::::from_bool(opcode == U256Opcode::EQ), - opcode_xor_flag: Val::::from_bool(opcode == U256Opcode::XOR), - opcode_and_flag: Val::::from_bool(opcode == U256Opcode::AND), - opcode_or_flag: Val::::from_bool(opcode == U256Opcode::OR), - opcode_slt_flag: Val::::from_bool(opcode == U256Opcode::SLT), - read_ptr_aux_cols: [z_ptr_read, x_ptr_read, y_ptr_read] - .map(|read| aux_cols_factory.make_read_aux_cols(read)), - read_x_aux_cols: aux_cols_factory.make_read_aux_cols(x_read), - read_y_aux_cols: aux_cols_factory.make_read_aux_cols(y_read), - write_z_aux_cols: match &z_write { - WriteRecord::Long(z) => aux_cols_factory.make_write_aux_cols(*z), - WriteRecord::Bool(_) => MemoryWriteAuxCols::disabled(), - }, - write_cmp_aux_cols: match &z_write { - WriteRecord::Long(_) => MemoryWriteAuxCols::disabled(), - WriteRecord::Bool(z) => aux_cols_factory.make_write_aux_cols(*z), - }, - }; - } - AirProofInput::simple_no_pis(air, RowMajorMatrix::new(rows, width)) - } -} - -impl ChipUsageGetter - for ArithmeticLogicChip -{ - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.data.len() - } - - fn trace_width(&self) -> usize { - ArithmeticLogicCols::::width() - } -} diff --git a/vm/src/old/mod.rs b/vm/src/old/mod.rs deleted file mode 100644 index 52a924321e..0000000000 --- a/vm/src/old/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod alu; -pub mod shift; -pub mod uint_multiplication; diff --git a/vm/src/old/shift/air.rs b/vm/src/old/shift/air.rs deleted file mode 100644 index 12d2d07c31..0000000000 --- a/vm/src/old/shift/air.rs +++ /dev/null @@ -1,179 +0,0 @@ -use std::{borrow::Borrow, iter::zip}; - -use ax_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, utils, var_range::VariableRangeCheckerBus, -}; -use ax_stark_backend::{ - interaction::InteractionBuilder, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; - -use super::columns::ShiftCols; -use crate::{ - arch::{instructions::U256Opcode, ExecutionBridge}, - system::memory::offline_checker::MemoryBridge, -}; - -#[derive(Clone, Copy, Debug)] -pub struct ShiftCoreAir { - pub(super) execution_bridge: ExecutionBridge, - pub(super) memory_bridge: MemoryBridge, - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - pub range_bus: VariableRangeCheckerBus, - - pub(super) offset: usize, -} - -impl PartitionedBaseAir - for ShiftCoreAir -{ -} -impl BaseAir - for ShiftCoreAir -{ - fn width(&self) -> usize { - ShiftCols::::width() - } -} - -impl BaseAirWithPublicValues - for ShiftCoreAir -{ -} - -impl Air - for ShiftCoreAir -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - - let ShiftCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow(); - builder.assert_bool(aux.is_valid); - - // Constrain that flags are valid. - let flags = [ - aux.opcode_sll_flag, - aux.opcode_srl_flag, - aux.opcode_sra_flag, - ]; - for flag in flags { - builder.assert_bool(flag); - } - - builder.assert_eq( - aux.is_valid, - flags - .iter() - .fold(AB::Expr::zero(), |acc, &flag| acc + flag.into()), - ); - - let x_limbs = &io.x.data; - let y_limbs = &io.y.data; - let z_limbs = &io.z.data; - let right_shift = aux.opcode_srl_flag + aux.opcode_sra_flag; - - // Constrain that bit_shift, bit_multiplier are correct, i.e. that bit_multiplier = - // 1 << bit_shift. We check that bit_shift is correct below if y < NUM_LIMBS * LIMB_BITS, - // otherwise we don't really care what its value is. Note that bit_shift < LIMB_BITS is - // constrained in bridge.rs via the range checker. - builder - .when(aux.opcode_sll_flag) - .assert_zero(aux.bit_multiplier_right); - builder - .when(right_shift.clone()) - .assert_zero(aux.bit_multiplier_left); - - for i in 0..LIMB_BITS { - let mut when_bit_shift = builder.when(aux.bit_shift_marker[i]); - when_bit_shift.assert_eq(aux.bit_shift, AB::F::from_canonical_usize(i)); - when_bit_shift - .when(aux.opcode_sll_flag) - .assert_eq(aux.bit_multiplier_left, AB::F::from_canonical_usize(1 << i)); - when_bit_shift.when(right_shift.clone()).assert_eq( - aux.bit_multiplier_right, - AB::F::from_canonical_usize(1 << i), - ); - } - - builder.assert_bool(aux.x_sign); - builder - .when(utils::not(aux.opcode_sra_flag)) - .assert_zero(aux.x_sign); - - let mut marker_sum = AB::Expr::zero(); - - // Check that z[i] = x[i] <> y[i] both on the bit and limb shift level if y < - // NUM_LIMBS * LIMB_BITS. - for i in 0..NUM_LIMBS { - marker_sum += aux.limb_shift_marker[i].into(); - builder.assert_bool(aux.limb_shift_marker[i]); - - let mut when_limb_shift = builder.when(aux.limb_shift_marker[i]); - when_limb_shift.assert_eq( - y_limbs[1] * AB::F::from_canonical_usize(1 << LIMB_BITS) + y_limbs[0] - - aux.bit_shift, - AB::F::from_canonical_usize(i * LIMB_BITS), - ); - - for j in 0..NUM_LIMBS { - // SLL constraints - if j < i { - when_limb_shift.assert_zero(z_limbs[j] * aux.opcode_sll_flag); - } else { - let expected_z_left = if j - i == 0 { - AB::Expr::zero() - } else { - aux.bit_shift_carry[j - i - 1].into() * aux.opcode_sll_flag - } + x_limbs[j - i] * aux.bit_multiplier_left - - AB::Expr::from_canonical_usize(1 << LIMB_BITS) - * aux.bit_shift_carry[j - i] - * aux.opcode_sll_flag; - when_limb_shift.assert_eq(z_limbs[j] * aux.opcode_sll_flag, expected_z_left); - } - - // SRL and SRA constraints. Combining with above would require an additional column. - if j + i > NUM_LIMBS - 1 { - when_limb_shift.assert_eq( - z_limbs[j] * right_shift.clone(), - aux.x_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1), - ); - } else { - let expected_z_right = if j + i == NUM_LIMBS - 1 { - aux.x_sign * (aux.bit_multiplier_right - AB::F::one()) - } else { - aux.bit_shift_carry[j + i + 1].into() * right_shift.clone() - } * AB::F::from_canonical_usize(1 << LIMB_BITS) - + right_shift.clone() * (x_limbs[j + i] - aux.bit_shift_carry[j + i]); - when_limb_shift - .assert_eq(z_limbs[j] * aux.bit_multiplier_right, expected_z_right); - } - - // Ensure y is defined entirely within y[0] and y[1] if limb shifting - if j > 1 { - when_limb_shift.assert_zero(y_limbs[j]); - } - } - } - - // If the shift is larger than the number of bits, check that each limb of z is filled - for z in z_limbs { - builder - .when(AB::Expr::one() - marker_sum.clone()) - .assert_eq( - *z, - aux.x_sign * AB::F::from_canonical_usize((1 << LIMB_BITS) - 1), - ); - } - - let expected_opcode = zip(flags, U256Opcode::shift_opcodes()) - .fold(AB::Expr::zero(), |acc, (flag, opcode)| { - acc + flag * AB::Expr::from_canonical_u8(opcode as u8) - }); - - self.eval_interactions(builder, io, aux, expected_opcode); - } -} diff --git a/vm/src/old/shift/bridge.rs b/vm/src/old/shift/bridge.rs deleted file mode 100644 index f9da8f86bc..0000000000 --- a/vm/src/old/shift/bridge.rs +++ /dev/null @@ -1,113 +0,0 @@ -use ax_stark_backend::interaction::InteractionBuilder; -use itertools::izip; -use p3_field::AbstractField; - -use super::{ - air::ShiftCoreAir, - columns::{ShiftAuxCols, ShiftIoCols}, -}; -use crate::system::memory::MemoryAddress; - -impl ShiftCoreAir { - pub fn eval_interactions( - &self, - builder: &mut AB, - io: &ShiftIoCols, - aux: &ShiftAuxCols, - expected_opcode: AB::Expr, - ) { - let timestamp: AB::Var = io.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) - }; - - for (ptr, value, mem_aux) in izip!( - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address - ], - [io.z.address, io.x.address, io.y.address], - &aux.read_ptr_aux_cols - ) { - self.memory_bridge - .read( - MemoryAddress::new(io.ptr_as, ptr), - [value], - timestamp_pp(), - mem_aux, - ) - .eval(builder, aux.is_valid); - } - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.x.address), - io.x.data, - timestamp_pp(), - &aux.read_x_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.y.address), - io.y.data, - timestamp_pp(), - &aux.read_y_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.memory_bridge - .write( - MemoryAddress::new(io.address_as, io.z.address), - io.z.data, - timestamp_pp(), - &aux.write_z_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.execution_bridge - .execute_and_increment_pc( - expected_opcode + AB::Expr::from_canonical_usize(self.offset), - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address, - io.ptr_as, - io.address_as, - ], - io.from_state, - AB::F::from_canonical_usize(timestamp_delta), - ) - .eval(builder, aux.is_valid); - - // Check that bit_shift < LIMB_BITS - self.range_bus - .range_check(aux.bit_shift, LIMB_BITS.ilog2() as usize) - .eval(builder, aux.is_valid); - - // Check x_sign & x[NUM_LIMBS - 1] == x_sign using XOR - let mask = AB::F::from_canonical_u32(1 << (LIMB_BITS - 1)); - let x_sign_shifted = aux.x_sign * mask; - self.bitwise_lookup_bus - .send_xor( - io.x.data[NUM_LIMBS - 1], - mask, - io.x.data[NUM_LIMBS - 1] + mask - - (AB::Expr::from_canonical_u32(2) * x_sign_shifted), - ) - .eval(builder, aux.opcode_sra_flag); - - for (z, carry) in io.z.data.iter().zip(aux.bit_shift_carry.iter()) { - self.range_bus - .range_check(*z, LIMB_BITS) - .eval(builder, aux.is_valid); - self.range_bus - .send(*carry, aux.bit_shift) - .eval(builder, aux.is_valid); - } - } -} diff --git a/vm/src/old/shift/columns.rs b/vm/src/old/shift/columns.rs deleted file mode 100644 index fdc22fe987..0000000000 --- a/vm/src/old/shift/columns.rs +++ /dev/null @@ -1,58 +0,0 @@ -use ax_circuit_derive::AlignedBorrow; - -use crate::{ - arch::ExecutionState, - old::uint_multiplication::MemoryData, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ShiftCols { - pub io: ShiftIoCols, - pub aux: ShiftAuxCols, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ShiftIoCols { - pub from_state: ExecutionState, - pub x: MemoryData, - pub y: MemoryData, - pub z: MemoryData, - pub ptr_as: T, - pub address_as: T, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct ShiftAuxCols { - pub is_valid: T, - - // Each limb is shifted by bit_shift, where y[0] = bit_shift + LIMB_BITS * bit_quotient and - // bit_multiplier = 2^bit_shift - pub bit_shift: T, - pub bit_multiplier_left: T, - pub bit_multiplier_right: T, - - // Sign of x for SRA - pub x_sign: T, - - // Boolean columns that are 1 exactly at the index of the bit/limb shift amount - pub bit_shift_marker: [T; LIMB_BITS], - pub limb_shift_marker: [T; NUM_LIMBS], - - // Part of each x[i] that gets bit shifted to the next limb - pub bit_shift_carry: [T; NUM_LIMBS], - - // Opcode flags for different operations - pub opcode_sll_flag: T, - pub opcode_srl_flag: T, - pub opcode_sra_flag: T, - - // Pointer read auxiliary columns for [z, x, y] - pub read_ptr_aux_cols: [MemoryReadAuxCols; 3], - pub read_x_aux_cols: MemoryReadAuxCols, - pub read_y_aux_cols: MemoryReadAuxCols, - pub write_z_aux_cols: MemoryWriteAuxCols, -} diff --git a/vm/src/old/shift/mod.rs b/vm/src/old/shift/mod.rs deleted file mode 100644 index 60a0d5cdda..0000000000 --- a/vm/src/old/shift/mod.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::{array, sync::Arc}; - -use air::ShiftCoreAir; -use ax_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChip, var_range::VariableRangeCheckerChip, -}; -use axvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; -use p3_field::PrimeField32; - -use crate::{ - arch::{ - instructions::{U256Opcode, UsizeOpcode}, - ExecutionBridge, ExecutionBus, ExecutionState, InstructionExecutor, - }, - system::{ - memory::{MemoryControllerRef, MemoryReadRecord, MemoryWriteRecord}, - program::{ExecutionError, ProgramBus}, - }, -}; - -mod air; -mod bridge; -mod columns; -mod trace; - -#[cfg(test)] -mod tests; - -#[derive(Clone, Debug)] -pub struct ShiftRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - pub x_ptr_read: MemoryReadRecord, - pub y_ptr_read: MemoryReadRecord, - pub z_ptr_read: MemoryReadRecord, - pub x_read: MemoryReadRecord, - pub y_read: MemoryReadRecord, - pub z_write: MemoryWriteRecord, - pub bit_shift_carry: [T; NUM_LIMBS], - pub bit_shift: usize, - pub limb_shift: usize, - pub x_sign: T, -} - -#[derive(Debug)] -pub struct ShiftChip { - pub air: ShiftCoreAir, - data: Vec>, - memory_controller: MemoryControllerRef, - pub bitwise_lookup_chip: Arc>, - pub range_checker_chip: Arc, - - offset: usize, -} - -impl - ShiftChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_controller: MemoryControllerRef, - bitwise_lookup_chip: Arc>, - offset: usize, - ) -> Self { - // (1 << (2 * LIMB_BITS)) fits within a u32 - assert!(LIMB_BITS < 16, "LIMB_BITS {} >= 16", LIMB_BITS); - // For range check that bit_shift < LIMB_BITS - assert!( - LIMB_BITS.is_power_of_two(), - "LIMB_BITS {} not a power of 2", - LIMB_BITS - ); - // A non-overflow shift amount is defined entirely within y[0] and y[1] - assert!( - NUM_LIMBS * LIMB_BITS < (1 << (2 * LIMB_BITS)), - "NUM_LIMBS * LIMB_BITS {} >= 2^(2 * LIMB_BITS {})", - NUM_LIMBS * LIMB_BITS, - LIMB_BITS - ); - let memory_bridge = memory_controller.borrow().memory_bridge(); - let range_checker_chip = memory_controller.borrow().range_checker.clone(); - Self { - air: ShiftCoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - range_bus: range_checker_chip.bus(), - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - offset, - }, - data: vec![], - memory_controller, - range_checker_chip, - bitwise_lookup_chip, - offset, - } - } -} - -impl InstructionExecutor - for ShiftChip -{ - fn execute( - &mut self, - instruction: Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction.clone(); - let local_opcode_index = opcode - self.offset; - assert!(U256Opcode::shift_opcodes().any(|op| op as usize == local_opcode_index)); - - let mut memory_controller = self.memory_controller.borrow_mut(); - debug_assert_eq!(from_state.timestamp, memory_controller.timestamp()); - - let [z_ptr_read, x_ptr_read, y_ptr_read] = - [a, b, c].map(|ptr_of_ptr| memory_controller.read_cell(d, ptr_of_ptr)); - let x_read = memory_controller.read::(e, x_ptr_read.value()); - let y_read = memory_controller.read::(e, y_ptr_read.value()); - - let x = x_read.data.map(|x| x.as_canonical_u32()); - let y = y_read.data.map(|y| y.as_canonical_u32()); - let (z, limb_shift, bit_shift) = - run_shift::(&x, &y, U256Opcode::from_usize(local_opcode_index)); - - let carry = x - .into_iter() - .map( - |val: u32| match U256Opcode::from_usize(local_opcode_index) { - U256Opcode::SLL => val >> (LIMB_BITS - bit_shift), - _ => val % (1 << bit_shift), - }, - ) - .collect::>(); - - let mut x_sign = 0; - if local_opcode_index == U256Opcode::SRA as usize { - x_sign = x[NUM_LIMBS - 1] >> (LIMB_BITS - 1); - self.bitwise_lookup_chip - .request_xor(x[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); - } - - self.range_checker_chip - .add_count(bit_shift as u32, LIMB_BITS.ilog2() as usize); - for (z_val, carry_val) in z.iter().zip(carry.iter()) { - self.range_checker_chip.add_count(*z_val, LIMB_BITS); - self.range_checker_chip.add_count(*carry_val, bit_shift); - } - - let z_write = memory_controller.write::( - e, - z_ptr_read.value(), - z.into_iter() - .map(T::from_canonical_u32) - .collect::>() - .try_into() - .unwrap(), - ); - - self.data.push(ShiftRecord { - from_state, - instruction: Instruction { - opcode: local_opcode_index, - ..instruction - }, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - bit_shift_carry: array::from_fn(|i| T::from_canonical_u32(carry[i])), - bit_shift, - limb_shift, - x_sign: T::from_canonical_u32(x_sign), - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory_controller.timestamp(), - }) - } - - fn get_opcode_name(&self, opcode: usize) -> String { - let local_opcode_index = U256Opcode::from_usize(opcode - self.offset); - format!("{local_opcode_index:?}<{NUM_LIMBS},{LIMB_BITS}>") - } -} - -fn run_shift( - x: &[u32], - y: &[u32], - op: U256Opcode, -) -> (Vec, usize, usize) { - match op { - U256Opcode::SLL => run_shift_left::(x, y), - U256Opcode::SRL => run_shift_right::(x, y, true), - U256Opcode::SRA => run_shift_right::(x, y, false), - _ => unreachable!(), - } -} - -fn run_shift_left( - x: &[u32], - y: &[u32], -) -> (Vec, usize, usize) { - let mut result = vec![0u32; NUM_LIMBS]; - - let (is_zero, limb_shift, bit_shift) = get_shift::(y); - if is_zero { - return (result, limb_shift, bit_shift); - } - - for i in limb_shift..NUM_LIMBS { - result[i] = if i > limb_shift { - ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } else { - (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) - }; - } - (result, limb_shift, bit_shift) -} - -fn run_shift_right( - x: &[u32], - y: &[u32], - logical: bool, -) -> (Vec, usize, usize) { - let fill = if logical { - 0 - } else { - ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) - }; - let mut result = vec![fill; NUM_LIMBS]; - - let (is_zero, limb_shift, bit_shift) = get_shift::(y); - if is_zero { - return (result, limb_shift, bit_shift); - } - - for i in 0..(NUM_LIMBS - limb_shift) { - result[i] = if i + limb_shift + 1 < NUM_LIMBS { - ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } else { - ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } - } - (result, limb_shift, bit_shift) -} - -fn get_shift(y: &[u32]) -> (bool, usize, usize) { - // We assume `NUM_LIMBS * LIMB_BITS < 2^(2*LIMB_BITS)` so if there are any higher limbs, - // the shifted value is zero. - let shift = (y[0] + (y[1] * (1 << LIMB_BITS))) as usize; - if shift < NUM_LIMBS * LIMB_BITS && y[2..].iter().all(|&val| val == 0) { - (false, shift / LIMB_BITS, shift % LIMB_BITS) - } else { - (true, NUM_LIMBS, shift % LIMB_BITS) - } -} diff --git a/vm/src/old/shift/tests.rs b/vm/src/old/shift/tests.rs deleted file mode 100644 index 80521363bc..0000000000 --- a/vm/src/old/shift/tests.rs +++ /dev/null @@ -1,682 +0,0 @@ -use std::{array, borrow::BorrowMut, iter, sync::Arc}; - -use ax_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, -}; -use ax_stark_backend::{utils::disable_debug_builder, verifier::VerificationError, Chip}; -use ax_stark_sdk::utils::create_seeded_rng; -use axvm_instructions::instruction::Instruction; -use p3_baby_bear::BabyBear; -use p3_field::AbstractField; -use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use rand::{rngs::StdRng, Rng}; -use test_log::test; - -use super::{run_shift, ShiftChip}; -use crate::{ - arch::{ - instructions::U256Opcode, - testing::{memory::gen_pointer, VmChipTestBuilder}, - BITWISE_OP_LOOKUP_BUS, - }, - old::shift::columns::ShiftCols, -}; - -type F = BabyBear; -const NUM_LIMBS: usize = 32; -const LIMB_BITS: usize = 8; - -fn generate_long_number( - rng: &mut StdRng, -) -> Vec { - (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..1 << LIMB_BITS)) - .collect() -} - -fn generate_shift(rng: &mut StdRng) -> Vec { - iter::once(rng.gen_range(0..1 << LIMB_BITS)) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect() -} - -#[allow(clippy::too_many_arguments)] -fn run_shift_rand_write_execute( - tester: &mut VmChipTestBuilder, - chip: &mut ShiftChip, - opcode: U256Opcode, - x: Vec, - y: Vec, - rng: &mut StdRng, -) { - let address_space_range = || 1usize..=2; - - let d = rng.gen_range(address_space_range()); - let e = rng.gen_range(address_space_range()); - - let x_address = gen_pointer(rng, 64); - let y_address = gen_pointer(rng, 64); - let res_address = gen_pointer(rng, 64); - let x_ptr_to_address = gen_pointer(rng, 1); - let y_ptr_to_address = gen_pointer(rng, 1); - let res_ptr_to_address = gen_pointer(rng, 1); - - let x_f = x - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let y_f = y - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write_cell(d, x_ptr_to_address, F::from_canonical_usize(x_address)); - tester.write_cell(d, y_ptr_to_address, F::from_canonical_usize(y_address)); - tester.write_cell(d, res_ptr_to_address, F::from_canonical_usize(res_address)); - tester.write::(e, x_address, x_f.as_slice().try_into().unwrap()); - tester.write::(e, y_address, y_f.as_slice().try_into().unwrap()); - - let (z, _, _) = run_shift::(&x, &y, opcode); - tester.execute( - chip, - Instruction::from_usize( - opcode as usize, - [res_ptr_to_address, x_ptr_to_address, y_ptr_to_address, d, e], - ), - ); - - assert_eq!( - z.into_iter().map(F::from_canonical_u32).collect::>(), - tester.read::(e, res_address) - ) -} - -#[allow(clippy::too_many_arguments)] -fn run_shift_negative_test( - opcode: U256Opcode, - x: Vec, - y: Vec, - z: Vec, - bit_shift: u32, - bit_multiplier_left: u32, - bit_multiplier_right: u32, - x_sign: u32, - bit_shift_carry: Vec, - expected_error: VerificationError, -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ShiftChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - let mut rng = create_seeded_rng(); - run_shift_rand_write_execute::( - &mut tester, - &mut chip, - opcode, - x, - y, - &mut rng, - ); - - if expected_error == VerificationError::NonZeroCumulativeSum { - chip.range_checker_chip.clear(); - chip.range_checker_chip - .add_count(bit_shift, LIMB_BITS.ilog2() as usize); - for (z_val, carry_val) in z.iter().zip(bit_shift_carry.iter()) { - chip.range_checker_chip.add_count(*z_val, LIMB_BITS); - chip.range_checker_chip - .add_count(*carry_val, bit_shift as usize); - } - } - let mut air_proof_input = chip.generate_air_proof_input(); - let shift_trace = air_proof_input.raw.common_main.as_mut().unwrap(); - let mut shift_trace_vec = shift_trace.row_slice(0).to_vec(); - let shift_trace_cols: &mut ShiftCols = (*shift_trace_vec).borrow_mut(); - - shift_trace_cols.io.z.data = array::from_fn(|i| F::from_canonical_u32(z[i])); - shift_trace_cols.aux.bit_shift = F::from_canonical_u32(bit_shift); - shift_trace_cols.aux.bit_multiplier_left = F::from_canonical_u32(bit_multiplier_left); - shift_trace_cols.aux.bit_multiplier_right = F::from_canonical_u32(bit_multiplier_right); - shift_trace_cols.aux.x_sign = F::from_canonical_u32(x_sign); - shift_trace_cols.aux.bit_shift_carry = - array::from_fn(|i| F::from_canonical_u32(bit_shift_carry[i])); - - *shift_trace = RowMajorMatrix::new( - shift_trace_vec, - ShiftCols::::width(), - ); - - disable_debug_builder(); - let mut tester = tester.build(); - tester.air_proof_inputs.push(air_proof_input); - let tester = tester.load(bitwise_chip).finalize(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - &expected_error - ); - let result = tester.simple_test(); - assert_eq!(result.err(), Some(expected_error), "{}", msg); -} - -#[test] -fn shift_sll_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ShiftChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_shift::(&mut rng); - run_shift_rand_write_execute(&mut tester, &mut chip, U256Opcode::SLL, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn shift_sll_wrong_answer_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 2, - 0, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sll_wrong_bit_shift_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 2, - 2, - 0, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sll_wrong_bit_mult_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 4, - 0, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sll_nonzero_bit_mult_right_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 2, - 1, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sll_nonzero_sign_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 3)) - .chain(iter::repeat(1).take(2)) - .collect(), - 1, - 2, - 0, - 1, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sll_out_of_range_carry_negative_test() { - run_shift_negative_test( - U256Opcode::SLL, - iter::once(1 << LIMB_BITS) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(0) - .chain(iter::once(2)) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 2, - 0, - 0, - iter::once(2) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn shift_srl_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ShiftChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_shift::(&mut rng); - run_shift_rand_write_execute(&mut tester, &mut chip, U256Opcode::SRL, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn shift_srl_wrong_answer_negative_test() { - run_shift_negative_test( - U256Opcode::SRL, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 0, - 2, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_srl_wrong_extension_negative_test() { - run_shift_negative_test( - U256Opcode::SRL, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 2)) - .chain(iter::once(1 << (LIMB_BITS - 1))) - .collect(), - 1, - 0, - 2, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_srl_nonzero_bit_mult_left_negative_test() { - run_shift_negative_test( - U256Opcode::SRL, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 2, - 2, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_srl_nonzero_sign_negative_test() { - run_shift_negative_test( - U256Opcode::SRL, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 3)) - .chain(iter::repeat(1).take(2)) - .collect(), - 1, - 2, - 0, - 1, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sra_rand_test() { - let num_ops: usize = 10; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ShiftChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - for _ in 0..num_ops { - let x = generate_long_number::(&mut rng); - let y = generate_shift::(&mut rng); - run_shift_rand_write_execute(&mut tester, &mut chip, U256Opcode::SRA, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn shift_sra_wrong_answer_negative_test() { - run_shift_negative_test( - U256Opcode::SRA, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - 1, - 0, - 2, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sra_wrong_extension_negative_test() { - run_shift_negative_test( - U256Opcode::SRA, - iter::once(4) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::once(2) - .chain(iter::repeat(0).take(NUM_LIMBS - 2)) - .chain(iter::once(1 << (LIMB_BITS - 1))) - .collect(), - 1, - 0, - 2, - 0, - vec![0; NUM_LIMBS], - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn shift_sra_wrong_sign_negative_test() { - run_shift_negative_test( - U256Opcode::SRA, - vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], - iter::once(1) - .chain(iter::repeat(0)) - .take(NUM_LIMBS) - .collect(), - iter::repeat((1 << LIMB_BITS) - 1) - .take(NUM_LIMBS - 1) - .chain(iter::once((1 << (LIMB_BITS - 1)) - 1)) - .collect(), - 1, - 0, - 2, - 0, - vec![1; NUM_LIMBS], - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn shift_overflow_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new(bitwise_bus)); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ShiftChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - bitwise_chip.clone(), - 0, - ); - - let x = generate_long_number::(&mut rng); - let mut y = generate_long_number::(&mut rng); - y[1] = 100; - - run_shift_rand_write_execute( - &mut tester, - &mut chip, - U256Opcode::SLL, - x.clone(), - y.clone(), - &mut rng, - ); - run_shift_rand_write_execute( - &mut tester, - &mut chip, - U256Opcode::SRL, - x.clone(), - y.clone(), - &mut rng, - ); - run_shift_rand_write_execute( - &mut tester, - &mut chip, - U256Opcode::SRA, - x.clone(), - y.clone(), - &mut rng, - ); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn run_sll_sanity_test() { - let x: [u32; 32] = [ - 45, 7, 61, 186, 49, 53, 119, 68, 145, 55, 102, 126, 9, 195, 23, 26, 197, 216, 251, 31, 74, - 237, 141, 92, 98, 184, 176, 106, 64, 29, 58, 246, - ]; - let y: [u32; 32] = [ - 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, - ]; - let z: [u32; 32] = [ - 0, 0, 0, 104, 57, 232, 209, 141, 169, 185, 35, 138, 188, 49, 243, 75, 24, 190, 208, 40, - 198, 222, 255, 80, 106, 111, 228, 18, 195, 133, 85, 3, - ]; - let sll_result = run_shift::<32, 8>(&x, &y, U256Opcode::SLL).0; - for i in 0..32 { - assert_eq!(z[i], sll_result[i]) - } -} - -#[test] -fn run_srl_sanity_test() { - let x: [u32; 32] = [ - 253, 247, 209, 166, 217, 253, 46, 42, 197, 8, 33, 136, 144, 148, 101, 195, 173, 150, 26, - 215, 233, 90, 213, 185, 119, 255, 238, 174, 31, 190, 221, 72, - ]; - let y: [u32; 32] = [ - 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, - ]; - let z: [u32; 32] = [ - 104, 211, 236, 126, 23, 149, 98, 132, 16, 68, 72, 202, 178, 225, 86, 75, 141, 235, 116, - 173, 234, 220, 187, 127, 119, 215, 15, 223, 110, 36, 0, 0, - ]; - let srl_result = run_shift::<32, 8>(&x, &y, U256Opcode::SRL).0; - let sra_result = run_shift::<32, 8>(&x, &y, U256Opcode::SRA).0; - for i in 0..32 { - assert_eq!(z[i], srl_result[i]); - assert_eq!(z[i], sra_result[i]); - } -} - -#[test] -fn run_sra_sanity_test() { - let x: [u32; 32] = [ - 253, 247, 209, 166, 217, 253, 46, 42, 197, 8, 33, 136, 144, 148, 101, 195, 173, 150, 26, - 215, 233, 90, 213, 185, 119, 255, 238, 174, 31, 190, 221, 200, - ]; - let y: [u32; 32] = [ - 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, - ]; - let z: [u32; 32] = [ - 104, 211, 236, 126, 23, 149, 98, 132, 16, 68, 72, 202, 178, 225, 86, 75, 141, 235, 116, - 173, 234, 220, 187, 127, 119, 215, 15, 223, 110, 228, 255, 255, - ]; - let sra_result = run_shift::<32, 8>(&x, &y, U256Opcode::SRA).0; - for i in 0..32 { - assert_eq!(z[i], sra_result[i]) - } -} diff --git a/vm/src/old/shift/trace.rs b/vm/src/old/shift/trace.rs deleted file mode 100644 index aa3a4c0d7a..0000000000 --- a/vm/src/old/shift/trace.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use ax_stark_backend::{ - config::{StarkGenericConfig, Val}, - prover::types::AirProofInput, - rap::{get_air_name, AnyRap}, - Chip, ChipUsageGetter, -}; -use p3_field::{AbstractField, PrimeField32}; -use p3_matrix::dense::RowMajorMatrix; - -use super::{ - columns::{ShiftAuxCols, ShiftCols, ShiftIoCols}, - ShiftChip, ShiftRecord, -}; -use crate::{ - arch::instructions::{U256Opcode, UsizeOpcode}, - old::uint_multiplication::MemoryData, -}; - -impl Chip - for ShiftChip, NUM_LIMBS, LIMB_BITS> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let air = self.air(); - let aux_cols_factory = self.memory_controller.borrow().aux_cols_factory(); - - let width = self.trace_width(); - let height = self.data.len(); - let padded_height = height.next_power_of_two(); - let mut rows = vec![Val::::zero(); width * padded_height]; - - for (row, operation) in rows.chunks_mut(width).zip(self.data) { - let ShiftRecord::, NUM_LIMBS, LIMB_BITS> { - from_state, - instruction, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - bit_shift_carry, - bit_shift, - limb_shift, - x_sign, - } = operation; - - let row: &mut ShiftCols, NUM_LIMBS, LIMB_BITS> = row.borrow_mut(); - - row.io = ShiftIoCols { - from_state: from_state.map(Val::::from_canonical_u32), - x: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: x_read.data, - address: x_read.pointer, - ptr_to_address: x_ptr_read.pointer, - }, - y: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: y_read.data, - address: y_read.pointer, - ptr_to_address: y_ptr_read.pointer, - }, - z: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: z_write.data, - address: z_write.pointer, - ptr_to_address: z_ptr_read.pointer, - }, - ptr_as: instruction.d, - address_as: instruction.e, - }; - - row.aux = ShiftAuxCols { - is_valid: Val::::one(), - bit_shift: Val::::from_canonical_usize(bit_shift), - bit_multiplier_left: Val::::from_canonical_usize(match U256Opcode::from_usize( - instruction.opcode, - ) { - U256Opcode::SLL => 1 << bit_shift, - U256Opcode::SRL | U256Opcode::SRA => 0, - _ => unreachable!(), - }), - bit_multiplier_right: Val::::from_canonical_usize( - match U256Opcode::from_usize(instruction.opcode) { - U256Opcode::SLL => 0, - U256Opcode::SRL | U256Opcode::SRA => 1 << bit_shift, - _ => unreachable!(), - }, - ), - x_sign, - bit_shift_marker: array::from_fn(|val| Val::::from_bool(val == bit_shift)), - limb_shift_marker: array::from_fn(|val| Val::::from_bool(val == limb_shift)), - bit_shift_carry, - opcode_sll_flag: Val::::from_bool( - instruction.opcode == U256Opcode::SLL as usize, - ), - opcode_srl_flag: Val::::from_bool( - instruction.opcode == U256Opcode::SRL as usize, - ), - opcode_sra_flag: Val::::from_bool( - instruction.opcode == U256Opcode::SRA as usize, - ), - read_ptr_aux_cols: [z_ptr_read, x_ptr_read, y_ptr_read] - .map(|read| aux_cols_factory.make_read_aux_cols(read)), - read_x_aux_cols: aux_cols_factory.make_read_aux_cols(x_read), - read_y_aux_cols: aux_cols_factory.make_read_aux_cols(y_read), - write_z_aux_cols: aux_cols_factory.make_write_aux_cols(z_write), - }; - } - AirProofInput::simple_no_pis(air, RowMajorMatrix::new(rows, width)) - } -} - -impl ChipUsageGetter - for ShiftChip -{ - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.data.len() - } - - fn trace_width(&self) -> usize { - ShiftCols::::width() - } -} diff --git a/vm/src/old/uint_multiplication/air.rs b/vm/src/old/uint_multiplication/air.rs deleted file mode 100644 index 0f390fde76..0000000000 --- a/vm/src/old/uint_multiplication/air.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::borrow::Borrow; - -use ax_circuit_primitives::range_tuple::RangeTupleCheckerBus; -use ax_stark_backend::{ - interaction::InteractionBuilder, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; -use p3_matrix::Matrix; - -use crate::{ - arch::ExecutionBridge, old::uint_multiplication::columns::UintMultiplicationCols, - system::memory::offline_checker::MemoryBridge, -}; - -#[derive(Clone, Debug)] -pub struct UintMultiplicationCoreAir { - pub(super) execution_bridge: ExecutionBridge, - pub(super) memory_bridge: MemoryBridge, - pub bus: RangeTupleCheckerBus<2>, - - pub(super) offset: usize, -} - -impl PartitionedBaseAir - for UintMultiplicationCoreAir -{ -} -impl BaseAir - for UintMultiplicationCoreAir -{ - fn width(&self) -> usize { - UintMultiplicationCols::::width() - } -} - -impl BaseAirWithPublicValues - for UintMultiplicationCoreAir -{ -} - -impl Air - for UintMultiplicationCoreAir -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - - let UintMultiplicationCols::<_, NUM_LIMBS, LIMB_BITS> { io, aux } = (*local).borrow(); - builder.assert_bool(aux.is_valid); - - let x_limbs = &io.x.data; - let y_limbs = &io.y.data; - let z_limbs = &io.z.data; - let carry_limbs = &aux.carry; - - for i in 0..NUM_LIMBS { - let lhs = (0..=i).fold( - if i > 0 { - carry_limbs[i - 1].into() - } else { - AB::Expr::zero() - }, - |acc, j| acc + (x_limbs[j] * y_limbs[i - j]), - ); - let rhs = - z_limbs[i] + (carry_limbs[i] * AB::Expr::from_canonical_usize(1 << LIMB_BITS)); - builder.assert_eq(lhs, rhs); - } - - self.eval_interactions(builder, io, aux); - } -} diff --git a/vm/src/old/uint_multiplication/bridge.rs b/vm/src/old/uint_multiplication/bridge.rs deleted file mode 100644 index b59c6f289d..0000000000 --- a/vm/src/old/uint_multiplication/bridge.rs +++ /dev/null @@ -1,92 +0,0 @@ -use ax_stark_backend::interaction::InteractionBuilder; -use itertools::izip; -use p3_field::AbstractField; - -use super::{ - air::UintMultiplicationCoreAir, - columns::{UintMultiplicationAuxCols, UintMultiplicationIoCols}, -}; -use crate::{arch::instructions::U256Opcode, system::memory::MemoryAddress}; - -impl - UintMultiplicationCoreAir -{ - pub fn eval_interactions( - &self, - builder: &mut AB, - io: &UintMultiplicationIoCols, - aux: &UintMultiplicationAuxCols, - ) { - let timestamp: AB::Var = io.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) - }; - - for (ptr, value, mem_aux) in izip!( - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address - ], - [io.z.address, io.x.address, io.y.address], - &aux.read_ptr_aux_cols - ) { - self.memory_bridge - .read( - MemoryAddress::new(io.ptr_as, ptr), - [value], - timestamp_pp(), - mem_aux, - ) - .eval(builder, aux.is_valid); - } - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.x.address), - io.x.data, - timestamp_pp(), - &aux.read_x_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.memory_bridge - .read( - MemoryAddress::new(io.address_as, io.y.address), - io.y.data, - timestamp_pp(), - &aux.read_y_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.memory_bridge - .write( - MemoryAddress::new(io.address_as, io.z.address), - io.z.data, - timestamp_pp(), - &aux.write_z_aux_cols, - ) - .eval(builder, aux.is_valid); - - self.execution_bridge - .execute_and_increment_pc( - AB::Expr::from_canonical_usize(U256Opcode::MUL as usize + self.offset), - [ - io.z.ptr_to_address, - io.x.ptr_to_address, - io.y.ptr_to_address, - io.ptr_as, - io.address_as, - ], - io.from_state, - AB::F::from_canonical_usize(timestamp_delta), - ) - .eval(builder, aux.is_valid); - - for (z, carry) in io.z.data.iter().zip(aux.carry.iter()) { - self.bus.send(vec![*z, *carry]).eval(builder, aux.is_valid); - } - } -} diff --git a/vm/src/old/uint_multiplication/columns.rs b/vm/src/old/uint_multiplication/columns.rs deleted file mode 100644 index 5919df3516..0000000000 --- a/vm/src/old/uint_multiplication/columns.rs +++ /dev/null @@ -1,43 +0,0 @@ -use ax_circuit_derive::AlignedBorrow; - -use crate::{ - arch::ExecutionState, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct UintMultiplicationCols { - pub io: UintMultiplicationIoCols, - pub aux: UintMultiplicationAuxCols, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct UintMultiplicationIoCols { - pub from_state: ExecutionState, - pub x: MemoryData, - pub y: MemoryData, - pub z: MemoryData, - pub ptr_as: T, - pub address_as: T, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct UintMultiplicationAuxCols { - pub is_valid: T, - pub carry: [T; NUM_LIMBS], - pub read_ptr_aux_cols: [MemoryReadAuxCols; 3], - pub read_x_aux_cols: MemoryReadAuxCols, - pub read_y_aux_cols: MemoryReadAuxCols, - pub write_z_aux_cols: MemoryWriteAuxCols, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct MemoryData { - pub data: [T; NUM_LIMBS], - pub address: T, - pub ptr_to_address: T, -} diff --git a/vm/src/old/uint_multiplication/mod.rs b/vm/src/old/uint_multiplication/mod.rs deleted file mode 100644 index 655b01c109..0000000000 --- a/vm/src/old/uint_multiplication/mod.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::sync::Arc; - -use ax_circuit_primitives::range_tuple::RangeTupleCheckerChip; -use p3_field::PrimeField32; - -use crate::{ - arch::{ - instructions::U256Opcode, ExecutionBridge, ExecutionBus, ExecutionState, - InstructionExecutor, - }, - system::{ - memory::{MemoryControllerRef, MemoryReadRecord, MemoryWriteRecord}, - program::{ExecutionError, ProgramBus}, - }, -}; - -mod air; -mod bridge; -mod columns; -mod trace; - -pub use air::*; -use axvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; -pub use columns::*; - -#[cfg(test)] -pub mod tests; - -#[derive(Clone, Debug)] -pub struct UintMultiplicationRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - pub x_ptr_read: MemoryReadRecord, - pub y_ptr_read: MemoryReadRecord, - pub z_ptr_read: MemoryReadRecord, - pub x_read: MemoryReadRecord, - pub y_read: MemoryReadRecord, - pub z_write: MemoryWriteRecord, - pub carry: Vec, -} - -#[derive(Debug)] -pub struct UintMultiplicationChip { - pub air: UintMultiplicationCoreAir, - data: Vec>, - memory_controller: MemoryControllerRef, - pub range_tuple_chip: Arc>, - - offset: usize, -} - -impl - UintMultiplicationChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_controller: MemoryControllerRef, - range_tuple_chip: Arc>, - offset: usize, - ) -> Self { - assert!(LIMB_BITS < 16, "LIMB_BITS {} >= 16", LIMB_BITS); - - let bus = range_tuple_chip.bus(); - - assert_eq!(bus.sizes.len(), 2); - assert!( - bus.sizes[0] >= 1 << LIMB_BITS, - "bus.sizes[0] {} < 2^LIMB_BITS {}", - bus.sizes[0], - 1 << LIMB_BITS - ); - assert!( - bus.sizes[1] >= (NUM_LIMBS * (1 << LIMB_BITS)) as u32, - "bus.sizes[1] {} < (NUM_LIMBS * 2^LIMB_BITS) {}", - bus.sizes[1], - NUM_LIMBS * (1 << LIMB_BITS) - ); - - let memory_bridge = memory_controller.borrow().memory_bridge(); - Self { - air: UintMultiplicationCoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: *bus, - offset, - }, - data: vec![], - memory_controller, - range_tuple_chip, - offset, - } - } -} - -impl InstructionExecutor - for UintMultiplicationChip -{ - fn execute( - &mut self, - instruction: Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - let local_opcode_index = opcode - self.offset; - assert!(local_opcode_index == U256Opcode::MUL as usize); - - let mut memory_controller = self.memory_controller.borrow_mut(); - debug_assert_eq!(from_state.timestamp, memory_controller.timestamp()); - - let [z_ptr_read, x_ptr_read, y_ptr_read] = - [a, b, c].map(|ptr_of_ptr| memory_controller.read_cell(d, ptr_of_ptr)); - let x_read = memory_controller.read::(e, x_ptr_read.value()); - let y_read = memory_controller.read::(e, y_ptr_read.value()); - - let x = x_read.data.map(|x| x.as_canonical_u32()); - let y = y_read.data.map(|x| x.as_canonical_u32()); - let (z, carry) = run_uint_multiplication::(&x, &y); - - for (z_val, carry_val) in z.iter().zip(carry.iter()) { - self.range_tuple_chip.add_count(&[*z_val, *carry_val]); - } - - let z_write = memory_controller.write::( - e, - z_ptr_read.value(), - z.into_iter() - .map(T::from_canonical_u32) - .collect::>() - .try_into() - .unwrap(), - ); - - self.data.push(UintMultiplicationRecord { - from_state, - instruction: instruction.clone(), - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - carry: carry.into_iter().map(T::from_canonical_u32).collect(), - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory_controller.timestamp(), - }) - } - - fn get_opcode_name(&self, _: usize) -> String { - format!("{:?}<{NUM_LIMBS},{LIMB_BITS}>", U256Opcode::MUL) - } -} - -fn run_uint_multiplication( - x: &[u32], - y: &[u32], -) -> (Vec, Vec) { - let mut result = vec![0; NUM_LIMBS]; - let mut carry = vec![0; NUM_LIMBS]; - for i in 0..NUM_LIMBS { - if i > 0 { - result[i] = carry[i - 1]; - } - for j in 0..=i { - result[i] += x[j] * y[i - j]; - } - carry[i] = result[i] >> LIMB_BITS; - result[i] %= 1 << LIMB_BITS; - } - (result, carry) -} diff --git a/vm/src/old/uint_multiplication/tests.rs b/vm/src/old/uint_multiplication/tests.rs deleted file mode 100644 index 41b3f81f54..0000000000 --- a/vm/src/old/uint_multiplication/tests.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::{array, borrow::BorrowMut, iter, sync::Arc}; - -use ax_circuit_primitives::range_tuple::{RangeTupleCheckerBus, RangeTupleCheckerChip}; -use ax_stark_backend::{utils::disable_debug_builder, verifier::VerificationError, Chip}; -use ax_stark_sdk::utils::create_seeded_rng; -use axvm_instructions::instruction::Instruction; -use p3_baby_bear::BabyBear; -use p3_field::AbstractField; -use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use rand::{rngs::StdRng, Rng}; - -use super::{columns::UintMultiplicationCols, run_uint_multiplication, UintMultiplicationChip}; -use crate::arch::{ - instructions::U256Opcode, - testing::{memory::gen_pointer, VmChipTestBuilder}, - RANGE_TUPLE_CHECKER_BUS, -}; - -type F = BabyBear; - -fn generate_uint_number( - rng: &mut StdRng, -) -> Vec { - (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..1 << LIMB_BITS)) - .collect() -} - -fn run_uint_multiplication_rand_write_execute( - tester: &mut VmChipTestBuilder, - chip: &mut UintMultiplicationChip, - x: Vec, - y: Vec, - rng: &mut StdRng, -) { - let address_space_range = || 1usize..=2; - - let d = rng.gen_range(address_space_range()); - let e = rng.gen_range(address_space_range()); - - let x_address = gen_pointer(rng, 64); - let y_address = gen_pointer(rng, 64); - let z_address = gen_pointer(rng, 64); - let x_ptr_to_address = gen_pointer(rng, 1); - let y_ptr_to_address = gen_pointer(rng, 1); - let z_ptr_to_address = gen_pointer(rng, 1); - - let x_f = x - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - let y_f = y - .clone() - .into_iter() - .map(F::from_canonical_u32) - .collect::>(); - - tester.write_cell(d, x_ptr_to_address, F::from_canonical_usize(x_address)); - tester.write_cell(d, y_ptr_to_address, F::from_canonical_usize(y_address)); - tester.write_cell(d, z_ptr_to_address, F::from_canonical_usize(z_address)); - tester.write::(e, x_address, x_f.as_slice().try_into().unwrap()); - tester.write::(e, y_address, y_f.as_slice().try_into().unwrap()); - - let (z, _) = run_uint_multiplication::(&x, &y); - tester.execute( - chip, - Instruction::from_usize( - U256Opcode::MUL as usize, - [z_ptr_to_address, x_ptr_to_address, y_ptr_to_address, d, e], - ), - ); - assert_eq!( - z.into_iter().map(F::from_canonical_u32).collect::>(), - tester.read::(e, z_address) - ); -} - -fn run_negative_uint_multiplication_test( - x: Vec, - y: Vec, - z: Vec, - carry: Vec, - expected_error: VerificationError, -) { - let bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << LIMB_BITS, (NUM_LIMBS * (1 << LIMB_BITS)) as u32], - ); - let range_tuple_chip: Arc> = Arc::new(RangeTupleCheckerChip::new(bus)); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = UintMultiplicationChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - range_tuple_chip.clone(), - 0, - ); - - let mut rng = create_seeded_rng(); - run_uint_multiplication_rand_write_execute(&mut tester, &mut chip, x, y, &mut rng); - - let mut air_proof_input = chip.generate_air_proof_input(); - let mult_trace = air_proof_input.raw.common_main.as_mut().unwrap(); - let mut mult_trace_vec = mult_trace.row_slice(0).to_vec(); - let mult_trace_cols: &mut UintMultiplicationCols = - (*mult_trace_vec).borrow_mut(); - - mult_trace_cols.io.z.data = array::from_fn(|i| F::from_canonical_u32(z[i])); - mult_trace_cols.aux.carry = array::from_fn(|i| F::from_canonical_u32(carry[i])); - *mult_trace = RowMajorMatrix::new( - mult_trace_vec, - UintMultiplicationCols::::width(), - ); - - disable_debug_builder(); - let tester = tester - .build() - .load_air_proof_input(air_proof_input) - .load(range_tuple_chip) - .finalize(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - &expected_error - ); - let result = tester.simple_test(); - assert_eq!(result.err(), Some(expected_error), "{}", msg); -} - -#[test] -fn uint_multiplication_rand_air_test() { - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - let num_ops: usize = 10; - - let bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << LIMB_BITS, (NUM_LIMBS * (1 << LIMB_BITS)) as u32], - ); - let range_tuple_chip: Arc> = Arc::new(RangeTupleCheckerChip::new(bus)); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = UintMultiplicationChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_controller(), - range_tuple_chip.clone(), - 0, - ); - - let mut rng = create_seeded_rng(); - - for _ in 0..num_ops { - let x = generate_uint_number::(&mut rng); - let y = generate_uint_number::(&mut rng); - run_uint_multiplication_rand_write_execute(&mut tester, &mut chip, x, y, &mut rng); - } - - let tester = tester.build().load(chip).load(range_tuple_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn negative_uint_multiplication_wrong_calc_test() { - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - // x = 00...0001 - // y = 00...0001 - // z = 00...0002 - // carry = 00...0000 - run_negative_uint_multiplication_test::( - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - iter::once(2).chain(iter::repeat(0).take(31)).collect(), - iter::repeat(0).take(32).collect(), - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn negative_uint_multiplication_wrong_carry_test() { - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - // x = 00...0001 - // y = 00...0001 - // z = 00...0001 - // carry = 00...0001 - run_negative_uint_multiplication_test::( - iter::once(1).chain(iter::repeat(0).take(31)).collect(), // 10000000... - iter::once(1).chain(iter::repeat(0).take(31)).collect(), // 10000000... - iter::once(1).chain(iter::repeat(0).take(31)).collect(), // 10000000... - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - VerificationError::OodEvaluationMismatch, - ); -} - -#[test] -fn negative_uint_multiplication_out_of_range_z_test() { - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - // x = 00...000[2^8] (out of range) - // y = 00...0001 - // z = 00...000[2^8] (out of range) - // carry = 00...0000 - run_negative_uint_multiplication_test::( - iter::once(1 << LIMB_BITS) - .chain(iter::repeat(0).take(31)) - .collect(), - iter::once(1).chain(iter::repeat(0).take(31)).collect(), - iter::once(1 << LIMB_BITS) - .chain(iter::repeat(0).take(31)) - .collect(), - iter::repeat(0).take(32).collect(), - VerificationError::NonZeroCumulativeSum, - ); -} - -#[test] -fn negative_uint_multiplication_out_of_range_carry_test() { - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - // x = 00...000[2^12] (out of range) - // y = 00...000[2^12] (out of range) - // z = 00...1000 - // carry = 00...01[2^8][2^16] (out of range) - run_negative_uint_multiplication_test::( - iter::once(1 << (LIMB_BITS + (LIMB_BITS / 2))) - .chain(iter::repeat(0).take(31)) - .collect(), - iter::once(1 << (LIMB_BITS + (LIMB_BITS / 2))) - .chain(iter::repeat(0).take(31)) - .collect(), - iter::repeat(0) - .take(3) - .chain(iter::once(1)) - .chain(iter::repeat(0).take(28)) - .collect(), - iter::once(1 << (2 * LIMB_BITS)) - .chain(iter::once(1 << LIMB_BITS)) - .chain(iter::once(1)) - .chain(iter::repeat(0).take(29)) - .collect(), - VerificationError::NonZeroCumulativeSum, - ); -} diff --git a/vm/src/old/uint_multiplication/trace.rs b/vm/src/old/uint_multiplication/trace.rs deleted file mode 100644 index 71fe320929..0000000000 --- a/vm/src/old/uint_multiplication/trace.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use ax_stark_backend::{ - config::{StarkGenericConfig, Val}, - prover::types::AirProofInput, - rap::{get_air_name, AnyRap}, - Chip, ChipUsageGetter, -}; -use p3_field::{AbstractField, PrimeField32}; -use p3_matrix::dense::RowMajorMatrix; - -use super::{ - columns::{ - MemoryData, UintMultiplicationAuxCols, UintMultiplicationCols, UintMultiplicationIoCols, - }, - UintMultiplicationChip, UintMultiplicationRecord, -}; - -impl Chip - for UintMultiplicationChip, NUM_LIMBS, LIMB_BITS> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let air = self.air(); - let aux_cols_factory = self.memory_controller.borrow().aux_cols_factory(); - - let width = self.trace_width(); - let height = self.data.len(); - let padded_height = height.next_power_of_two(); - let mut rows = vec![Val::::zero(); width * padded_height]; - - for (row, operation) in rows.chunks_mut(width).zip(self.data) { - let UintMultiplicationRecord::, NUM_LIMBS, LIMB_BITS> { - from_state, - instruction, - x_ptr_read, - y_ptr_read, - z_ptr_read, - x_read, - y_read, - z_write, - carry, - } = operation; - - let row: &mut UintMultiplicationCols, NUM_LIMBS, LIMB_BITS> = row.borrow_mut(); - - row.io = UintMultiplicationIoCols { - from_state: from_state.map(Val::::from_canonical_u32), - x: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: x_read.data, - address: x_read.pointer, - ptr_to_address: x_ptr_read.pointer, - }, - y: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: y_read.data, - address: y_read.pointer, - ptr_to_address: y_ptr_read.pointer, - }, - z: MemoryData::, NUM_LIMBS, LIMB_BITS> { - data: z_write.data, - address: z_write.pointer, - ptr_to_address: z_ptr_read.pointer, - }, - ptr_as: instruction.d, - address_as: instruction.e, - }; - - row.aux = UintMultiplicationAuxCols { - is_valid: Val::::one(), - carry: array::from_fn(|i| carry[i]), - read_ptr_aux_cols: [z_ptr_read, x_ptr_read, y_ptr_read] - .map(|read| aux_cols_factory.make_read_aux_cols(read)), - read_x_aux_cols: aux_cols_factory.make_read_aux_cols(x_read), - read_y_aux_cols: aux_cols_factory.make_read_aux_cols(y_read), - write_z_aux_cols: aux_cols_factory.make_write_aux_cols(z_write), - }; - } - AirProofInput::simple_no_pis(air, RowMajorMatrix::new(rows, width)) - } -} - -impl ChipUsageGetter - for UintMultiplicationChip -{ - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.data.len() - } - - fn trace_width(&self) -> usize { - UintMultiplicationCols::::width() - } -} diff --git a/vm/src/rv32im/adapters/heap.rs b/vm/src/rv32im/adapters/heap.rs new file mode 100644 index 0000000000..0b57bffcc7 --- /dev/null +++ b/vm/src/rv32im/adapters/heap.rs @@ -0,0 +1,212 @@ +use std::{array::from_fn, borrow::Borrow, cell::RefCell, marker::PhantomData}; + +use ax_stark_backend::interaction::InteractionBuilder; +use axvm_instructions::instruction::Instruction; +use p3_air::BaseAir; +use p3_field::{Field, PrimeField32}; + +use super::{ + read_rv32_register, vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, + Rv32VecHeapAdapterCols, Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord, +}; +use crate::{ + arch::{ + AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, + ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, + VmAdapterInterface, + }, + system::{ + memory::{ + offline_checker::MemoryBridge, MemoryAuxColsFactory, MemoryController, + MemoryControllerRef, + }, + program::ProgramBus, + }, +}; + +/// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. +/// * The data is read from the heap (address space 2), and the pointers +/// are read from registers (address space 1). +/// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). +/// * Writes are to the address in `rd`. + +#[derive(Clone, Copy, Debug, derive_new::new)] +pub struct Rv32HeapAdapterAir< + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pub(super) execution_bridge: ExecutionBridge, + pub(super) memory_bridge: MemoryBridge, + /// The max number of bits for an address in memory + address_bits: usize, +} + +impl BaseAir + for Rv32HeapAdapterAir +{ + fn width(&self) -> usize { + Rv32VecHeapAdapterCols::::width() + } +} + +impl< + AB: InteractionBuilder, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > VmAdapterAir for Rv32HeapAdapterAir +{ + type Interface = BasicAdapterInterface< + AB::Expr, + MinimalInstruction, + NUM_READS, + 1, + READ_SIZE, + WRITE_SIZE, + >; + + fn eval( + &self, + builder: &mut AB, + local: &[AB::Var], + ctx: AdapterAirContext, + ) { + let vec_heap_air: Rv32VecHeapAdapterAir = + Rv32VecHeapAdapterAir::new( + self.execution_bridge, + self.memory_bridge, + self.address_bits, + ); + vec_heap_air.eval(builder, local, ctx.into()); + } + + fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { + let cols: &Rv32VecHeapAdapterCols<_, NUM_READS, 1, 1, READ_SIZE, WRITE_SIZE> = + local.borrow(); + cols.from_state.pc + } +} + +#[derive(Debug)] +pub struct Rv32HeapAdapterChip< + F: Field, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pub air: Rv32HeapAdapterAir, + _marker: PhantomData, +} + +impl + Rv32HeapAdapterChip +{ + pub fn new( + execution_bus: ExecutionBus, + program_bus: ProgramBus, + memory_controller: MemoryControllerRef, + ) -> Self { + assert!(NUM_READS <= 2); + let memory_controller = RefCell::borrow(&memory_controller); + let memory_bridge = memory_controller.memory_bridge(); + let address_bits = memory_controller.mem_config.pointer_max_bits; + Self { + air: Rv32HeapAdapterAir { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + address_bits, + }, + _marker: PhantomData, + } + } +} + +impl + VmAdapterChip for Rv32HeapAdapterChip +{ + type ReadRecord = Rv32VecHeapReadRecord; + type WriteRecord = Rv32VecHeapWriteRecord; + type Air = Rv32HeapAdapterAir; + type Interface = + BasicAdapterInterface, NUM_READS, 1, READ_SIZE, WRITE_SIZE>; + + fn preprocess( + &mut self, + memory: &mut MemoryController, + instruction: &Instruction, + ) -> Result<( + >::Reads, + Self::ReadRecord, + )> { + let Instruction { a, b, c, d, e, .. } = *instruction; + + debug_assert_eq!(d.as_canonical_u32(), 1); + debug_assert_eq!(e.as_canonical_u32(), 2); + + let mut rs_vals = [0; NUM_READS]; + let rs_records: [_; NUM_READS] = from_fn(|i| { + let addr = if i == 0 { b } else { c }; + let (record, val) = read_rv32_register(memory, d, addr); + rs_vals[i] = val; + record + }); + let (rd_record, rd_val) = read_rv32_register(memory, d, a); + + let read_records = rs_vals.map(|address| { + debug_assert!(address < (1 << self.air.address_bits)); + [memory.read::(e, F::from_canonical_u32(address))] + }); + let read_data = read_records.map(|r| r[0].data); + + let record = Rv32VecHeapReadRecord { + rs: rs_records, + rd: rd_record, + rd_val: F::from_canonical_u32(rd_val), + reads: read_records, + }; + + Ok((read_data, record)) + } + + fn postprocess( + &mut self, + memory: &mut MemoryController, + instruction: &Instruction, + from_state: ExecutionState, + output: AdapterRuntimeContext, + read_record: &Self::ReadRecord, + ) -> Result<(ExecutionState, Self::WriteRecord)> { + let e = instruction.e; + let writes = [memory.write(e, read_record.rd_val, output.writes[0])]; + + let timestamp_delta = memory.timestamp() - from_state.timestamp; + debug_assert!( + timestamp_delta == 6, + "timestamp delta is {}, expected 6", + timestamp_delta + ); + + Ok(( + ExecutionState { + pc: from_state.pc + 4, + timestamp: memory.timestamp(), + }, + Self::WriteRecord { from_state, writes }, + )) + } + + fn generate_trace_row( + &self, + row_slice: &mut [F], + read_record: Self::ReadRecord, + write_record: Self::WriteRecord, + aux_cols_factory: &MemoryAuxColsFactory, + ) { + vec_heap_generate_trace_row_impl(row_slice, &read_record, &write_record, aux_cols_factory); + } + + fn air(&self) -> &Self::Air { + &self.air + } +} diff --git a/vm/src/rv32im/adapters/mod.rs b/vm/src/rv32im/adapters/mod.rs index b8522fe525..a2e047e33f 100644 --- a/vm/src/rv32im/adapters/mod.rs +++ b/vm/src/rv32im/adapters/mod.rs @@ -1,5 +1,6 @@ mod alu; mod branch; +mod heap; mod hintstore; mod jalr; mod loadstore; @@ -10,6 +11,7 @@ mod vec_heap; pub use alu::*; pub use axvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; pub use branch::*; +pub use heap::*; pub use hintstore::*; pub use jalr::*; pub use loadstore::*; @@ -17,6 +19,9 @@ pub use mul::*; pub use rdwrite::*; pub use vec_heap::*; +/// 256-bit heap integer stored as 32 bytes (32 limbs of 8-bits) +pub const INT256_NUM_LIMBS: usize = 32; + // For soundness, should be <= 16 pub const RV_IS_TYPE_IMM_BITS: usize = 12; diff --git a/vm/src/rv32im/adapters/vec_heap.rs b/vm/src/rv32im/adapters/vec_heap.rs index f8ed12ff1b..1be3fad40a 100644 --- a/vm/src/rv32im/adapters/vec_heap.rs +++ b/vm/src/rv32im/adapters/vec_heap.rs @@ -32,38 +32,41 @@ use crate::{ /// This adapter reads from R (R <= 2) pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers /// are read from registers (address space 1). -/// * Reads take the form of `NUM_READS` consecutive reads of size `READ_SIZE` -/// from the heap, starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). -/// * Writes take the form of `NUM_WRITES` consecutive writes of size `WRITE_SIZE` -/// to the heap, starting from the address in `rd`. +/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size +/// `READ_SIZE` from the heap, starting from the addresses in `rs[0]` +/// (and `rs[1]` if `R = 2`). +/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of +/// size `WRITE_SIZE` to the heap, starting from the address in `rd`. #[derive(Debug)] pub struct Rv32VecHeapAdapterChip< F: Field, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > { - pub air: Rv32VecHeapAdapterAir, + pub air: + Rv32VecHeapAdapterAir, _marker: PhantomData, } impl< F: PrimeField32, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > Rv32VecHeapAdapterChip + > + Rv32VecHeapAdapterChip { pub fn new( execution_bus: ExecutionBus, program_bus: ProgramBus, memory_controller: MemoryControllerRef, ) -> Self { - assert!(R <= 2); + assert!(NUM_READS <= 2); let memory_controller = RefCell::borrow(&memory_controller); let memory_bridge = memory_controller.memory_bridge(); let address_bits = memory_controller.mem_config.pointer_max_bits; @@ -81,58 +84,59 @@ impl< #[derive(Clone, Debug)] pub struct Rv32VecHeapReadRecord< F: Field, - const R: usize, const NUM_READS: usize, + const BLOCKS_PER_READ: usize, const READ_SIZE: usize, > { /// Read register value from address space e=1 - pub rs: [MemoryReadRecord; R], + pub rs: [MemoryReadRecord; NUM_READS], /// Read register value from address space d=1 pub rd: MemoryReadRecord, pub rd_val: F, - pub reads: [[MemoryReadRecord; NUM_READS]; R], + pub reads: [[MemoryReadRecord; BLOCKS_PER_READ]; NUM_READS], } #[derive(Clone, Debug)] -pub struct Rv32VecHeapWriteRecord { +pub struct Rv32VecHeapWriteRecord +{ pub from_state: ExecutionState, - pub writes: [MemoryWriteRecord; NUM_WRITES], + pub writes: [MemoryWriteRecord; BLOCKS_PER_WRITE], } #[repr(C)] #[derive(AlignedBorrow)] pub struct Rv32VecHeapAdapterCols< T, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > { pub from_state: ExecutionState, + pub rs_ptr: [T; NUM_READS], pub rd_ptr: T, - pub rs_ptr: [T; R], + pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS], pub rd_val: [T; RV32_REGISTER_NUM_LIMBS], - pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; R], - pub rs_read_aux: [MemoryReadAuxCols; R], + pub rs_read_aux: [MemoryReadAuxCols; NUM_READS], pub rd_read_aux: MemoryReadAuxCols, - pub reads_aux: [[MemoryReadAuxCols; NUM_READS]; R], - pub writes_aux: [MemoryWriteAuxCols; NUM_WRITES], + pub reads_aux: [[MemoryReadAuxCols; BLOCKS_PER_READ]; NUM_READS], + pub writes_aux: [MemoryWriteAuxCols; BLOCKS_PER_WRITE], } #[allow(dead_code)] #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32VecHeapAdapterAir< - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > { @@ -144,29 +148,44 @@ pub struct Rv32VecHeapAdapterAir< impl< F: Field, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > BaseAir for Rv32VecHeapAdapterAir + > BaseAir + for Rv32VecHeapAdapterAir { fn width(&self) -> usize { - Rv32VecHeapAdapterCols::::width() + Rv32VecHeapAdapterCols::< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >::width() } } impl< AB: InteractionBuilder, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterAir for Rv32VecHeapAdapterAir + > VmAdapterAir + for Rv32VecHeapAdapterAir { - type Interface = - VecHeapAdapterInterface; + type Interface = VecHeapAdapterInterface< + AB::Expr, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >; fn eval( &self, @@ -174,8 +193,14 @@ impl< local: &[AB::Var], ctx: AdapterAirContext, ) { - let cols: &Rv32VecHeapAdapterCols<_, R, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE> = - local.borrow(); + let cols: &Rv32VecHeapAdapterCols< + _, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = local.borrow(); let timestamp = cols.from_state.timestamp; let mut timestamp_delta: usize = 0; let mut timestamp_pp = || { @@ -201,22 +226,15 @@ impl< // Compose the u32 register value into single field element, with // a range check on the highest limb. - let mut reg_val_f: Vec<_> = cols - .rs_val - .iter() - .chain(once(&cols.rd_val)) - .map(|&decomp| { - // TODO: range check - decomp - .into_iter() - .enumerate() - .fold(AB::Expr::zero(), |acc, (i, limb)| { - acc + limb * AB::Expr::from_canonical_usize(1 << (i * RV32_CELL_BITS)) - }) - }) - .collect(); - let rd_val_f = reg_val_f.pop().unwrap(); - let rs_val_f = reg_val_f; + let register_to_field = |r: [AB::Var; RV32_REGISTER_NUM_LIMBS]| { + r.into_iter() + .enumerate() + .fold(AB::Expr::zero(), |acc, (i, limb)| { + acc + limb * AB::Expr::from_canonical_usize(1 << (i * RV32_CELL_BITS)) + }) + }; + let rd_val_f = register_to_field(cols.rd_val); + let rs_val_f = cols.rs_val.map(register_to_field); let e = AB::F::from_canonical_usize(2); // Reads from heap @@ -275,26 +293,47 @@ impl< } fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &Rv32VecHeapAdapterCols<_, R, NUM_READS, NUM_WRITES, READ_SIZE, WRITE_SIZE> = - local.borrow(); + let cols: &Rv32VecHeapAdapterCols< + _, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = local.borrow(); cols.from_state.pc } } impl< F: PrimeField32, - const R: usize, const NUM_READS: usize, - const NUM_WRITES: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, > VmAdapterChip - for Rv32VecHeapAdapterChip + for Rv32VecHeapAdapterChip< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord; - type Air = Rv32VecHeapAdapterAir; - type Interface = VecHeapAdapterInterface; + type ReadRecord = Rv32VecHeapReadRecord; + type WriteRecord = Rv32VecHeapWriteRecord; + type Air = + Rv32VecHeapAdapterAir; + type Interface = VecHeapAdapterInterface< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >; fn preprocess( &mut self, @@ -309,8 +348,8 @@ impl< debug_assert_eq!(d.as_canonical_u32(), 1); debug_assert_eq!(e.as_canonical_u32(), 2); - let mut rs_vals = [0; R]; - let rs_records: [_; R] = from_fn(|i| { + let mut rs_vals = [0; NUM_READS]; + let rs_records: [_; NUM_READS] = from_fn(|i| { let addr = if i == 0 { b } else { c }; let (record, val) = read_rv32_register(memory, d, addr); rs_vals[i] = val; @@ -372,35 +411,51 @@ impl< write_record: Self::WriteRecord, aux_cols_factory: &MemoryAuxColsFactory, ) { - let row_slice: &mut Rv32VecHeapAdapterCols< - F, - R, - NUM_READS, - NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - row_slice.rd_ptr = read_record.rd.pointer; - row_slice.rs_ptr = read_record.rs.map(|r| r.pointer); - - row_slice.rd_val = read_record.rd.data; - row_slice.rs_val = read_record.rs.map(|r| r.data); - - row_slice.rs_read_aux = read_record - .rs - .map(|r| aux_cols_factory.make_read_aux_cols(r)); - row_slice.rd_read_aux = aux_cols_factory.make_read_aux_cols(read_record.rd); - row_slice.reads_aux = read_record - .reads - .map(|r| r.map(|x| aux_cols_factory.make_read_aux_cols(x))); - row_slice.writes_aux = write_record - .writes - .map(|w| aux_cols_factory.make_write_aux_cols(w)); + vec_heap_generate_trace_row_impl(row_slice, &read_record, &write_record, aux_cols_factory) } fn air(&self) -> &Self::Air { &self.air } } + +pub(super) fn vec_heap_generate_trace_row_impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +>( + row_slice: &mut [F], + read_record: &Rv32VecHeapReadRecord, + write_record: &Rv32VecHeapWriteRecord, + aux_cols_factory: &MemoryAuxColsFactory, +) { + let row_slice: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = row_slice.borrow_mut(); + row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + + row_slice.rd_ptr = read_record.rd.pointer; + row_slice.rs_ptr = read_record.rs.map(|r| r.pointer); + + row_slice.rd_val = read_record.rd.data; + row_slice.rs_val = read_record.rs.map(|r| r.data); + + row_slice.rs_read_aux = read_record + .rs + .map(|r| aux_cols_factory.make_read_aux_cols(r)); + row_slice.rd_read_aux = aux_cols_factory.make_read_aux_cols(read_record.rd); + row_slice.reads_aux = read_record + .reads + .map(|r| r.map(|x| aux_cols_factory.make_read_aux_cols(x))); + row_slice.writes_aux = write_record + .writes + .map(|w| aux_cols_factory.make_write_aux_cols(w)); +} diff --git a/vm/src/rv32im/base_alu/core.rs b/vm/src/rv32im/base_alu/core.rs index a115887cfc..9bc8d3b408 100644 --- a/vm/src/rv32im/base_alu/core.rs +++ b/vm/src/rv32im/base_alu/core.rs @@ -210,7 +210,7 @@ where let c = data[1].map(|y| y.as_canonical_u32()); let a = run_alu::(local_opcode_index, &b, &c); - let output: AdapterRuntimeContext = AdapterRuntimeContext { + let output = AdapterRuntimeContext { to_pc: None, writes: [a.map(F::from_canonical_u32)].into(), }; diff --git a/vm/src/rv32im/base_alu/tests.rs b/vm/src/rv32im/base_alu/tests.rs index 6c2352ea48..7bb92aa515 100644 --- a/vm/src/rv32im/base_alu/tests.rs +++ b/vm/src/rv32im/base_alu/tests.rs @@ -15,20 +15,22 @@ use p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }; -use rand::{rngs::StdRng, Rng}; +use rand::Rng; use super::{core::run_alu, BaseAluCoreChip, Rv32BaseAluChip}; use crate::{ arch::{ instructions::BaseAluOpcode, - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, + testing::{TestAdapterChip, VmChipTestBuilder}, + ExecutionBridge, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, }, rv32im::{ adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, base_alu::BaseAluCoreCols, }, - utils::{generate_long_number, generate_rv32_is_type_immediate}, + utils::{ + generate_long_number, generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; @@ -40,42 +42,6 @@ type F = BabyBear; /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_rv32_alu_rand_write_execute>( - tester: &mut VmChipTestBuilder, - chip: &mut E, - opcode: BaseAluOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - c_imm: Option, - rng: &mut StdRng, -) { - let is_imm = c_imm.is_some(); - - let rs1 = gen_pointer(rng, 4); - let rs2 = c_imm.unwrap_or_else(|| gen_pointer(rng, 4)); - let rd = gen_pointer(rng, 4); - - tester.write::(1, rs1, b.map(F::from_canonical_u32)); - if !is_imm { - tester.write::(1, rs2, c.map(F::from_canonical_u32)); - } - - let a = run_alu::(opcode, &b, &c); - tester.execute( - chip, - Instruction::from_usize( - opcode as usize, - [rd, rs1, rs2, 1, if is_imm { 0 } else { 1 }], - ), - ); - - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ); -} - fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); @@ -105,7 +71,14 @@ fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { let (imm, c) = generate_rv32_is_type_immediate(&mut rng); (Some(imm), c) }; - run_rv32_alu_rand_write_execute(&mut tester, &mut chip, opcode, b, c, c_imm, &mut rng); + + let (instruction, rd) = + rv32_rand_write_register_or_imm(&mut tester, b, c, c_imm, opcode as usize, &mut rng); + tester.execute(&mut chip, instruction); + + let a = run_alu::(opcode, &b, &c) + .map(F::from_canonical_u32); + assert_eq!(a, tester.read::(1, rd)) } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); diff --git a/vm/src/rv32im/less_than/tests.rs b/vm/src/rv32im/less_than/tests.rs index 5274c39442..7647910b30 100644 --- a/vm/src/rv32im/less_than/tests.rs +++ b/vm/src/rv32im/less_than/tests.rs @@ -15,20 +15,23 @@ use p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }; -use rand::{rngs::StdRng, Rng}; +use rand::Rng; use super::{core::run_less_than, LessThanCoreChip, Rv32LessThanChip}; use crate::{ arch::{ instructions::LessThanOpcode, - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, + testing::{TestAdapterChip, VmChipTestBuilder}, + ExecutionBridge, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, }, rv32im::{ adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, less_than::LessThanCoreCols, }, - utils::{generate_long_number, generate_rv32_is_type_immediate, i32_to_f}, + utils::{ + generate_long_number, generate_rv32_is_type_immediate, i32_to_f, + rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; @@ -40,44 +43,6 @@ type F = BabyBear; /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_rv32_lt_rand_write_execute>( - tester: &mut VmChipTestBuilder, - chip: &mut E, - opcode: LessThanOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - c_imm: Option, - rng: &mut StdRng, -) { - let is_imm = c_imm.is_some(); - - let rs1 = gen_pointer(rng, 4); - let rs2 = c_imm.unwrap_or_else(|| gen_pointer(rng, 4)); - let rd = gen_pointer(rng, 4); - - tester.write::(1, rs1, b.map(F::from_canonical_u32)); - if !is_imm { - tester.write::(1, rs2, c.map(F::from_canonical_u32)); - } - - let (cmp, _, _, _) = run_less_than::(opcode, &b, &c); - tester.execute( - chip, - Instruction::from_usize( - opcode as usize, - [rd, rs1, rs2, 1, if is_imm { 0 } else { 1 }], - ), - ); - let mut a = [0; RV32_REGISTER_NUM_LIMBS]; - a[0] = cmp as u32; - - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ); -} - fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); @@ -107,28 +72,28 @@ fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let (imm, c) = generate_rv32_is_type_immediate(&mut rng); (Some(imm), c) }; - run_rv32_lt_rand_write_execute(&mut tester, &mut chip, opcode, b, c, c_imm, &mut rng); + + let (instruction, rd) = + rv32_rand_write_register_or_imm(&mut tester, b, c, c_imm, opcode as usize, &mut rng); + tester.execute(&mut chip, instruction); + + let (cmp, _, _, _) = + run_less_than::(opcode, &b, &c); + let mut a = [F::zero(); RV32_REGISTER_NUM_LIMBS]; + a[0] = F::from_bool(cmp); + assert_eq!(a, tester.read::(1, rd)); } // Test special case where b = c - run_rv32_lt_rand_write_execute( - &mut tester, - &mut chip, - opcode, - [101, 128, 202, 255], - [101, 128, 202, 255], - None, - &mut rng, - ); - run_rv32_lt_rand_write_execute( - &mut tester, - &mut chip, - opcode, - [36, 0, 0, 0], - [36, 0, 0, 0], - Some(36), - &mut rng, - ); + let b = [101, 128, 202, 255]; + let (instruction, _) = + rv32_rand_write_register_or_imm(&mut tester, b, b, None, opcode as usize, &mut rng); + tester.execute(&mut chip, instruction); + + let b = [36, 0, 0, 0]; + let (instruction, _) = + rv32_rand_write_register_or_imm(&mut tester, b, b, Some(36), opcode as usize, &mut rng); + tester.execute(&mut chip, instruction); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); diff --git a/vm/src/rv32im/mul/tests.rs b/vm/src/rv32im/mul/tests.rs index fe3f25f2a5..3b4d1f1673 100644 --- a/vm/src/rv32im/mul/tests.rs +++ b/vm/src/rv32im/mul/tests.rs @@ -13,20 +13,18 @@ use p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }; -use rand::rngs::StdRng; use super::core::run_mul; use crate::{ arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, - RANGE_TUPLE_CHECKER_BUS, + testing::{TestAdapterChip, VmChipTestBuilder}, + ExecutionBridge, VmAdapterChip, VmChipWrapper, RANGE_TUPLE_CHECKER_BUS, }, rv32im::{ adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, mul::{MultiplicationCoreChip, MultiplicationCoreCols, Rv32MultiplicationChip}, }, - utils::generate_long_number, + utils::{generate_long_number, rv32_rand_write_register_or_imm}, }; type F = BabyBear; @@ -38,33 +36,6 @@ type F = BabyBear; /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_rv32_mul_rand_write_execute>( - tester: &mut VmChipTestBuilder, - chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - rng: &mut StdRng, -) { - let rs1 = gen_pointer(rng, 4); - let rs2 = gen_pointer(rng, 4); - let rd = gen_pointer(rng, 4); - - tester.write::(1, rs1, b.map(F::from_canonical_u32)); - tester.write::(1, rs2, c.map(F::from_canonical_u32)); - - let (a, _) = run_mul::(&b, &c); - tester.execute( - chip, - Instruction::from_usize(MulOpcode::MUL as usize, [rd, rs1, rs2, 1, 0]), - ); - - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ); -} - fn run_rv32_mul_rand_test(num_ops: usize) { // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) const MAX_NUM_LIMBS: u32 = 32; @@ -90,7 +61,23 @@ fn run_rv32_mul_rand_test(num_ops: usize) { for _ in 0..num_ops { let b = generate_long_number::(&mut rng); let c = generate_long_number::(&mut rng); - run_rv32_mul_rand_write_execute(&mut tester, &mut chip, b, c, &mut rng); + + let (mut instruction, rd) = rv32_rand_write_register_or_imm( + &mut tester, + b, + c, + None, + MulOpcode::MUL as usize, + &mut rng, + ); + instruction.e = F::zero(); + tester.execute(&mut chip, instruction); + + let (a, _) = run_mul::(&b, &c); + assert_eq!( + a.map(F::from_canonical_u32), + tester.read::(1, rd) + ) } let tester = tester @@ -103,7 +90,7 @@ fn run_rv32_mul_rand_test(num_ops: usize) { #[test] fn rv32_mul_rand_test() { - run_rv32_mul_rand_test(100); + run_rv32_mul_rand_test(1); } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/vm/src/rv32im/shift/tests.rs b/vm/src/rv32im/shift/tests.rs index 85d1dd7120..6631b4d005 100644 --- a/vm/src/rv32im/shift/tests.rs +++ b/vm/src/rv32im/shift/tests.rs @@ -15,20 +15,22 @@ use p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }; -use rand::{rngs::StdRng, Rng}; +use rand::Rng; use super::{core::run_shift, Rv32ShiftChip, ShiftCoreChip}; use crate::{ arch::{ instructions::ShiftOpcode, - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, + testing::{TestAdapterChip, VmChipTestBuilder}, + ExecutionBridge, VmAdapterChip, VmChipWrapper, BITWISE_OP_LOOKUP_BUS, }, rv32im::{ adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, shift::ShiftCoreCols, }, - utils::{generate_long_number, generate_rv32_is_type_immediate}, + utils::{ + generate_long_number, generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; @@ -40,45 +42,8 @@ type F = BabyBear; /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_rv32_shift_rand_write_execute>( - tester: &mut VmChipTestBuilder, - chip: &mut E, - opcode: ShiftOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - c_imm: Option, - rng: &mut StdRng, -) { - let is_imm = c_imm.is_some(); - - let rs1 = gen_pointer(rng, 4); - let rs2 = c_imm.unwrap_or_else(|| gen_pointer(rng, 4)); - let rd = gen_pointer(rng, 4); - - tester.write::(1, rs1, b.map(F::from_canonical_u32)); - if !is_imm { - tester.write::(1, rs2, c.map(F::from_canonical_u32)); - } - - let (a, _, _) = run_shift::(opcode, &b, &c); - tester.execute( - chip, - Instruction::from_usize( - opcode as usize, - [rd, rs1, rs2, 1, if is_imm { 0 } else { 1 }], - ), - ); - - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ); -} - fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( bitwise_bus, @@ -110,7 +75,16 @@ fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { let (imm, c) = generate_rv32_is_type_immediate(&mut rng); (Some(imm), c) }; - run_rv32_shift_rand_write_execute(&mut tester, &mut chip, opcode, b, c, c_imm, &mut rng); + + let (instruction, rd) = + rv32_rand_write_register_or_imm(&mut tester, b, c, c_imm, opcode as usize, &mut rng); + tester.execute(&mut chip, instruction); + + let (a, _, _) = run_shift::(opcode, &b, &c); + assert_eq!( + a.map(F::from_canonical_u32), + tester.read::(1, rd) + ) } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); diff --git a/vm/src/system/program/util.rs b/vm/src/system/program/util.rs index 9a756a0233..e447501bda 100644 --- a/vm/src/system/program/util.rs +++ b/vm/src/system/program/util.rs @@ -28,7 +28,10 @@ pub fn execute_program(program: Program, input_stream: Vec( ) } +// Returns (instruction, rd) +pub fn rv32_rand_write_register_or_imm( + tester: &mut VmChipTestBuilder, + rs1_writes: [u32; NUM_LIMBS], + rs2_writes: [u32; NUM_LIMBS], + imm: Option, + opcode_with_offset: usize, + rng: &mut StdRng, +) -> (Instruction, usize) { + let rs2_is_imm = imm.is_some(); + + let rs1 = gen_pointer(rng, NUM_LIMBS); + let rs2 = imm.unwrap_or_else(|| gen_pointer(rng, NUM_LIMBS)); + let rd = gen_pointer(rng, NUM_LIMBS); + + tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u32)); + if !rs2_is_imm { + tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u32)); + } + + ( + Instruction::from_usize( + opcode_with_offset, + [rd, rs1, rs2, 1, if rs2_is_imm { 0 } else { 1 }], + ), + rd, + ) +} + pub fn generate_long_number( rng: &mut StdRng, ) -> [u32; NUM_LIMBS] {