Skip to content

Commit

Permalink
Improved raw_union interface, fixed EGraph::dump and updated edition
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Feb 11, 2024
1 parent 8370122 commit c18f6d4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["Max Willsey <[email protected]>"]
categories = ["data-structures"]
description = "An implementation of egraphs"
edition = "2018"
edition = "2021"
keywords = ["e-graphs"]
license = "MIT"
name = "egg"
Expand Down
2 changes: 1 addition & 1 deletion src/eclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl<L: Language, D: Debug> Debug for EClassData<L, D> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut nodes = self.nodes.clone();
nodes.sort();
writeln!(f, "({:?}): {:?}", self.data, nodes)
write!(f, "({:?}): {:?}", self.data, nodes)
}
}

Expand Down
42 changes: 18 additions & 24 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,11 +803,27 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
N::pre_union(self, enode_id1, enode_id2, &rule);

self.clean = false;
if let Some((id, class2)) = self.inner.raw_union(enode_id1, enode_id2) {
self.merge(id, class2);
let mut new_root = None;
self.inner
.raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| {
new_root = Some(id1);

let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
self.analysis_pending.extend(p1);
}
if did_merge.1 {
self.analysis_pending.extend(p2);
}

concat_vecs(&mut class1.nodes, class2.nodes);
});
if let Some(id) = new_root {
if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
}
N::modify(self, id);

true
} else {
if let Some(Justification::Rule(_)) = rule {
Expand All @@ -819,28 +835,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

fn merge(&mut self, id1: Id, class2: EClass<L, N::Data>) {
let class1 = self.inner.get_class_mut_with_cannon(id1).0;
let (class2, parents) = class2.destruct();
let did_merge = self.analysis.merge(&mut class1.data, class2.data);
if did_merge.0 {
// class1.parents already contains the combined parents,
// so we only take the ones that were there before the union
self.analysis_pending.extend(
class1
.parents()
.take(class1.parents().len() - parents.len()),
);
}
if did_merge.1 {
self.analysis_pending.extend(parents);
}

concat_vecs(&mut class1.nodes, class2.nodes);

N::modify(self, id1)
}

/// Update the analysis data of an e-class.
///
/// This also propagates the changes through the e-graph,
Expand Down
46 changes: 36 additions & 10 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@ use std::ops::{Deref, DerefMut};
use std::{
borrow::BorrowMut,
fmt::{self, Debug},
iter, slice,
};

#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};

pub struct Parents<'a>(&'a [Id]);

impl<'a> IntoIterator for Parents<'a> {
type Item = Id;
type IntoIter = iter::Copied<slice::Iter<'a, Id>>;

fn into_iter(self) -> Self::IntoIter {
self.0.iter().copied()
}
}

/// A [`RawEGraph`] without its classes that can be obtained by dereferencing a [`RawEGraph`].
///
/// It exists as a separate type so that it can still be used while mutably borrowing a [`RawEClass`]
Expand Down Expand Up @@ -506,16 +518,18 @@ impl<L: Language, D> RawEGraph<L, D> {
///
/// The given ids need not be canonical.
///
/// Returns `None` if the two ids were already equivalent.
///
/// Returns `Some((id, class))` if two classes were merged where `id` is the id of the newly merged class
/// and `class` is the old `RawEClass` that merged into `id`
/// If a union occurs, `merge` is called with the data, id, and parents of the two eclasses being merged
#[inline]
pub fn raw_union(&mut self, enode_id1: Id, enode_id2: Id) -> Option<(Id, RawEClass<D>)> {
pub fn raw_union(
&mut self,
enode_id1: Id,
enode_id2: Id,
merge: impl FnOnce(&mut D, Id, Parents<'_>, D, Id, Parents<'_>),
) {
let mut id1 = self.find_mut(enode_id1);
let mut id2 = self.find_mut(enode_id2);
if id1 == id2 {
return None;
return;
}
// make sure class2 has fewer parents
let class1_parents = self.classes[&id1].parents.len();
Expand All @@ -531,11 +545,19 @@ impl<L: Language, D> RawEGraph<L, D> {
let class2 = self.classes.remove(&id2).unwrap();
let class1 = self.classes.get_mut(&id1).unwrap();
assert_eq!(id1, class1.id);
let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents));
merge(
&mut class1.raw_data,
class1.id,
p1,
class2.raw_data,
class2.id,
p2,
);

self.pending.extend(class2.parents());
self.pending.extend(&class2.parents);

class1.parents.extend(class2.parents());
Some((id1, class2))
class1.parents.extend(class2.parents);
}

#[inline]
Expand Down Expand Up @@ -615,7 +637,11 @@ impl<L: Language> RawEGraph<L, ()> {

/// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data
pub fn union(&mut self, id1: Id, id2: Id) -> bool {
Self::raw_union(self, id1, id2).is_some()
let mut unioned = false;
self.raw_union(id1, id2, |_, _, _, _, _, _| {
unioned = true;
});
unioned
}

/// Simplified version of [`raw_rebuild`](RawEGraph::raw_rebuild) for egraphs without eclass data
Expand Down

0 comments on commit c18f6d4

Please sign in to comment.