Skip to content

Commit

Permalink
eliminated node field of ExplainNode (used EGraph.nodes instead)
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Jan 3, 2024
1 parent 1f838c6 commit 3145a30
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 181 deletions.
107 changes: 78 additions & 29 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,11 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Make a copy of the egraph with the same nodes, but no unions between them.
pub fn copy_without_unions(&self, analysis: N) -> Self {
if let Some(explain) = &self.explain {
let egraph = Self::new(analysis);
explain.populate_enodes(egraph)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions");
let mut egraph = Self::new(analysis);
for node in &self.nodes {
egraph.add(node.clone());
}
egraph
}

/// Performs the union between two egraphs.
Expand Down Expand Up @@ -342,32 +341,70 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// was obtained (see [`add_uncanoncial`](EGraph::add_uncanonical),
/// [`add_expr_uncanonical`](EGraph::add_expr_uncanonical))
pub fn id_to_expr(&self, id: Id) -> RecExpr<L> {
if let Some(explain) = &self.explain {
explain.node_to_recexpr(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
let mut res = Default::default();
let mut cache = Default::default();
self.id_to_expr_internal(&mut res, id, &mut cache);
res
}

fn id_to_expr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let new_node = self
.id_to_node(node_id)
.clone()
.map_children(|child| self.id_to_expr_internal(res, child, cache));
let res_id = res.add(new_node);
cache.insert(node_id, res_id);
res_id
}

/// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
pub fn id_to_node(&self, id: Id) -> &L {
if let Some(explain) = &self.explain {
explain.node(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
}
&self.nodes[usize::from(id)]
}

/// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
/// When an eclass listed in the given substitutions is found, it creates a variable.
/// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
/// Otherwise it behaves like [`id_to_expr`](EGraph::id_to_expr).
pub fn id_to_pattern(&self, id: Id, substitutions: &HashMap<Id, Id>) -> (Pattern<L>, Subst) {
if let Some(explain) = &self.explain {
explain.node_to_pattern(id, substitutions)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique patterns per id");
let mut res = Default::default();
let mut subst = Default::default();
let mut cache = Default::default();
self.id_to_pattern_internal(&mut res, id, substitutions, &mut subst, &mut cache);
(Pattern::new(res), subst)
}

fn id_to_pattern_internal(
&self,
res: &mut PatternAst<L>,
node_id: Id,
var_substitutions: &HashMap<Id, Id>,
subst: &mut Subst,
cache: &mut HashMap<Id, Id>,
) -> Id {
if let Some(existing) = cache.get(&node_id) {
return *existing;
}
let res_id = if let Some(existing) = var_substitutions.get(&node_id) {
let var = format!("?{}", node_id).parse().unwrap();
subst.insert(var, *existing);
res.add(ENodeOrVar::Var(var))
} else {
let new_node = self.id_to_node(node_id).clone().map_children(|child| {
self.id_to_pattern_internal(res, child, var_substitutions, subst, cache)
});
res.add(ENodeOrVar::ENode(new_node))
};
cache.insert(node_id, res_id);
res_id
}

/// Get all the unions ever found in the egraph in terms of enode ids.
Expand All @@ -393,17 +430,19 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Get the number of congruences between nodes in the egraph.
/// Only available when explanations are enabled.
pub fn get_num_congr(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_congr::<N>(&self.classes, &self.unionfind)
if let Some(explain) = &mut self.explain {
explain
.with_nodes(&self.nodes)
.get_num_congr::<N>(&self.classes, &self.unionfind)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
}

/// Get the number of nodes in the egraph used for explanations.
pub fn get_explanation_num_nodes(&mut self) -> usize {
if let Some(explain) = &self.explain {
explain.get_num_nodes()
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).get_num_nodes()
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand Down Expand Up @@ -441,7 +480,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -464,7 +508,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// but more efficient
fn explain_existance_id(&mut self, id: Id) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -478,7 +522,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
) -> Explanation<L> {
let id = self.add_instantiation_noncanonical(pattern, subst);
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
explain.with_nodes(&self.nodes).explain_existance(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.")
}
Expand All @@ -501,7 +545,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
);
}
if let Some(explain) = &mut self.explain {
explain.explain_equivalence::<N>(left, right, &mut self.unionfind, &self.classes)
explain.with_nodes(&self.nodes).explain_equivalence::<N>(
left,
right,
&mut self.unionfind,
&self.classes,
)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get explanations.");
}
Expand Down Expand Up @@ -1213,9 +1262,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
n_unions
}

pub(crate) fn check_each_explain(&self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &self.explain {
explain.check_each_explain(rules)
pub(crate) fn check_each_explain(&mut self, rules: &[&Rewrite<L, N>]) -> bool {
if let Some(explain) = &mut self.explain {
explain.with_nodes(&self.nodes).check_each_explain(rules)
} else {
panic!("Can't check explain when explanations are off");
}
Expand Down
Loading

0 comments on commit 3145a30

Please sign in to comment.