diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index 2dafecf4..2c7afff4 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -1230,53 +1230,53 @@ mod tests { self_expression.contains(other_expression) } + /// Produces an iter over all nodes in the multicircuit, if it is non-empty. + /// + /// Helper function for counting the number of nodes of a specific type. + fn iter_nodes( + constraints: &[Self], + ) -> std::vec::IntoIter<(usize, ConstraintCircuitMonad)> { + if let Some(first) = constraints.first() { + first + .builder + .all_nodes + .borrow() + .iter() + .map(|(n, m)| (*n, m.clone())) + .collect_vec() + .into_iter() + } else { + vec![].into_iter() + } + } + /// Counts the number of inputs from the main table fn num_nodes(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .len() + Self::iter_nodes(constraints).count() + } + + /// Determine if the constraint circuit monad corresponds to a main table + /// column. + fn is_main_table_column(&self) -> bool { + if let CircuitExpression::Input(ii) = self.circuit.as_ref().borrow().expression { + ii.is_main_table_column() + } else { + false + } } /// Counts the number of inputs from the main table fn num_main_inputs(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() - .filter(|(_, cc)| { - if let CircuitExpression::Input(ii) = cc.circuit.as_ref().borrow().expression { - ii.is_main_table_column() - } else { - false - } - }) + Self::iter_nodes(constraints) + .filter(|(_, cc)| cc.is_main_table_column()) .filter(|(_, cc)| cc.circuit.borrow().evaluates_to_base_element()) .count() } /// Counts the number of inputs from the aux table fn num_aux_inputs(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() - .filter(|(_, cc)| { - if let CircuitExpression::Input(ii) = cc.circuit.as_ref().borrow().expression { - !ii.is_main_table_column() - } else { - false - } - }) + Self::iter_nodes(constraints) + .filter(|(_, cc)| !cc.is_main_table_column()) .count() } @@ -1287,13 +1287,7 @@ mod tests { /// Counts the number of challenges fn num_challenges(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() + Self::iter_nodes(constraints) .filter(|(_, cc)| { matches!( cc.circuit.as_ref().borrow().expression, @@ -1305,13 +1299,7 @@ mod tests { /// Counts the number of `BinOp`s fn num_binops(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() + Self::iter_nodes(constraints) .filter(|(_, cc)| { matches!( cc.circuit.as_ref().borrow().expression, @@ -1323,13 +1311,7 @@ mod tests { /// Counts the number of BFE constants fn num_bfield_constants(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() + Self::iter_nodes(constraints) .filter(|(_, cc)| { matches!( cc.circuit.as_ref().borrow().expression, @@ -1341,13 +1323,7 @@ mod tests { /// Counts the number of XFE constants fn num_xfield_constants(constraints: &[Self]) -> usize { - constraints - .first() - .unwrap() - .builder - .all_nodes - .borrow() - .iter() + Self::iter_nodes(constraints) .filter(|(_, cc)| { matches!( cc.circuit.as_ref().borrow().expression,