From e117c84e88187fd64aea40e3f1da043e34716e2c Mon Sep 17 00:00:00 2001 From: Alan Szepieniec Date: Tue, 17 Sep 2024 17:48:54 +0200 Subject: [PATCH] wip: Add proptest for compleness and soundness of `apply_substitution` --- triton-constraint-circuit/src/lib.rs | 339 +++++++++++++++++++++------ triton-vm/src/table/master_table.rs | 8 +- 2 files changed, 267 insertions(+), 80 deletions(-) diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index 09ee75eb..2db4f5b6 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -784,8 +784,8 @@ impl ConstraintCircuitMonad { multicircuit: &mut [Self], info: DegreeLoweringInfo, chosen_node_id: usize, - main_constraints_count: usize, - aux_constraints_count: usize, + new_main_constraints_count: usize, + new_aux_constraints_count: usize, ) -> ConstraintCircuitMonad { let builder = multicircuit[0].builder.clone(); @@ -793,10 +793,10 @@ impl ConstraintCircuitMonad { let chosen_node = builder.all_nodes.borrow()[&chosen_node_id].clone(); let chosen_node_is_main_col = chosen_node.circuit.borrow().evaluates_to_base_element(); let new_input_indicator = if chosen_node_is_main_col { - let new_main_col_idx = info.num_main_cols + main_constraints_count; + let new_main_col_idx = info.num_main_cols + new_main_constraints_count; II::main_table_input(new_main_col_idx) } else { - let new_aux_col_idx = info.num_aux_cols + aux_constraints_count; + let new_aux_col_idx = info.num_aux_cols + new_aux_constraints_count; II::aux_table_input(new_aux_col_idx) }; let new_variable = builder.input(new_input_indicator); @@ -904,7 +904,7 @@ impl ConstraintCircuitMonad { /// Counts the number of nodes in this multicircuit. Only counts nodes that /// are used; not nodes that have been forgotten. - pub fn num_nodes(constraints: &[Self]) -> usize { + pub fn num_visible_nodes(constraints: &[Self]) -> usize { constraints .iter() .flat_map(|ccm| Self::all_nodes_in_circuit(&ccm.circuit.borrow())) @@ -1104,6 +1104,7 @@ mod tests { use std::hash::Hasher; use itertools::Itertools; + use ndarray::{Array2, Axis}; use proptest::arbitrary::Arbitrary; use proptest::collection::vec; use proptest::prelude::*; @@ -1211,6 +1212,133 @@ mod tests { let other_expression = &other.circuit.borrow().expression; self_expression.contains(other_expression) } + + /// Counts the number of inputs from the main table + fn num_nodes(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .len() + } + + /// Counts the number of inputs from the main table + fn num_main_inputs(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + if let CircuitExpression::Input(ii) = cc.circuit.as_ref().borrow().expression { + ii.is_main_table_column() + } else { + false + } + }) + .filter(|(_, cc)| cc.circuit.borrow().evaluates_to_base_element()) + .count() + } + + /// Counts the number of inputs from the aux table + fn num_aux_inputs(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + if let CircuitExpression::Input(ii) = cc.circuit.as_ref().borrow().expression { + !ii.is_main_table_column() + } else { + false + } + }) + .count() + } + + /// Counts the number of total (*i.e.*, main + aux) inputs + fn num_inputs(constraints: &[Self]) -> usize { + Self::num_main_inputs(constraints) + Self::num_aux_inputs(constraints) + } + + /// Counts the number of challenges + fn num_challenges(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + matches!( + cc.circuit.as_ref().borrow().expression, + CircuitExpression::Challenge(_) + ) + }) + .count() + } + + /// Counts the number of `BinOp`s + fn num_binops(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + matches!( + cc.circuit.as_ref().borrow().expression, + CircuitExpression::BinOp(_, _, _) + ) + }) + .count() + } + + /// Counts the number of BFE constants + fn num_bfield_constants(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + matches!( + cc.circuit.as_ref().borrow().expression, + CircuitExpression::BConst(_) + ) + }) + .count() + } + + /// Counts the number of XFE constants + fn num_xfield_constants(constraints: &[Self]) -> usize { + constraints + .first() + .unwrap() + .builder + .all_nodes + .borrow() + .iter() + .filter(|(_, cc)| { + matches!( + cc.circuit.as_ref().borrow().expression, + CircuitExpression::XConst(_) + ) + }) + .count() + } } impl CircuitExpression { @@ -1590,13 +1718,13 @@ mod tests { Extension(#[strategy(arb())] XFieldElement), } - fn arbitrary_circuit( + fn arbitrary_circuit_monad( num_inputs: usize, num_challenges: usize, num_constants: usize, num_nodes: usize, num_outputs: usize, - ) -> BoxedStrategy>> { + ) -> BoxedStrategy>> { ( vec(CircuitInputType::arbitrary(), num_inputs), vec(CircuitConstantType::arbitrary(), num_constants), @@ -1683,87 +1811,146 @@ mod tests { } output_nodes - .into_iter() - .map(|node| node.consume()) - .collect_vec() }) .boxed() } - fn all_main_inputs(multicircuit: &[ConstraintCircuit]) -> Vec { - multicircuit - .iter() - .flat_map(|c| match &c.expression { - CircuitExpression::Input(ii) => { - if ii.is_main_table_column() { - vec![ii.column()] - } else { - vec![] - } - } - CircuitExpression::BinOp(_, lhs, rhs) => { - all_main_inputs(&[lhs.borrow().clone(), rhs.borrow().clone()]) - } - _ => { - vec![] - } - }) - .unique() - .collect_vec() - } - - fn all_aux_inputs(multicircuit: &[ConstraintCircuit]) -> Vec { - multicircuit - .iter() - .flat_map(|c| match &c.expression { - CircuitExpression::Input(ii) => { - if ii.is_main_table_column() { - vec![] - } else { - vec![ii.column()] - } - } - CircuitExpression::BinOp(_, lhs, rhs) => { - all_aux_inputs(&[lhs.borrow().clone(), rhs.borrow().clone()]) - } - _ => { - vec![] - } - }) - .unique() - .collect_vec() - } + #[proptest] + fn node_type_counts_add_up( + #[strategy(arbitrary_circuit_monad(10, 10, 10, 100, 10))] multicircuit_monad: Vec< + ConstraintCircuitMonad, + >, + ) { + prop_assert_eq!( + ConstraintCircuitMonad::num_nodes(&multicircuit_monad), + ConstraintCircuitMonad::num_main_inputs(&multicircuit_monad) + + ConstraintCircuitMonad::num_aux_inputs(&multicircuit_monad) + + ConstraintCircuitMonad::num_challenges(&multicircuit_monad) + + ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad) + + ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad) + + ConstraintCircuitMonad::num_binops(&multicircuit_monad) + ); - fn all_challenges(multicircuit: &[ConstraintCircuit]) -> Vec { - multicircuit - .iter() - .flat_map(|c| match &c.expression { - CircuitExpression::Challenge(ch) => { - vec![*ch] - } - CircuitExpression::BinOp(_, lhs, rhs) => { - all_challenges(&[lhs.borrow().clone(), rhs.borrow().clone()]) - } - _ => { - vec![] - } - }) - .unique() - .collect_vec() + prop_assert_eq!(10, ConstraintCircuitMonad::num_inputs(&multicircuit_monad)); + prop_assert_eq!(10, ConstraintCircuitMonad::num_bfield_constants(&multicircuit_monad) + + ConstraintCircuitMonad::num_xfield_constants(&multicircuit_monad)); } + /// Test the completeness and soundness of the `apply_substitution` function, + /// which substitutes a single node. + /// + /// In this context, completeness means: #[proptest] - fn node_substitution_is_complete( - #[strategy(arbitrary_circuit(10, 10, 10, 100, 10))] multicircuit: Vec< - ConstraintCircuit, + fn node_substitution_is_complete_and_sound( + #[strategy(arbitrary_circuit_monad(5, 5, 0, 20, 5))] mut multicircuit_monad: Vec< + ConstraintCircuitMonad, >, - #[strategy(vec(arb::(), 1+all_main_inputs(&#multicircuit).into_iter().max().unwrap()))] + #[strategy(vec(arb::(), ConstraintCircuitMonad::num_main_inputs(&#multicircuit_monad)))] + #[filter(!#main_input.is_empty())] main_input: Vec, - #[strategy(vec(arb::(), 1+all_aux_inputs(&#multicircuit).into_iter().max().unwrap()))] + #[strategy(vec(arb::(), ConstraintCircuitMonad::num_aux_inputs(&#multicircuit_monad)))] + #[filter(!#aux_input.is_empty())] aux_input: Vec, - #[strategy(vec(arb::(), 1+all_challenges(&#multicircuit).into_iter().max().unwrap()))] + #[strategy(vec(arb::(), ConstraintCircuitMonad::num_challenges(&#multicircuit_monad)))] challenges: Vec, + #[strategy(arb())] substitution_node_index: usize, ) { - todo!() + let mut main_input = Array2::from_shape_vec((1, main_input.len()), main_input).unwrap(); + let mut aux_input = Array2::from_shape_vec((1, aux_input.len()), aux_input).unwrap(); + + // compute circuit output before degree-lowering + let output_before_lowering = multicircuit_monad + .iter() + .map(|constraint| { + constraint.circuit.borrow().evaluate( + main_input.view(), + aux_input.view(), + &challenges, + ) + }) + .collect_vec(); + + // apply one step of degree-lowering + let num_nodes = multicircuit_monad[0].builder.all_nodes.borrow().len(); + let substitution_node_id = multicircuit_monad[0] + .builder + .all_nodes + .borrow() + .iter() + .nth(substitution_node_index % num_nodes) + .map(|(i, _n)| *i) + .unwrap(); + + let degree_lowering_info = DegreeLoweringInfo { + target_degree: 2, + num_main_cols: main_input.len(), + num_aux_cols: aux_input.len(), + }; + let substitution_constraint = ConstraintCircuitMonad::apply_substitution( + &mut multicircuit_monad, + degree_lowering_info, + substitution_node_id, + 0, + 0, + ); + + // extract substituted constraint + let CircuitExpression::BinOp(BinOp::Add, variable, neg_expression) = + &substitution_constraint.circuit.as_ref().borrow().expression + else { + unreachable!(); + }; + let CircuitExpression::BinOp(BinOp::Mul, _neg_one, expression) = + &neg_expression.as_ref().borrow().expression + else { + unreachable!(); + }; + + // extend input consistently with the introduced variable + let extra_input = + expression + .borrow() + .evaluate(main_input.view(), aux_input.view(), &challenges); + if variable.borrow().evaluates_to_base_element() { + main_input + .append( + Axis(1), + Array2::from_shape_vec([1, 1], vec![extra_input.coefficients[0]]) + .unwrap() + .view(), + ) + .unwrap(); + } else { + aux_input + .append( + Axis(1), + Array2::from_shape_vec([1, 1], vec![extra_input]) + .unwrap() + .view(), + ) + .unwrap(); + } + + // evaluate again + let output_after_lowering = multicircuit_monad + .iter() + .map(|constraint| { + constraint.circuit.borrow().evaluate( + main_input.view(), + aux_input.view(), + &challenges, + ) + }) + .collect_vec(); + + // assert same value in original constraints + prop_assert_eq!(output_before_lowering, output_after_lowering); + + // assert zero in substitution constraint + prop_assert!(substitution_constraint + .circuit + .borrow() + .evaluate(main_input.view(), aux_input.view(), &challenges) + .is_zero()); } } diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index e8916a54..4b6782c6 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -1760,13 +1760,13 @@ mod tests { ft = format!("{ft}\n"); let num_nodes_in_all_initial_constraints = - ConstraintCircuitMonad::num_nodes(&all_initial_constraints); + ConstraintCircuitMonad::num_visible_nodes(&all_initial_constraints); let num_nodes_in_all_consistency_constraints = - ConstraintCircuitMonad::num_nodes(&all_consistency_constraints); + ConstraintCircuitMonad::num_visible_nodes(&all_consistency_constraints); let num_nodes_in_all_transition_constraints = - ConstraintCircuitMonad::num_nodes(&all_transition_constraints); + ConstraintCircuitMonad::num_visible_nodes(&all_transition_constraints); let num_nodes_in_all_terminal_constraints = - ConstraintCircuitMonad::num_nodes(&all_terminal_constraints); + ConstraintCircuitMonad::num_visible_nodes(&all_terminal_constraints); ft = format!( "{ft}| {:<46} | {:>8} | {:>12} | {:>11} | {:>9} |", "(# nodes)",