diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index 2db4f5b6..1af633fa 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -1104,7 +1104,8 @@ mod tests { use std::hash::Hasher; use itertools::Itertools; - use ndarray::{Array2, Axis}; + use ndarray::Array2; + use ndarray::Axis; use proptest::arbitrary::Arbitrary; use proptest::collection::vec; use proptest::prelude::*; @@ -1774,32 +1775,52 @@ mod tests { } } - for operation in operations { - match operation { - CircuitOperationChoice::Add(lhs, rhs) => { - if !all_nodes.is_empty() { - let lhs_index = lhs % all_nodes.len(); - let rhs_index = rhs % all_nodes.len(); - - let lhs_node = all_nodes[lhs_index].clone(); - let rhs_node = all_nodes[rhs_index].clone(); - - let node = lhs_node + rhs_node; - all_nodes.push(node); + if !all_nodes.is_empty() { + for operation in operations { + let mut i = 0; + let mut j = 0; + let new_node = loop { + let new_node = match operation { + CircuitOperationChoice::Add(lhs, rhs) => { + let lhs_index = (lhs + i) % all_nodes.len(); + let rhs_index = (rhs + j) % all_nodes.len(); + + let lhs_node = all_nodes[lhs_index].clone(); + let rhs_node = all_nodes[rhs_index].clone(); + + lhs_node + rhs_node + } + CircuitOperationChoice::Mul(lhs, rhs) => { + let lhs_index = (lhs + i) % all_nodes.len(); + let rhs_index = (rhs + j) % all_nodes.len(); + + let lhs_node = all_nodes[lhs_index].clone(); + let rhs_node = all_nodes[rhs_index].clone(); + + lhs_node * rhs_node + } + }; + + if matches!( + new_node.circuit.borrow().expression, + CircuitExpression::BinOp(_, _, _) + ) { + break new_node; } - } - CircuitOperationChoice::Mul(lhs, rhs) => { - if !all_nodes.is_empty() { - let lhs_index = lhs % all_nodes.len(); - let rhs_index = rhs % all_nodes.len(); - let lhs_node = all_nodes[lhs_index].clone(); - let rhs_node = all_nodes[rhs_index].clone(); + j += 1; + if j == all_nodes.len() { + i += 1; - let node = lhs_node * rhs_node; - all_nodes.push(node); + assert_ne!( + all_nodes.len(), + i, + "Must be able to construct binop from available nodes" + ); } - } + }; + + all_nodes.push(new_node); } } @@ -1832,8 +1853,11 @@ mod tests { ); prop_assert_eq!(10, ConstraintCircuitMonad::num_inputs(&multicircuit_monad)); - prop_assert_eq!(10, ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad) - + ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad)); + prop_assert_eq!( + 10, + ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad) + + ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad) + ); } /// Test the completeness and soundness of the `apply_substitution` function, @@ -1842,7 +1866,7 @@ mod tests { /// In this context, completeness means: #[proptest] fn node_substitution_is_complete_and_sound( - #[strategy(arbitrary_circuit_monad(5, 5, 0, 20, 5))] mut multicircuit_monad: Vec< + #[strategy(arbitrary_circuit_monad(10, 10, 10, 100, 10))] mut multicircuit_monad: Vec< ConstraintCircuitMonad, >, #[strategy(vec(arb::(), ConstraintCircuitMonad::num_main_inputs(&#multicircuit_monad)))] @@ -1872,14 +1896,33 @@ mod tests { // apply one step of degree-lowering let num_nodes = multicircuit_monad[0].builder.all_nodes.borrow().len(); - let substitution_node_id = multicircuit_monad[0] + let mut substitution_node_index = substitution_node_index; + + // Find a node to substitute, and ensure that this node is a binop + // node, as a constant node might be folded into another constant + // later on. + let (mut substitution_node_id, mut node_for_substitution) = multicircuit_monad[0] .builder .all_nodes .borrow() .iter() .nth(substitution_node_index % num_nodes) - .map(|(i, _n)| *i) + .map(|(i, n)| (*i, n.clone())) .unwrap(); + while !matches!( + &node_for_substitution.circuit.borrow().expression, + CircuitExpression::BinOp(_, _, _), + ) { + (substitution_node_id, node_for_substitution) = multicircuit_monad[0] + .builder + .all_nodes + .borrow() + .iter() + .nth(substitution_node_index % num_nodes) + .map(|(i, n)| (*i, n.clone())) + .unwrap(); + substitution_node_index += 1; + } let degree_lowering_info = DegreeLoweringInfo { target_degree: 2,