diff --git a/compiler/src/asm/compiler.rs b/compiler/src/asm/compiler.rs index c37f66052f..9cde4f03b1 100644 --- a/compiler/src/asm/compiler.rs +++ b/compiler/src/asm/compiler.rs @@ -349,21 +349,21 @@ impl + TwoAdicField> AsmCo debug_info, ); } - DslIr::AddU256(dst, lhs, rhs) => { + DslIr::Add256(dst, lhs, rhs) => { self.push( - AsmInstruction::AddU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + AsmInstruction::Add256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), debug_info, ); } - DslIr::SubU256(dst, lhs, rhs) => { + DslIr::Sub256(dst, lhs, rhs) => { self.push( - AsmInstruction::SubU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + AsmInstruction::Sub256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), debug_info, ); } - DslIr::MulU256(dst, lhs, rhs) => { + DslIr::Mul256(dst, lhs, rhs) => { self.push( - AsmInstruction::MulU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + AsmInstruction::Mul256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), debug_info, ); } @@ -373,9 +373,59 @@ impl + TwoAdicField> AsmCo debug_info, ); } - DslIr::EqualToU256(dst, lhs, rhs) => { + DslIr::EqualTo256(dst, lhs, rhs) => { self.push( - AsmInstruction::EqualToU256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), + AsmInstruction::EqualTo256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::Xor256(dst, lhs, rhs) => { + self.push( + AsmInstruction::Xor256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::And256(dst, lhs, rhs) => { + self.push( + AsmInstruction::And256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::Or256(dst, lhs, rhs) => { + self.push( + AsmInstruction::Or256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::LessThanI256(dst, lhs, rhs) => { + self.push( + AsmInstruction::LessThanI256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::ShiftLeft256(dst, lhs, rhs) => { + self.push( + AsmInstruction::ShiftLeft256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()), + debug_info, + ); + } + DslIr::ShiftRightLogic256(dst, lhs, rhs) => { + self.push( + AsmInstruction::ShiftRightLogic256( + dst.ptr_fp(), + lhs.ptr_fp(), + rhs.ptr_fp(), + ), + debug_info, + ); + } + DslIr::ShiftRightArith256(dst, lhs, rhs) => { + self.push( + AsmInstruction::ShiftRightArith256( + dst.ptr_fp(), + lhs.ptr_fp(), + rhs.ptr_fp(), + ), debug_info, ); } diff --git a/compiler/src/asm/instruction.rs b/compiler/src/asm/instruction.rs index be488407ab..011d467407 100644 --- a/compiler/src/asm/instruction.rs +++ b/compiler/src/asm/instruction.rs @@ -96,20 +96,41 @@ pub enum AsmInstruction { /// Modular divide, dst = lhs / rhs. DivSecp256k1Scalar(i32, i32, i32), - /// uint add, dst = lhs + rhs. - AddU256(i32, i32, i32), + /// int add, dst = lhs + rhs. + Add256(i32, i32, i32), - /// uint subtract, dst = lhs - rhs. - SubU256(i32, i32, i32), + /// int subtract, dst = lhs - rhs. + Sub256(i32, i32, i32), - /// uint multiply, dst = lhs * rhs. - MulU256(i32, i32, i32), + /// int multiply, dst = lhs * rhs. + Mul256(i32, i32, i32), /// uint less than, dst = lhs < rhs. LessThanU256(i32, i32, i32), - /// uint equal to, dst = lhs == rhs. - EqualToU256(i32, i32, i32), + /// int equal to, dst = lhs == rhs. + EqualTo256(i32, i32, i32), + + /// int bitwise XOR, dst = lhs ^ rhs + Xor256(i32, i32, i32), + + /// int bitwise AND, dst = lhs & rhs + And256(i32, i32, i32), + + /// int bitwise OR, dst = lhs | rhs + Or256(i32, i32, i32), + + /// signed int less than, dst = lhs < rhs + LessThanI256(i32, i32, i32), + + /// int shift left, dst = lhs << rhs + ShiftLeft256(i32, i32, i32), + + /// int shift right logical, dst = lhs >> rhs + ShiftRightLogic256(i32, i32, i32), + + /// int shift right arithmetic, dst = lhs >> rhs + ShiftRightArith256(i32, i32, i32), /// Jump. Jump(i32, F), @@ -473,20 +494,41 @@ impl> AsmInstruction { dst, src1, src2 ) } - AsmInstruction::AddU256(dst, src1, src2) => { - write!(f, "add_u256 ({})fp ({})fp ({})fp", dst, src1, src2) + AsmInstruction::Add256(dst, src1, src2) => { + write!(f, "add_256 ({})fp ({})fp ({})fp", dst, src1, src2) } - AsmInstruction::SubU256(dst, src1, src2) => { - write!(f, "sub_u256 ({})fp ({})fp ({})fp", dst, src1, src2) + AsmInstruction::Sub256(dst, src1, src2) => { + write!(f, "sub_256 ({})fp ({})fp ({})fp", dst, src1, src2) } - AsmInstruction::MulU256(dst, src1, src2) => { - write!(f, "mul_u256 ({})fp ({})fp ({})fp", dst, src1, src2) + AsmInstruction::Mul256(dst, src1, src2) => { + write!(f, "mul_256 ({})fp ({})fp ({})fp", dst, src1, src2) } AsmInstruction::LessThanU256(dst, src1, src2) => { - write!(f, "lt_u256 ({})fp ({})fp ({})fp", dst, src1, src2) + write!(f, "sltu_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::EqualTo256(dst, src1, src2) => { + write!(f, "eq_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::Xor256(dst, src1, src2) => { + write!(f, "xor_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::And256(dst, src1, src2) => { + write!(f, "and_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::Or256(dst, src1, src2) => { + write!(f, "or_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::LessThanI256(dst, src1, src2) => { + write!(f, "slt_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::ShiftLeft256(dst, src1, src2) => { + write!(f, "sll_256 ({})fp ({})fp ({})fp", dst, src1, src2) + } + AsmInstruction::ShiftRightLogic256(dst, src1, src2) => { + write!(f, "srl_256 ({})fp ({})fp ({})fp", dst, src1, src2) } - AsmInstruction::EqualToU256(dst, src1, src2) => { - write!(f, "eq_u256 ({})fp ({})fp ({})fp", dst, src1, src2) + AsmInstruction::ShiftRightArith256(dst, src1, src2) => { + write!(f, "sra_256 ({})fp ({})fp ({})fp", dst, src1, src2) } } } diff --git a/compiler/src/conversion/mod.rs b/compiler/src/conversion/mod.rs index c048a98afc..bc33d4e9b6 100644 --- a/compiler/src/conversion/mod.rs +++ b/compiler/src/conversion/mod.rs @@ -759,27 +759,23 @@ fn convert_instruction>( AS::Memory, AS::Memory, )], - AsmInstruction::AddU256(dst, src1, src2) => vec![inst_large( + AsmInstruction::Add256(dst, src1, src2) => vec![inst( ADD256, i32_f(dst), i32_f(src1), i32_f(src2), AS::Memory, AS::Memory, - AS::Memory.to_field(), - AS::Memory.to_field(), )], - AsmInstruction::SubU256(dst, src1, src2) => vec![inst_large( + AsmInstruction::Sub256(dst, src1, src2) => vec![inst( SUB256, i32_f(dst), i32_f(src1), i32_f(src2), AS::Memory, AS::Memory, - AS::Memory.to_field(), - AS::Memory.to_field(), )], - AsmInstruction::MulU256(dst, src1, src2) => vec![inst( + AsmInstruction::Mul256(dst, src1, src2) => vec![inst( MUL256, i32_f(dst), i32_f(src1), @@ -787,25 +783,77 @@ fn convert_instruction>( AS::Memory, AS::Memory, )], - AsmInstruction::LessThanU256(dst, src1, src2) => vec![inst_large( - LT256, + AsmInstruction::LessThanU256(dst, src1, src2) => vec![inst( + SLTU256, i32_f(dst), i32_f(src1), i32_f(src2), AS::Memory, AS::Memory, - AS::Memory.to_field(), - AS::Memory.to_field(), )], - AsmInstruction::EqualToU256(dst, src1, src2) => vec![inst_large( + AsmInstruction::EqualTo256(dst, src1, src2) => vec![inst( EQ256, i32_f(dst), i32_f(src1), i32_f(src2), AS::Memory, AS::Memory, - AS::Memory.to_field(), - AS::Memory.to_field(), + )], + AsmInstruction::Xor256(dst, src1, src2) => vec![inst( + XOR256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::And256(dst, src1, src2) => vec![inst( + AND256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::Or256(dst, src1, src2) => vec![inst( + OR256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::LessThanI256(dst, src1, src2) => vec![inst( + SLT256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::ShiftLeft256(dst, src1, src2) => vec![inst( + SLL256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::ShiftRightLogic256(dst, src1, src2) => vec![inst( + SRL256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, + )], + AsmInstruction::ShiftRightArith256(dst, src1, src2) => vec![inst( + SRA256, + i32_f(dst), + i32_f(src1), + i32_f(src2), + AS::Memory, + AS::Memory, )], AsmInstruction::Keccak256(dst, src, len) => vec![inst_med( KECCAK256, diff --git a/compiler/src/ir/instructions.rs b/compiler/src/ir/instructions.rs index e9da35e0d6..a5b52af20f 100644 --- a/compiler/src/ir/instructions.rs +++ b/compiler/src/ir/instructions.rs @@ -38,8 +38,8 @@ pub enum DslIr { AddSecp256k1Coord(BigUintVar, BigUintVar, BigUintVar), /// Add two modular BigInts over scalar field. AddSecp256k1Scalar(BigUintVar, BigUintVar, BigUintVar), - /// Add two u256 - AddU256(BigUintVar, BigUintVar, BigUintVar), + /// Add two 256-bit integers + Add256(BigUintVar, BigUintVar, BigUintVar), // Subtractions. /// Subtracts two variables (var = var - var). @@ -68,8 +68,8 @@ pub enum DslIr { SubSecp256k1Coord(BigUintVar, BigUintVar, BigUintVar), /// Subtracts two modular BigInts over scalar field. SubSecp256k1Scalar(BigUintVar, BigUintVar, BigUintVar), - /// Subtract two u256 - SubU256(BigUintVar, BigUintVar, BigUintVar), + /// Subtract two 256-bit integers + Sub256(BigUintVar, BigUintVar, BigUintVar), // Multiplications. /// Multiplies two variables (var = var * var). @@ -92,8 +92,8 @@ pub enum DslIr { MulSecp256k1Coord(BigUintVar, BigUintVar, BigUintVar), /// Multiplies two modular BigInts over scalar field. MulSecp256k1Scalar(BigUintVar, BigUintVar, BigUintVar), - /// Multiply two u256 - MulU256(BigUintVar, BigUintVar, BigUintVar), + /// Multiply two 256-bit integers + Mul256(BigUintVar, BigUintVar, BigUintVar), // Divisions. /// Divides two variables (var = var / var). @@ -132,8 +132,27 @@ pub enum DslIr { LessThanVI(Var, Var, C::N), /// Compare two u256 for < LessThanU256(Ptr, BigUintVar, BigUintVar), - /// Compare two u256 for == - EqualToU256(Ptr, BigUintVar, BigUintVar), + /// Compare two 256-bit integers for == + EqualTo256(Ptr, BigUintVar, BigUintVar), + /// Compare two signed 256-bit integers for < + LessThanI256(Ptr, BigUintVar, BigUintVar), + + // Bitwise operations. + /// Bitwise XOR on two 256-bit integers + Xor256(BigUintVar, BigUintVar, BigUintVar), + /// Bitwise AND on two 256-bit integers + And256(BigUintVar, BigUintVar, BigUintVar), + /// Bitwise OR on two 256-bit integers + Or256(BigUintVar, BigUintVar, BigUintVar), + + // Shifts. + /// Shift left on 256-bit integers + ShiftLeft256(BigUintVar, BigUintVar, BigUintVar), + /// Shift right logical on 256-bit integers + ShiftRightLogic256(BigUintVar, BigUintVar, BigUintVar), + /// Shift right arithmetic on 256-bit integers + ShiftRightArith256(BigUintVar, BigUintVar, BigUintVar), + // ======= // Control flow. diff --git a/compiler/src/ir/modular_arithmetic.rs b/compiler/src/ir/modular_arithmetic.rs index 8c15632338..9859b1e155 100644 --- a/compiler/src/ir/modular_arithmetic.rs +++ b/compiler/src/ir/modular_arithmetic.rs @@ -96,14 +96,14 @@ where // FIXME: reuse constant zero. let big_zero = self.eval_biguint(BigUint::zero()); self.operations - .push(DslIr::EqualToU256(ret_arr.ptr(), biguint.clone(), big_zero)); + .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_zero)); let ret: Var<_> = self.get(&ret_arr, 0); self.if_ne(ret, C::N::one()).then(|builder| { // FIXME: reuse constant. let big_n = builder.eval_biguint(SECP256K1_COORD_PRIME.clone()); builder .operations - .push(DslIr::EqualToU256(ret_arr.ptr(), biguint.clone(), big_n)); + .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_n)); let _ret: Var<_> = builder.get(&ret_arr, 0); builder.assign(&ret, _ret); }); @@ -173,14 +173,14 @@ where // FIXME: reuse constant zero. let big_zero = self.eval_biguint(BigUint::zero()); self.operations - .push(DslIr::EqualToU256(ret_arr.ptr(), biguint.clone(), big_zero)); + .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_zero)); let ret: Var<_> = self.get(&ret_arr, 0); self.if_ne(ret, C::N::one()).then(|builder| { // FIXME: reuse constant. let big_n = builder.eval_biguint(SECP256K1_SCALAR_PRIME.clone()); builder .operations - .push(DslIr::EqualToU256(ret_arr.ptr(), biguint.clone(), big_n)); + .push(DslIr::EqualTo256(ret_arr.ptr(), biguint.clone(), big_n)); let _ret: Var<_> = builder.get(&ret_arr, 0); builder.assign(&ret, _ret); }); diff --git a/compiler/src/ir/uint.rs b/compiler/src/ir/uint.rs index 4e0114b326..9f9e1e4c64 100644 --- a/compiler/src/ir/uint.rs +++ b/compiler/src/ir/uint.rs @@ -7,39 +7,97 @@ impl Builder where C::N: PrimeField64, { - pub fn u256_add(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + pub fn add_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { let dst = self.dyn_array(NUM_LIMBS); self.operations - .push(DslIr::AddU256(dst.clone(), left.clone(), right.clone())); + .push(DslIr::Add256(dst.clone(), left.clone(), right.clone())); dst } - pub fn u256_sub(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + pub fn sub_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { let dst = self.dyn_array(NUM_LIMBS); self.operations - .push(DslIr::SubU256(dst.clone(), left.clone(), right.clone())); + .push(DslIr::Sub256(dst.clone(), left.clone(), right.clone())); dst } - pub fn u256_mul(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + pub fn mul_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { let dst = self.dyn_array(NUM_LIMBS); self.operations - .push(DslIr::MulU256(dst.clone(), left.clone(), right.clone())); + .push(DslIr::Mul256(dst.clone(), left.clone(), right.clone())); dst } - pub fn u256_lt(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { + pub fn sltu_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { let dst = self.array(1); self.operations .push(DslIr::LessThanU256(dst.ptr(), left.clone(), right.clone())); self.get(&dst, 0) } - pub fn u256_eq(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { + pub fn eq_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { // let dst = self.alloc(1, as MemVariable>::size_of()); let dst = self.array(1); self.operations - .push(DslIr::EqualToU256(dst.ptr(), left.clone(), right.clone())); + .push(DslIr::EqualTo256(dst.ptr(), left.clone(), right.clone())); self.get(&dst, 0) } + + pub fn xor_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations + .push(DslIr::Xor256(dst.clone(), left.clone(), right.clone())); + dst + } + + pub fn and_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations + .push(DslIr::And256(dst.clone(), left.clone(), right.clone())); + dst + } + + pub fn or_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations + .push(DslIr::Or256(dst.clone(), left.clone(), right.clone())); + dst + } + + pub fn slt_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> Var { + let dst = self.array(1); + self.operations + .push(DslIr::LessThanI256(dst.ptr(), left.clone(), right.clone())); + self.get(&dst, 0) + } + + pub fn sll_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations.push(DslIr::ShiftLeft256( + dst.clone(), + left.clone(), + right.clone(), + )); + dst + } + + pub fn srl_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations.push(DslIr::ShiftRightLogic256( + dst.clone(), + left.clone(), + right.clone(), + )); + dst + } + + pub fn sra_256(&mut self, left: &BigUintVar, right: &BigUintVar) -> BigUintVar { + let dst = self.dyn_array(NUM_LIMBS); + self.operations.push(DslIr::ShiftRightArith256( + dst.clone(), + left.clone(), + right.clone(), + )); + dst + } } diff --git a/compiler/tests/uint.rs b/compiler/tests/uint.rs index 075215aa40..de15477b83 100644 --- a/compiler/tests/uint.rs +++ b/compiler/tests/uint.rs @@ -1,17 +1,21 @@ +use std::iter; + use afs_compiler::{ asm::AsmBuilder, + conversion::CompilerOptions, ir::Var, util::{execute_program, execute_program_with_config}, }; use ax_sdk::utils::create_seeded_rng; use num_bigint_dig::BigUint; +use num_traits::Zero; use p3_baby_bear::BabyBear; use p3_field::{extension::BinomialExtensionField, AbstractField}; use rand::{Rng, RngCore}; use stark_vm::vm::config::VmConfig; #[test] -fn test_compiler_u256_add_sub() { +fn test_compiler_256_add_sub() { let num_digits = 8; let num_ops = 15; let mut rng = create_seeded_rng(); @@ -39,9 +43,9 @@ fn test_compiler_u256_add_sub() { }; let c_var = if add_flag { - builder.u256_add(&a_var, &b_var) + builder.add_256(&a_var, &b_var) } else { - builder.u256_sub(&a_var, &b_var) + builder.sub_256(&a_var, &b_var) }; let c_check_var = builder.eval_biguint(c); builder.assert_var_array_eq(&c_var, &c_check_var); @@ -53,7 +57,7 @@ fn test_compiler_u256_add_sub() { } #[test] -fn test_compiler_u256_mul() { +fn test_compiler_256_mul() { let num_digits = 8; let num_ops = 10; let mut rng = create_seeded_rng(); @@ -73,7 +77,7 @@ fn test_compiler_u256_mul() { let b_var = builder.eval_biguint(b.clone()); let c = (a.clone() * b.clone()) % u256_modulus.clone(); - let c_var = builder.u256_mul(&a_var, &b_var); + let c_var = builder.mul_256(&a_var, &b_var); let c_check_var = builder.eval_biguint(c); builder.assert_var_array_eq(&c_var, &c_check_var); @@ -95,9 +99,9 @@ fn test_compiler_u256_mul() { } #[test] -fn test_compiler_u256_lt_eq() { +fn test_compiler_256_sltu_eq() { let num_digits = 8; - let num_ops = 1; + let num_ops = 15; let mut rng = create_seeded_rng(); type F = BabyBear; @@ -126,9 +130,9 @@ fn test_compiler_u256_lt_eq() { }; let c_var = if lt_flag { - builder.u256_lt(&a_var, &b_var) + builder.sltu_256(&a_var, &b_var) } else { - builder.u256_eq(&a_var, &b_var) + builder.eq_256(&a_var, &b_var) }; let c_check_var: Var<_> = builder.eval(F::from_bool(c)); @@ -136,6 +140,214 @@ fn test_compiler_u256_lt_eq() { } builder.halt(); + let program = builder.clone().compile_isa_with_options(CompilerOptions { + word_size: 32, + ..Default::default() + }); + execute_program(program, vec![]); +} + +#[test] +fn test_compiler_256_slt_eq() { + let num_digits = 8; + let num_ops = 15; + let mut rng = create_seeded_rng(); + + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); + let msb_mask: u32 = 1 << 31; + + for _ in 0..num_ops { + let a_digits: Vec = (0..num_digits).map(|_| rng.next_u32()).collect(); + let a = BigUint::new(a_digits.clone()); + let b_digits: Vec = (0..num_digits).map(|_| rng.next_u32()).collect(); + let b = BigUint::new(b_digits.clone()); + + let a_var = builder.eval_biguint(a.clone()); + let b_var = builder.eval_biguint(b.clone()); + + let same_sign = + (a_digits[num_digits - 1] & msb_mask) == (b_digits[num_digits - 1] & msb_mask); + + let c = if same_sign { + a.clone() < b.clone() + } else { + a_digits[num_digits - 1] & msb_mask == msb_mask + }; + + let c_var = builder.slt_256(&a_var, &b_var); + let c_check_var: Var<_> = builder.eval(F::from_bool(c)); + builder.assert_var_eq(c_var, c_check_var); + } + builder.halt(); + + let program = builder.clone().compile_isa_with_options(CompilerOptions { + word_size: 32, + ..Default::default() + }); + execute_program(program, vec![]); +} + +#[test] +fn test_compiler_256_xor_and_or() { + let num_digits = 8; + let num_ops = 20; + let mut rng = create_seeded_rng(); + + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); + + for _ in 0..num_ops { + let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); + let a = BigUint::new(a_digits); + let b_digits = (0..num_digits).map(|_| rng.next_u32()).collect(); + let b = BigUint::new(b_digits); + + let a_var = builder.eval_biguint(a.clone()); + let b_var = builder.eval_biguint(b.clone()); + + // xor = 0, and = 1, or = 2 + let flag: u8 = rng.gen_range(0..=2); + + let c = if flag == 0 { + a.clone() ^ b.clone() + } else if flag == 1 { + a.clone() & b.clone() + } else { + a.clone() | b.clone() + }; + + let c_var = if flag == 0 { + builder.xor_256(&a_var, &b_var) + } else if flag == 1 { + builder.and_256(&a_var, &b_var) + } else { + builder.or_256(&a_var, &b_var) + }; + let c_check_var = builder.eval_biguint(c); + builder.assert_var_array_eq(&c_var, &c_check_var); + } + builder.halt(); + let program = builder.clone().compile_isa(); execute_program(program, vec![]); } + +#[test] +fn test_compiler_256_sll_srl() { + let num_digits = 8; + let num_ops = 15; + let mut rng = create_seeded_rng(); + + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); + + for _ in 0..num_ops { + let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect::>(); + let a = BigUint::new(a_digits.clone()); + + let b_shift = rng.gen_range(0..=64); + let b_digits = iter::once(b_shift as u32) + .chain(iter::repeat(0u32)) + .take(num_digits) + .collect::>(); + let b = BigUint::new(b_digits); + + let a_var = builder.eval_biguint(a.clone()); + let b_var = builder.eval_biguint(b.clone()); + + // sll = 0, srl = 1 + let sll_flag = rng.gen_bool(0.5); + + let c = if sll_flag { + a.clone() << b_shift + } else { + a.clone() >> b_shift + }; + + let c_var = if sll_flag { + builder.sll_256(&a_var, &b_var) + } else { + builder.srl_256(&a_var, &b_var) + }; + let c_check_var = builder.eval_biguint(c); + builder.assert_var_array_eq(&c_var, &c_check_var); + } + builder.halt(); + + let program = builder.clone().compile_isa(); + execute_program_with_config( + VmConfig { + num_public_values: 4, + max_segment_len: (1 << 25) - 100, + shift_256_enabled: true, + bigint_limb_size: 8, + ..Default::default() + }, + program, + vec![], + ); +} + +#[test] +fn test_compiler_256_sra() { + let num_digits = 8; + let num_ops = 10; + let mut rng = create_seeded_rng(); + + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); + let msb_mask: u32 = 1 << 31; + + for _ in 0..num_ops { + let a_digits = (0..num_digits).map(|_| rng.next_u32()).collect::>(); + let a_sign = a_digits[num_digits - 1] & msb_mask == msb_mask; + let a = BigUint::new(a_digits.clone()); + + let b_shift = rng.gen_range(0..=256); + let b_digits = iter::once(b_shift as u32) + .chain(iter::repeat(0u32)) + .take(num_digits) + .collect::>(); + let b = BigUint::new(b_digits); + + let a_var = builder.eval_biguint(a.clone()); + let b_var = builder.eval_biguint(b.clone()); + + let ones = iter::repeat(0) + .take((256 - b_shift) / 32) + .chain(iter::once(u32::MAX << (32 - (b_shift % 32)))) + .chain(iter::repeat(u32::MAX)) + .take(num_digits) + .collect::>(); + + let c = (a.clone() >> b_shift) + + if a_sign { + BigUint::new(ones) + } else { + BigUint::zero() + }; + + let c_var = builder.sra_256(&a_var, &b_var); + let c_check_var = builder.eval_biguint(c); + builder.assert_var_array_eq(&c_var, &c_check_var); + } + builder.halt(); + + let program = builder.clone().compile_isa(); + execute_program_with_config( + VmConfig { + num_public_values: 4, + max_segment_len: (1 << 25) - 100, + shift_256_enabled: true, + bigint_limb_size: 8, + ..Default::default() + }, + program, + vec![], + ); +} diff --git a/vm/src/alu/air.rs b/vm/src/alu/air.rs index d935b6bbe0..cae05b0e36 100644 --- a/vm/src/alu/air.rs +++ b/vm/src/alu/air.rs @@ -52,7 +52,7 @@ impl Air let flags = [ aux.opcode_add_flag, aux.opcode_sub_flag, - aux.opcode_lt_flag, + aux.opcode_sltu_flag, aux.opcode_eq_flag, aux.opcode_xor_flag, aux.opcode_and_flag, @@ -104,7 +104,7 @@ impl Air AB::Expr::zero() }); builder - .when(aux.opcode_sub_flag + aux.opcode_lt_flag + aux.opcode_slt_flag) + .when(aux.opcode_sub_flag + aux.opcode_sltu_flag + aux.opcode_slt_flag) .assert_bool(carry_sub[i].clone()); } @@ -121,7 +121,7 @@ impl Air .assert_zero(aux.y_sign); let slt_xor = - (aux.opcode_lt_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign + (aux.opcode_sltu_flag + aux.opcode_slt_flag) * io.cmp_result + aux.x_sign + aux.y_sign - AB::Expr::from_canonical_u32(2) * (io.cmp_result * aux.x_sign + io.cmp_result * aux.y_sign @@ -129,7 +129,7 @@ impl Air + AB::Expr::from_canonical_u32(4) * (io.cmp_result * aux.x_sign * aux.y_sign); builder.assert_eq( slt_xor, - (aux.opcode_lt_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(), + (aux.opcode_sltu_flag + aux.opcode_slt_flag) * carry_sub[NUM_LIMBS - 1].clone(), ); // For EQ, z is filled with 0 except at the lowest index i such that x[i] != y[i]. If diff --git a/vm/src/alu/bridge.rs b/vm/src/alu/bridge.rs index 813570e3ab..9d8f98d65d 100644 --- a/vm/src/alu/bridge.rs +++ b/vm/src/alu/bridge.rs @@ -24,7 +24,7 @@ impl ArithmeticLogicAir ArithmeticLogicAir( ) -> (Vec, bool) { match opcode { Opcode::ADD256 => solve_add::(x, y), - Opcode::SUB256 | Opcode::LT256 => solve_subtract::(x, y), + Opcode::SUB256 | Opcode::SLTU256 => solve_subtract::(x, y), Opcode::EQ256 => solve_eq::(x, y), Opcode::XOR256 => solve_xor::(x, y), Opcode::AND256 => solve_and::(x, y), diff --git a/vm/src/alu/tests.rs b/vm/src/alu/tests.rs index dc4abe5080..62b6dbee0f 100644 --- a/vm/src/alu/tests.rs +++ b/vm/src/alu/tests.rs @@ -281,7 +281,7 @@ fn alu_sub_wrong_negative_test() { } #[test] -fn alu_lt_rand_test() { +fn alu_sltu_rand_test() { let num_ops: usize = 10; let mut rng = create_seeded_rng(); @@ -297,7 +297,7 @@ fn alu_lt_rand_test() { for _ in 0..num_ops { let x = generate_long_number::(&mut rng); let y = generate_long_number::(&mut rng); - run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::LT256, x, y, &mut rng); + run_alu_rand_write_execute(&mut tester, &mut chip, Opcode::SLTU256, x, y, &mut rng); } let tester = tester.build().load(chip).load(xor_lookup_chip).finalize(); @@ -305,9 +305,9 @@ fn alu_lt_rand_test() { } #[test] -fn alu_lt_wrong_subtraction_test() { +fn alu_sltu_wrong_subtraction_test() { run_alu_negative_test( - Opcode::LT256, + Opcode::SLTU256, iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), iter::once(65_000).chain(iter::repeat(0).take(31)).collect(), std::iter::once(1).chain(iter::repeat(0).take(31)).collect(), @@ -319,9 +319,9 @@ fn alu_lt_wrong_subtraction_test() { } #[test] -fn alu_lt_wrong_negative_test() { +fn alu_sltu_wrong_negative_test() { run_alu_negative_test( - Opcode::LT256, + Opcode::SLTU256, iter::once(1).chain(iter::repeat(0).take(31)).collect(), iter::once(1).chain(iter::repeat(0).take(31)).collect(), iter::once(0).chain(iter::repeat(0).take(31)).collect(), @@ -333,9 +333,9 @@ fn alu_lt_wrong_negative_test() { } #[test] -fn alu_lt_non_zero_sign_negative_test() { +fn alu_sltu_non_zero_sign_negative_test() { run_alu_negative_test( - Opcode::LT256, + Opcode::SLTU256, vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], vec![(1 << LIMB_BITS) - 1; NUM_LIMBS], vec![0; NUM_LIMBS], diff --git a/vm/src/alu/trace.rs b/vm/src/alu/trace.rs index e86720c3d5..f4ba528265 100644 --- a/vm/src/alu/trace.rs +++ b/vm/src/alu/trace.rs @@ -82,7 +82,7 @@ impl MachineChi y_sign, opcode_add_flag: F::from_bool(instruction.opcode == Opcode::ADD256), opcode_sub_flag: F::from_bool(instruction.opcode == Opcode::SUB256), - opcode_lt_flag: F::from_bool(instruction.opcode == Opcode::LT256), + opcode_sltu_flag: F::from_bool(instruction.opcode == Opcode::SLTU256), opcode_eq_flag: F::from_bool(instruction.opcode == Opcode::EQ256), opcode_xor_flag: F::from_bool(instruction.opcode == Opcode::XOR256), opcode_and_flag: F::from_bool(instruction.opcode == Opcode::AND256), diff --git a/vm/src/arch/instructions.rs b/vm/src/arch/instructions.rs index b2a580e1e4..2f5fcfe195 100644 --- a/vm/src/arch/instructions.rs +++ b/vm/src/arch/instructions.rs @@ -64,7 +64,7 @@ pub enum Opcode { ADD256 = 80, SUB256 = 81, MUL256 = 82, - LT256 = 83, + SLTU256 = 83, EQ256 = 84, XOR256 = 85, AND256 = 86, @@ -95,8 +95,9 @@ pub const CORE_INSTRUCTIONS: [Opcode; 17] = [ ]; pub const FIELD_ARITHMETIC_INSTRUCTIONS: [Opcode; 4] = [FADD, FSUB, FMUL, FDIV]; pub const FIELD_EXTENSION_INSTRUCTIONS: [Opcode; 4] = [FE4ADD, FE4SUB, BBE4MUL, BBE4DIV]; -pub const ALU_256_INSTRUCTIONS: [Opcode; 8] = - [ADD256, SUB256, LT256, EQ256, XOR256, AND256, OR256, SLT256]; +pub const ALU_256_INSTRUCTIONS: [Opcode; 8] = [ + ADD256, SUB256, SLTU256, EQ256, XOR256, AND256, OR256, SLT256, +]; pub const SHIFT_256_INSTRUCTIONS: [Opcode; 3] = [SLL256, SRL256, SRA256]; pub const UI_32_INSTRUCTIONS: [Opcode; 2] = [LUI, AUIPC];