Skip to content

Commit

Permalink
wip: Fix arb circuit test/constructor to account for constant-folding
Browse files Browse the repository at this point in the history
The result of `let c = b_constant(5) + b_constant(4);` is not a `BinOp`
expression but rather a constant expression, as the constant folding
happens at the time of node-insertion and not at a later stage. So we
need two "do-while" loops to account for this possibility, one in the
"arbitrary" implementation and one in the proptest.
  • Loading branch information
Sword-Smith committed Sep 17, 2024
1 parent e117c84 commit 63001bd
Showing 1 changed file with 71 additions and 28 deletions.
99 changes: 71 additions & 28 deletions triton-constraint-circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,8 @@ mod tests {
use std::hash::Hasher;

use itertools::Itertools;
use ndarray::{Array2, Axis};
use ndarray::Array2;
use ndarray::Axis;
use proptest::arbitrary::Arbitrary;
use proptest::collection::vec;
use proptest::prelude::*;
Expand Down Expand Up @@ -1774,32 +1775,52 @@ mod tests {
}
}

for operation in operations {
match operation {
CircuitOperationChoice::Add(lhs, rhs) => {
if !all_nodes.is_empty() {
let lhs_index = lhs % all_nodes.len();
let rhs_index = rhs % all_nodes.len();

let lhs_node = all_nodes[lhs_index].clone();
let rhs_node = all_nodes[rhs_index].clone();

let node = lhs_node + rhs_node;
all_nodes.push(node);
if !all_nodes.is_empty() {
for operation in operations {
let mut i = 0;
let mut j = 0;
let new_node = loop {
let new_node = match operation {
CircuitOperationChoice::Add(lhs, rhs) => {
let lhs_index = (lhs + i) % all_nodes.len();
let rhs_index = (rhs + j) % all_nodes.len();

let lhs_node = all_nodes[lhs_index].clone();
let rhs_node = all_nodes[rhs_index].clone();

lhs_node + rhs_node
}
CircuitOperationChoice::Mul(lhs, rhs) => {
let lhs_index = (lhs + i) % all_nodes.len();
let rhs_index = (rhs + j) % all_nodes.len();

let lhs_node = all_nodes[lhs_index].clone();
let rhs_node = all_nodes[rhs_index].clone();

lhs_node * rhs_node
}
};

if matches!(
new_node.circuit.borrow().expression,
CircuitExpression::BinOp(_, _, _)
) {
break new_node;
}
}
CircuitOperationChoice::Mul(lhs, rhs) => {
if !all_nodes.is_empty() {
let lhs_index = lhs % all_nodes.len();
let rhs_index = rhs % all_nodes.len();

let lhs_node = all_nodes[lhs_index].clone();
let rhs_node = all_nodes[rhs_index].clone();
j += 1;
if j == all_nodes.len() {
i += 1;

let node = lhs_node * rhs_node;
all_nodes.push(node);
assert_ne!(
all_nodes.len(),
i,
"Must be able to construct binop from available nodes"
);
}
}
};

all_nodes.push(new_node);
}
}

Expand Down Expand Up @@ -1832,8 +1853,11 @@ mod tests {
);

prop_assert_eq!(10, ConstraintCircuitMonad::num_inputs(&multicircuit_monad));
prop_assert_eq!(10, ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad)
+ ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad));
prop_assert_eq!(
10,
ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad)
+ ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad)
);
}

/// Test the completeness and soundness of the `apply_substitution` function,
Expand All @@ -1842,7 +1866,7 @@ mod tests {
/// In this context, completeness means:
#[proptest]
fn node_substitution_is_complete_and_sound(
#[strategy(arbitrary_circuit_monad(5, 5, 0, 20, 5))] mut multicircuit_monad: Vec<
#[strategy(arbitrary_circuit_monad(10, 10, 10, 100, 10))] mut multicircuit_monad: Vec<
ConstraintCircuitMonad<SingleRowIndicator>,
>,
#[strategy(vec(arb::<BFieldElement>(), ConstraintCircuitMonad::num_main_inputs(&#multicircuit_monad)))]
Expand Down Expand Up @@ -1872,14 +1896,33 @@ mod tests {

// apply one step of degree-lowering
let num_nodes = multicircuit_monad[0].builder.all_nodes.borrow().len();
let substitution_node_id = multicircuit_monad[0]
let mut substitution_node_index = substitution_node_index;

// Find a node to substitute, and ensure that this node is a binop
// node, as a constant node might be folded into another constant
// later on.
let (mut substitution_node_id, mut node_for_substitution) = multicircuit_monad[0]
.builder
.all_nodes
.borrow()
.iter()
.nth(substitution_node_index % num_nodes)
.map(|(i, _n)| *i)
.map(|(i, n)| (*i, n.clone()))
.unwrap();
while !matches!(
&node_for_substitution.circuit.borrow().expression,
CircuitExpression::BinOp(_, _, _),
) {
(substitution_node_id, node_for_substitution) = multicircuit_monad[0]
.builder
.all_nodes
.borrow()
.iter()
.nth(substitution_node_index % num_nodes)
.map(|(i, n)| (*i, n.clone()))
.unwrap();
substitution_node_index += 1;
}

let degree_lowering_info = DegreeLoweringInfo {
target_degree: 2,
Expand Down

0 comments on commit 63001bd

Please sign in to comment.