Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added push/pop API #290

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};

use crate::semi_persistent::UndoLogT;
use log::*;

/** A data structure to keep track of equalities between expressions.
Expand Down Expand Up @@ -56,16 +57,24 @@ pub struct EGraph<L: Language, N: Analysis<L>> {
pub analysis: N,
/// The `Explain` used to explain equivalences in this `EGraph`.
pub(crate) explain: Option<Explain<L>>,
unionfind: UnionFind,
#[cfg_attr(
feature = "serde-1",
serde(bound(
serialize = "N::UndoLog: Serialize",
deserialize = "N::UndoLog: for<'a> Deserialize<'a>",
))
)]
pub(crate) undo_log: N::UndoLog,
pub(crate) unionfind: UnionFind,
/// Stores each enode's `Id`, not the `Id` of the eclass.
/// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new
/// unions can cause them to become out of date.
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
memo: HashMap<L, Id>,
pub(crate) memo: HashMap<L, Id>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pending: Vec<(L, Id)>,
analysis_pending: UniqueQueue<(L, Id)>,
pub(crate) pending: Vec<(L, Id)>,
pub(crate) analysis_pending: UniqueQueue<(L, Id)>,
#[cfg_attr(
feature = "serde-1",
serde(bound(
Expand Down Expand Up @@ -103,6 +112,8 @@ impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
f.debug_struct("EGraph")
.field("memo", &self.memo)
.field("classes", &self.classes)
.field("undo_log", &self.undo_log)
.field("explain", &self.explain)
.finish()
}
}
Expand All @@ -120,6 +131,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
memo: Default::default(),
analysis_pending: Default::default(),
classes_by_op: Default::default(),
undo_log: Default::default(),
}
}

Expand Down Expand Up @@ -769,9 +781,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
*existing_explain
} else {
let new_id = self.unionfind.make_set();
self.undo_log.add_node(&original, new_id);
explain.add(original, new_id, new_id);
self.unionfind.union(id, new_id);
self.undo_log.union(id, new_id);
explain.union(existing_id, new_id, Justification::Congruence, true);
self.undo_log.union_explain(existing_id, new_id);
new_id
}
} else {
Expand All @@ -780,6 +795,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
} else {
let id = self.make_new_eclass(enode);
if let Some(explain) = self.explain.as_mut() {
self.undo_log.add_node(&original, id);
explain.add(original, id, id);
}

Expand Down Expand Up @@ -811,7 +827,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
self.pending.push((enode.clone(), id));

self.classes.insert(id, class);
assert!(self.memo.insert(enode, id).is_none());
let old = self.undo_log.modify_memo(&mut self.memo, enode, Some(id));
assert!(old.is_none());

id
}
Expand Down Expand Up @@ -919,7 +936,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
if id1 == id2 {
if let Some(Justification::Rule(_)) = rule {
if let Some(explain) = &mut self.explain {
explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap());
explain.alternate_rewrite(
enode_id1,
enode_id2,
rule.unwrap(),
&mut self.undo_log,
);
}
}
return false;
Expand All @@ -933,10 +955,12 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

if let Some(explain) = &mut self.explain {
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);
self.undo_log.union_explain(enode_id1, enode_id2);
}

// make id1 the new root
self.unionfind.union(id1, id2);
self.undo_log.union(id1, id2);

assert_ne!(id1, id2);
let class2 = self.classes.remove(&id2).unwrap();
Expand Down Expand Up @@ -1105,7 +1129,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
while !self.pending.is_empty() || !self.analysis_pending.is_empty() {
while let Some((mut node, class)) = self.pending.pop() {
node.update_children(|id| self.find_mut(id));
if let Some(memo_class) = self.memo.insert(node, class) {
if let Some(memo_class) =
self.undo_log.modify_memo(&mut self.memo, node, Some(class))
{
let did_something = self.perform_union(
memo_class,
class,
Expand Down
46 changes: 30 additions & 16 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{self, Debug, Display, Formatter};
use std::rc::Rc;

use crate::semi_persistent::UndoLogT;
use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
Expand All @@ -29,32 +30,43 @@ pub enum Justification {

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct Connection {
next: Id,
pub(crate) struct Connection {
pub(crate) next: Id,
current: Id,
justification: Justification,
is_rewrite_forward: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct ExplainNode<L: Language> {
node: L,
pub(crate) struct ExplainNode<L: Language> {
pub(crate) node: L,
// neighbors includes parent connections
neighbors: Vec<Connection>,
parent_connection: Connection,
pub(crate) neighbors: Vec<Connection>,
pub(crate) parent_connection: Connection,
// it was inserted because of:
// 1) it's parent is inserted (points to parent enode)
// 2) a rewrite instantiated it (points to adjacent enode)
// 3) it was inserted directly (points to itself)
// if 1 is true but it's also adjacent (2) then either works and it picks 2
existance_node: Id,
pub(crate) existance_node: Id,
}

impl Connection {
pub(crate) fn dummy(set: Id) -> Self {
Connection {
justification: Justification::Congruence,
is_rewrite_forward: false,
next: set,
current: set,
}
}
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct Explain<L: Language> {
explainfind: Vec<ExplainNode<L>>,
pub(crate) explainfind: Vec<ExplainNode<L>>,
#[cfg_attr(feature = "serde-1", serde(with = "vectorize"))]
pub uncanon_memo: HashMap<L, Id>,
/// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted.
Expand All @@ -66,7 +78,7 @@ pub struct Explain<L: Language> {
// Invariant: The distance is always <= the unoptimized distance
// That is, less than or equal to the result of `distance_between`
#[cfg_attr(feature = "serde-1", serde(skip))]
shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
pub(crate) shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>,
}

#[derive(Default)]
Expand Down Expand Up @@ -1048,12 +1060,7 @@ impl<L: Language> Explain<L> {
self.explainfind.push(ExplainNode {
node,
neighbors: vec![],
parent_connection: Connection {
justification: Justification::Congruence,
is_rewrite_forward: false,
next: set,
current: set,
},
parent_connection: Connection::dummy(set),
existance_node,
});
set
Expand All @@ -1075,7 +1082,13 @@ impl<L: Language> Explain<L> {
}
}

pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) {
pub(crate) fn alternate_rewrite(
&mut self,
node1: Id,
node2: Id,
justification: Justification,
undo: &mut impl UndoLogT<L>,
) {
if node1 == node2 {
return;
}
Expand All @@ -1084,6 +1097,7 @@ impl<L: Language> Explain<L> {
return;
}
}
undo.union_explain(node1, node2);

let lconnection = Connection {
justification: justification.clone(),
Expand Down
24 changes: 24 additions & 0 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{hash::Hash, str::FromStr};

use crate::*;

use crate::semi_persistent::{UndoLog, UndoLogT};
use fmt::Formatter;
use symbolic_expressions::{Sexp, SexpError};
use thiserror::Error;
Expand Down Expand Up @@ -655,6 +656,7 @@ define_language! {
struct ConstantFolding;
impl Analysis<SimpleMath> for ConstantFolding {
type Data = Option<i32>;
type UndoLog = ();

fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
egg::merge_max(to, from)
Expand Down Expand Up @@ -700,6 +702,12 @@ pub trait Analysis<L: Language>: Sized {
/// The per-[`EClass`] data for this analysis.
type Data: Debug;

/// Determines whether the [`EGraph`] supports [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Setting this to `()` disables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Setting this to [`UndoLog`](UndoLog) enables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Doing this requires that the [`EGraph`] has explanations enabled
type UndoLog: UndoLogT<L>;

/// Makes a new [`Analysis`] data for a given e-node.
///
/// Note the mutable `egraph` parameter: this is needed for some
Expand Down Expand Up @@ -765,6 +773,22 @@ pub trait Analysis<L: Language>: Sized {

impl<L: Language> Analysis<L> for () {
type Data = ();

type UndoLog = ();
fn make(_egraph: &mut EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge {
DidMerge(false, false)
}
}

/// Simple [`Analysis`], similar to `()` but enables [`push`](EGraph::push) and [`pop`](EGraph::pop)
/// Doing this requires that the [`EGraph`] has explanations enabled
pub struct WithUndo;

impl<L: Language> Analysis<L> for WithUndo {
type Data = ();

type UndoLog = UndoLog<L>;
fn make(_egraph: &mut EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge {
DidMerge(false, false)
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod multipattern;
mod pattern;
mod rewrite;
mod run;
mod semi_persistent;
mod subst;
mod unionfind;
mod util;
Expand Down Expand Up @@ -101,6 +102,7 @@ pub use {
pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches},
rewrite::{Applier, Condition, ConditionEqual, ConditionalApplier, Rewrite, Searcher},
run::*,
semi_persistent::UndoLog,
subst::{Subst, Var},
util::*,
};
Expand Down
1 change: 1 addition & 0 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ where
/// struct MinSize;
/// impl Analysis<Math> for MinSize {
/// type Data = usize;
/// type UndoLog = ();
/// fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
/// merge_min(to, from)
/// }
Expand Down
Loading
Loading