diff --git a/triton-constraint-circuit/src/lib.rs b/triton-constraint-circuit/src/lib.rs index 68ba3405..c3539a9f 100644 --- a/triton-constraint-circuit/src/lib.rs +++ b/triton-constraint-circuit/src/lib.rs @@ -781,25 +781,23 @@ impl ConstraintCircuitMonad { multicircuit: &[ConstraintCircuitMonad], target_degree: isize, ) -> usize { - // The relevant fields of a ConstraintCircuit. - // Avoids interior mutability in a HashSet, which is a foot gun. - #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] - struct PartialNodeInfo { - id: usize, - degree: isize, - } - assert!(!multicircuit.is_empty()); let multicircuit = multicircuit .iter() .map(|c| c.clone().consume()) .collect_vec(); + // Computing all node degree is slow; this cache de-duplicates work. + let node_degrees = Self::all_nodes_in_multicircuit(&multicircuit) + .into_iter() + .map(|node| (node.id, node.degree())) + .collect::>(); + // Only nodes with degree > target_degree need changing. let high_degree_nodes = Self::all_nodes_in_multicircuit(&multicircuit) .into_iter() + .filter(|node| node_degrees[&node.id] > target_degree) .unique() - .filter(|node| node.degree() > target_degree) .collect_vec(); // Collect all candidates for substitution, i.e., descendents of high_degree_nodes @@ -807,11 +805,8 @@ impl ConstraintCircuitMonad { // Substituting a node of degree 1 is both pointless and can lead to infinite iteration. let low_degree_nodes = Self::all_nodes_in_multicircuit(&high_degree_nodes) .into_iter() - .filter(|node| 1 < node.degree() && node.degree() <= target_degree) - .map(|node| PartialNodeInfo { - id: node.id, - degree: node.degree(), - }) + .filter(|node| 1 < node_degrees[&node.id] && node_degrees[&node.id] <= target_degree) + .map(|node| node.id) .collect_vec(); // If the resulting list is empty, there is no way forward. Stop – panic time! @@ -819,20 +814,25 @@ impl ConstraintCircuitMonad { // Of the remaining nodes, keep the ones occurring the most often. let mut nodes_and_occurrences = HashMap::new(); - for node in &low_degree_nodes { + for node in low_degree_nodes { *nodes_and_occurrences.entry(node).or_insert(0) += 1; } let max_occurrences = nodes_and_occurrences.iter().map(|(_, &c)| c).max().unwrap(); nodes_and_occurrences.retain(|_, &mut count| count == max_occurrences); - let mut candidate_nodes = nodes_and_occurrences.keys().copied().collect_vec(); + let mut candidate_node_ids = nodes_and_occurrences.keys().copied().collect_vec(); // If there are still multiple nodes, pick the one with the highest degree. - let max_degree = candidate_nodes.iter().map(|n| n.degree).max().unwrap(); - candidate_nodes.retain(|node| node.degree == max_degree); + let max_degree = candidate_node_ids + .iter() + .map(|node_id| node_degrees[node_id]) + .max() + .unwrap(); + candidate_node_ids.retain(|node_id| node_degrees[node_id] == max_degree); + + candidate_node_ids.sort_unstable(); // If there are still multiple nodes, pick any one – but deterministically so. - candidate_nodes.sort_unstable_by_key(|node| node.id); - candidate_nodes[0].id + candidate_node_ids.into_iter().min().unwrap() } /// Returns all nodes used in the multicircuit.