diff --git a/vm/src/intrinsics/ecc/fp2/muldiv.rs b/vm/src/intrinsics/ecc/fp2/muldiv.rs index 4bd24894d0..131e71698b 100644 --- a/vm/src/intrinsics/ecc/fp2/muldiv.rs +++ b/vm/src/intrinsics/ecc/fp2/muldiv.rs @@ -3,7 +3,7 @@ use std::{cell::RefCell, rc::Rc}; use ax_circuit_derive::{Chip, ChipUsageGetter}; use ax_circuit_primitives::var_range::VariableRangeCheckerBus; use ax_ecc_primitives::{ - field_expression::{ExprBuilder, ExprBuilderConfig, FieldExpr}, + field_expression::{ExprBuilder, ExprBuilderConfig, FieldExpr, SymbolicExpr}, field_extension::Fp2, }; use axvm_circuit_derive::InstructionExecutor; @@ -58,14 +58,9 @@ pub fn fp2_muldiv_expr( let builder = ExprBuilder::new(config, range_bus.range_max_bits); let builder = Rc::new(RefCell::new(builder)); - let mut x = Fp2::new(builder.clone()); + let x = Fp2::new(builder.clone()); let mut y = Fp2::new(builder.clone()); - let flag = builder.borrow_mut().new_flag(); - // Order matters: we must do compute first as this triggers autosave and introduces new variables. - // These intermediate variables have to be computed before are result variables. - let fp2_compute = Fp2::select(flag, &x.mul(&mut y), &x.div(&mut y)); - let (z_idx, mut z) = Fp2::new_var(builder.clone()); let mut lvar = Fp2::select(flag, &x, &z); let mut rvar = Fp2::select(flag, &z, &x); @@ -78,12 +73,19 @@ pub fn fp2_muldiv_expr( builder .borrow_mut() .set_constraint(z_idx.1, fp2_constraint.c1.expr); - builder - .borrow_mut() - .set_compute(z_idx.0, fp2_compute.c0.expr); - builder - .borrow_mut() - .set_compute(z_idx.1, fp2_compute.c1.expr); + + // Compute expression has to be done manually at the SymbolicExpr level. + // Otherwise it saves the quotient and introduces new variables. + let compute_z0_div = (&x.c0.expr * &y.c0.expr + &x.c1.expr * &y.c1.expr) + / (&y.c0.expr * &y.c0.expr + &y.c1.expr * &y.c1.expr); + let compute_z0_mul = &x.c0.expr * &y.c0.expr - &x.c1.expr * &y.c1.expr; + let compute_z0 = SymbolicExpr::Select(flag, Box::new(compute_z0_mul), Box::new(compute_z0_div)); + let compute_z1_div = (&x.c1.expr * &y.c0.expr - &x.c0.expr * &y.c1.expr) + / (&y.c0.expr * &y.c0.expr + &y.c1.expr * &y.c1.expr); + let compute_z1_mul = &x.c1.expr * &y.c0.expr + &x.c0.expr * &y.c1.expr; + let compute_z1 = SymbolicExpr::Select(flag, Box::new(compute_z1_mul), Box::new(compute_z1_div)); + builder.borrow_mut().set_compute(z_idx.0, compute_z0); + builder.borrow_mut().set_compute(z_idx.1, compute_z1); let builder = builder.borrow().clone(); (FieldExpr::new(builder, range_bus), flag)