Skip to content

Commit

Permalink
Switched from storing eclasses in a HashMap to a Vec
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Dec 10, 2024
1 parent 0eb35e4 commit 81cc122
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 128 deletions.
8 changes: 4 additions & 4 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1506,9 +1506,9 @@ impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
self.calculate_parent_distance(right, ancestor, distance_memo);

// now all three share an ancestor
let a = self.calculate_parent_distance(ancestor, Id::from(usize::MAX), distance_memo);
let b = self.calculate_parent_distance(left, Id::from(usize::MAX), distance_memo);
let c = self.calculate_parent_distance(right, Id::from(usize::MAX), distance_memo);
let a = self.calculate_parent_distance(ancestor, Id::MAX, distance_memo);
let b = self.calculate_parent_distance(left, Id::MAX, distance_memo);
let c = self.calculate_parent_distance(right, Id::MAX, distance_memo);

assert!(
distance_memo.parent_distance[usize::from(ancestor)].0
Expand Down Expand Up @@ -1582,7 +1582,7 @@ impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
let new_dist = dist + distance_memo.parent_distance[usize::from(parent)].1;
distance_memo.parent_distance[usize::from(enode)] = (parent_parent, new_dist);
} else {
if ancestor == Id::from(usize::MAX) {
if ancestor == Id::MAX {
break;
}
if distance_memo.tree_depth.get(&parent).unwrap()
Expand Down
38 changes: 36 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]
#![forbid(unsafe_code)]
#![no_std]
/*!
`egg` (**e**-**g**raphs **g**ood) is a e-graph library optimized for equality saturation.
Expand Down Expand Up @@ -32,6 +30,7 @@ for less or more logging.
#![doc = include_str!("../tests/simple.rs")]
#![doc = "\n```"]

extern crate core;
extern crate no_std_compat as std;

#[cfg(feature = "egg_compat")]
Expand Down Expand Up @@ -70,6 +69,8 @@ mod multipattern;
#[cfg(feature = "egg_compat")]
mod pattern;

const U31_MAX: u32 = (1 << (u32::BITS - 1)) - 1;

/// Lower level egraph API
pub mod raw;

Expand All @@ -92,8 +93,14 @@ mod util;
#[cfg_attr(feature = "serde-1", serde(transparent))]
pub struct Id(u32);

impl Id {
/// Dummy id value
pub const MAX: Id = Id(U31_MAX);
}

impl From<usize> for Id {
fn from(n: usize) -> Id {
assert!(n <= U31_MAX as usize);
Id(n as u32)
}
}
Expand All @@ -116,6 +123,33 @@ impl std::fmt::Display for Id {
}
}

mod cid {
/// Index into the classes field of an [`EGraph`]
#[derive(Hash, Clone, Copy, Eq, PartialEq)]
pub struct ClassId(pub(crate) u32);

impl std::fmt::Debug for ClassId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl From<usize> for ClassId {
fn from(n: usize) -> ClassId {
assert!(n <= crate::U31_MAX as usize);
ClassId(n as u32)
}
}

impl ClassId {
pub(crate) fn idx(self) -> usize {
self.0 as usize
}
}
}

use cid::ClassId;

#[cfg(feature = "egg_compat")]
pub(crate) use {explain::Explain, raw::UnionFind};

Expand Down
89 changes: 51 additions & 38 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::raw::{util::HashMap, Language, RawEClass, RecExpr, UnionFind};
use crate::{dot::Dot, Id};
use crate::{dot::Dot, ClassId, Id};
use no_std_compat::prelude::v1::*;
use std::collections::BTreeMap;
use std::convert::Infallible;
Expand Down Expand Up @@ -364,7 +364,7 @@ pub struct RawEGraph<L: Language, D, U = ()> {
pub(super) pending: Vec<Id>,
/// `Id`s that are congruently equivalent to another `Id` that is not in this set
pub(super) congruence_duplicates: BitSet,
pub(super) classes: HashMap<Id, RawEClass<D>>,
pub(super) classes: Vec<RawEClass<D>>,
pub(super) undo_log: U,
}

Expand Down Expand Up @@ -407,11 +407,11 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
let classes: BTreeMap<_, _> = self
.classes
.iter()
.map(|(x, y)| {
.map(|y| {
let mut parents = y.parents.clone();
parents.sort_unstable();
(
*x,
y.id,
RawEClass {
id: y.id,
raw_data: &y.raw_data,
Expand All @@ -431,10 +431,7 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
impl<L: Language, D, U> RawEGraph<L, D, U> {
/// Returns an iterator over the eclasses in the egraph.
pub fn classes(&self) -> impl ExactSizeIterator<Item = &RawEClass<D>> {
self.classes.iter().map(|(id, class)| {
debug_assert_eq!(*id, class.id);
class
})
self.classes.iter()
}

/// Returns a mutating iterator over the eclasses in the egraph.
Expand All @@ -445,10 +442,7 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
impl ExactSizeIterator<Item = &mut RawEClass<D>>,
&mut EGraphResidual<L>,
) {
let iter = self.classes.iter_mut().map(|(id, class)| {
debug_assert_eq!(*id, class.id);
class
});
let iter = self.classes.iter_mut();
(iter, &mut self.residual)
}

Expand All @@ -460,15 +454,15 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
/// Returns the eclass corresponding to `id`
pub fn get_class<I: BorrowMut<Id>>(&self, mut id: I) -> &RawEClass<D> {
let id = id.borrow_mut();
*id = self.find(*id);
self.get_class_with_cannon(*id)
let (nid, cid) = self.unionfind.find_full(*id);
*id = nid;
&self.classes[cid.idx()]
}

/// Like [`get_class`](RawEGraph::get_class) but panics if `id` is not canonical
pub fn get_class_with_cannon(&self, id: Id) -> &RawEClass<D> {
self.classes
.get(&id)
.unwrap_or_else(|| panic!("Invalid id {}", id))
let cid = self.unionfind.find_canon(id);
&self.classes[cid.idx()]
}

/// Returns the eclass corresponding to `id`
Expand All @@ -478,21 +472,18 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
mut id: I,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
let id = id.borrow_mut();
*id = self.find_mut(*id);
self.get_class_mut_with_cannon(*id)
let (nid, cid) = self.unionfind.find_mut_full(*id);
*id = nid;
(&mut self.classes[cid.idx()], &mut self.residual)
}

/// Like [`get_class_mut`](RawEGraph::get_class_mut) but panics if `id` is not canonical
pub fn get_class_mut_with_cannon(
&mut self,
id: Id,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
(
self.classes
.get_mut(&id)
.unwrap_or_else(|| panic!("Invalid id {}", id)),
&mut self.residual,
)
let cid = self.unionfind.find_canon(id);
(&mut self.classes[cid.idx()], &mut self.residual)
}

/// Returns whether `self` is congruently closed
Expand Down Expand Up @@ -618,22 +609,26 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
} else {
let this = get_self(outer);
let canon_id = this.find_mut(existing_id);
let new_id = this.residual.unionfind.make_set();
this.undo_log.add_node(&original, &[], new_id);
this.undo_log.union(canon_id, new_id, Vec::new());
let cid = ClassId::from(this.classes.len());
let new_id = this.residual.unionfind.make_set_with_id(cid);
this.undo_log.add_node(&original, &[], new_id, cid);
this.undo_log.union(canon_id, new_id, Vec::new(), cid);
debug_assert_eq!(Id::from(this.nodes.len()), new_id);
this.residual.nodes.push(original);
this.residual.unionfind.union(canon_id, new_id);
handle_union(outer, pre, existing_id, new_id);
new_id
}
} else {
let id = this.residual.unionfind.make_set();
let classes_len = this.classes.len();
let end_cid = ClassId::from(classes_len);
let id = this.residual.unionfind.make_set_with_id(end_cid);
let mut dedup_children = SmallVec::<[Id; 8]>::from_slice(enode.children());
dedup_children.sort();
dedup_children.dedup();

this.undo_log.add_node(&original, &dedup_children, id);
this.undo_log
.add_node(&original, &dedup_children, id, end_cid);
debug_assert_eq!(Id::from(this.nodes.len()), id);
this.residual.nodes.push(original);

Expand All @@ -650,7 +645,12 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
this.get_class_mut_with_cannon(child).0.parents.push(id);
}

this.classes.insert(id, class);
assert_eq!(
this.classes.len(),
classes_len,
"classes can't be added from callback"
);
this.classes.push(class);
this.residual.memo.insert_with_hash(hash, enode, id);
this.undo_log.insert_memo(hash);

Expand All @@ -670,26 +670,39 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
enode_id2: Id,
merge: impl FnOnce(MergeInfo<'_, D>),
) {
let mut id1 = self.find_mut(enode_id1);
let mut id2 = self.find_mut(enode_id2);
let (mut id1, mut cid1) = self.unionfind.find_mut_full(enode_id1);
let (mut id2, mut cid2) = self.unionfind.find_mut_full(enode_id2);
if id1 == id2 {
return;
}
// make sure class2 has fewer parents
let class1_parents = self.classes[&id1].parents.len();
let class2_parents = self.classes[&id2].parents.len();
let class1_parents = self.classes[cid1.idx()].parents.len();
let class2_parents = self.classes[cid2.idx()].parents.len();
let mut swapped = false;
if class1_parents < class2_parents {
swapped = true;
std::mem::swap(&mut id1, &mut id2);
std::mem::swap(&mut cid1, &mut cid2);
}

// make id1 the new root
self.residual.unionfind.union(id1, id2);

assert_ne!(id1, id2);
let class2 = self.classes.remove(&id2).unwrap();
let class1 = self.classes.get_mut(&id1).unwrap();
let class2 = if cid2.idx() == self.classes.len() - 1 {
self.classes.pop().unwrap()
} else {
let class2 = self.classes.swap_remove(cid2.idx());
let fixup_id = self.classes[cid2.idx()].id;
self.unionfind.reset_root(fixup_id, cid2);
self.undo_log.fix_id(fixup_id, cid2);
if cid1.idx() == self.classes.len() {
cid1 = cid2;
}
class2
};

let class1 = &mut self.classes[cid1.idx()];
assert_eq!(id1, class1.id);

let info = MergeInfo {
Expand All @@ -707,7 +720,7 @@ impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {

class1.parents.extend(&class2.parents);

self.undo_log.union(id1, id2, class2.parents);
self.undo_log.union(id1, id2, class2.parents, cid2);
}

/// Rebuild to [`RawEGraph`] to restore congruence closure
Expand Down
29 changes: 20 additions & 9 deletions src/raw/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::raw::{Language, RawEGraph};
use crate::Id;
use crate::{ClassId, Id};
use no_std_compat::prelude::v1::*;
use std::fmt::Debug;

Expand All @@ -11,10 +11,13 @@ impl<U: Sealed> Sealed for Option<U> {}
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed {
#[doc(hidden)]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id);
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId);

#[doc(hidden)]
fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec<Id>);
fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec<Id>, old_cid: ClassId);

#[doc(hidden)]
fn fix_id(&mut self, fixup_id: Id, cid: ClassId);

#[doc(hidden)]
fn insert_memo(&mut self, hash: u64);
Expand All @@ -31,10 +34,12 @@ pub trait UndoLogT<L, D>: Default + Debug + Sealed {

impl<L, D> UndoLogT<L, D> for () {
#[inline]
fn add_node(&mut self, _: &L, _: &[Id], _: Id) {}
fn add_node(&mut self, _: &L, _: &[Id], _: Id, _: ClassId) {}

#[inline]
fn union(&mut self, _: Id, _: Id, _: Vec<Id>) {}
fn union(&mut self, _: Id, _: Id, _: Vec<Id>, _: ClassId) {}

fn fix_id(&mut self, _: Id, _: ClassId) {}

#[inline]
fn insert_memo(&mut self, _: u64) {}
Expand All @@ -51,16 +56,22 @@ impl<L, D> UndoLogT<L, D> for () {

impl<L, D, U: UndoLogT<L, D>> UndoLogT<L, D> for Option<U> {
#[inline]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id) {
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId) {
if let Some(undo) = self {
undo.add_node(node, canon_children, node_id)
undo.add_node(node, canon_children, node_id, cid)
}
}

#[inline]
fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec<Id>) {
fn union(&mut self, id1: Id, id2: Id, id2_parents: Vec<Id>, old_cid: ClassId) {
if let Some(undo) = self {
undo.union(id1, id2, id2_parents, old_cid)
}
}

fn fix_id(&mut self, fixup_id: Id, cid: ClassId) {
if let Some(undo) = self {
undo.union(id1, id2, id2_parents)
undo.fix_id(fixup_id, cid)
}
}

Expand Down
Loading

0 comments on commit 81cc122

Please sign in to comment.