Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow path compression to be disabled by undo log #11

Merged
merged 4 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Use the [`Dot`] struct to visualize an [`EGraph`](crate::EGraph)
use no_std_compat::prelude::v1::*;
use std::fmt::{self, Debug, Display, Formatter};

use crate::raw::reflect_const::PathCompressT;
use crate::{raw::EGraphResidual, raw::Language};

/**
Expand Down Expand Up @@ -48,16 +49,16 @@ instead of to its own eclass.

[GraphViz]: https://graphviz.gitlab.io/
**/
pub struct Dot<'a, L: Language> {
pub(crate) egraph: &'a EGraphResidual<L>,
pub struct Dot<'a, L: Language, P: PathCompressT> {
pub(crate) egraph: &'a EGraphResidual<L, P>,
/// A list of strings to be output top part of the dot file.
pub config: Vec<String>,
/// Whether or not to anchor the edges in the output.
/// True by default.
pub use_anchors: bool,
}

impl<'a, L> Dot<'a, L>
impl<'a, L, P: PathCompressT> Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down Expand Up @@ -100,7 +101,7 @@ mod std_only {
use std::io::{Error, ErrorKind, Result, Write};
use std::path::Path;

impl<'a, L: Language + Display> Dot<'a, L> {
impl<'a, L: Language + Display, P: PathCompressT> Dot<'a, L, P> {
/// Writes the `Dot` to a .dot file with the given filename.
/// Does _not_ require a `dot` binary.
pub fn to_dot(&self, filename: impl AsRef<Path>) -> Result<()> {
Expand Down Expand Up @@ -177,13 +178,13 @@ mod std_only {
}
}

impl<'a, L: Language> Debug for Dot<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for Dot<'a, L, P> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_tuple("Dot").field(self.egraph).finish()
}
}

impl<'a, L> Display for Dot<'a, L>
impl<'a, L, P: PathCompressT> Display for Dot<'a, L, P>
where
L: Language + Display,
{
Expand Down
3 changes: 2 additions & 1 deletion src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use raw::semi_persistent1 as sp;
#[cfg(not(feature = "push-pop-alt"))]
use raw::semi_persistent2 as sp;

use crate::raw::UndoLogPC;
use sp::UndoLog;
type PushInfo = (sp::PushInfo, explain::PushInfo, usize);
/** A data structure to keep track of equalities between expressions.
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<L: Language, N: Analysis<L>> Debug for EGraph<L, N> {
}

impl<L: Language, N: Analysis<L>> Deref for EGraph<L, N> {
type Target = EGraphResidual<L>;
type Target = EGraphResidual<L, <UndoLog as UndoLogPC>::AllowPathCompress>;

#[inline]
fn deref(&self) -> &Self::Target {
Expand Down
4 changes: 2 additions & 2 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::mem;
use std::ops::{Deref, DerefMut};
use std::rc::Rc;

use crate::raw::RawEGraph;
use crate::raw::{RawEGraph, UndoLogT};
use symbolic_expressions::Sexp;

type ProofCost = Saturating<usize>;
Expand Down Expand Up @@ -1094,7 +1094,7 @@ impl<'a, L: Language, X> DerefMut for ExplainWith<'a, L, X> {
}
}

impl<'x, L: Language, D, U> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
impl<'x, L: Language, D, U: UndoLogT<L, D>> ExplainWith<'x, L, &'x RawEGraph<L, D, U>> {
pub(crate) fn node(&self, node_id: Id) -> &L {
self.raw.id_to_node(node_id)
}
Expand Down
5 changes: 3 additions & 2 deletions src/explain/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::explain::{Connection, Explain};
use crate::raw::reflect_const::PathCompressT;
use crate::raw::EGraphResidual;
use crate::{Id, Language};
use no_std_compat::prelude::v1::*;
Expand Down Expand Up @@ -28,11 +29,11 @@ impl<L: Language> Explain<L> {
PushInfo(self.undo_log.as_ref().unwrap().len())
}

pub(crate) fn pop(
pub(crate) fn pop<P: PathCompressT>(
&mut self,
info: PushInfo,
number_of_uncanon_nodes: usize,
egraph: &EGraphResidual<L>,
egraph: &EGraphResidual<L, P>,
) {
for id in self.undo_log.as_mut().unwrap().drain(info.0..).rev() {
let node1 = &mut self.explainfind[usize::from(id)];
Expand Down
5 changes: 4 additions & 1 deletion src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ pub mod semi_persistent2;
mod unionfind;
pub(crate) mod util;

/// Types and traits for specify whether path compression is supported
pub mod reflect_const;

pub use eclass::RawEClass;
pub use egraph::{EGraphResidual, RawEGraph, UnionInfo};
pub use language::*;
use semi_persistent::Sealed;
pub use semi_persistent::{AsUnwrap, UndoLogT};
pub use semi_persistent::{AsUnwrap, UndoLogPC, UndoLogT};
pub use unionfind::UnionFind;
58 changes: 38 additions & 20 deletions src/raw/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{
};

use crate::raw::dhashmap::*;
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::UndoLogT;
use default_vec2::BitSet;
#[cfg(feature = "serde-1")]
Expand All @@ -35,8 +36,12 @@ impl<'a> IntoIterator for Parents<'a> {
/// See [`RawEGraph::classes_mut`], [`RawEGraph::get_class_mut`]
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct EGraphResidual<L: Language> {
pub(super) unionfind: UnionFind,
#[cfg_attr(
feature = "serde-1",
serde(bound(serialize = "L: Serialize", deserialize = "L: Deserialize<'de>"))
)]
pub struct EGraphResidual<L: Language, P: PathCompressT = PathCompress<true>> {
pub(super) unionfind: UnionFind<P>,
/// Stores the original node represented by each non-canonical id
pub(super) nodes: Vec<L>,
/// Stores each enode's `Id`, not the `Id` of the eclass.
Expand All @@ -46,7 +51,7 @@ pub struct EGraphResidual<L: Language> {
pub(super) memo: DHashMap<L, Id>,
}

impl<L: Language> EGraphResidual<L> {
impl<L: Language, P: PathCompressT> EGraphResidual<L, P> {
/// Pick a representative term for a given Id.
///
/// Calling this function on an uncanonical `Id` returns a representative based on how it
Expand Down Expand Up @@ -308,7 +313,7 @@ impl<L: Language> EGraphResidual<L> {
}

/// Creates a [`Dot`] to visualize this egraph. See [`Dot`].
pub fn dot(&self) -> Dot<'_, L> {
pub fn dot(&self) -> Dot<'_, L, P> {
Dot {
egraph: self,
config: vec![],
Expand All @@ -317,8 +322,15 @@ impl<L: Language> EGraphResidual<L> {
}
}

impl<L: Language> EGraphResidual<L, PathCompress<false>> {
/// Return the direct parent from the union find without path compression
pub fn find_direct_parent(&self, id: Id) -> Id {
self.unionfind.parent_id(id)
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language> Debug for EGraphResidual<L> {
impl<L: Language, P: PathCompressT> Debug for EGraphResidual<L, P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EGraphResidual")
.field("unionfind", &self.unionfind)
Expand Down Expand Up @@ -356,9 +368,9 @@ to properly handle this data
**/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub struct RawEGraph<L: Language, D, U = ()> {
pub struct RawEGraph<L: Language, D, U: UndoLogT<L, D> = ()> {
#[cfg_attr(feature = "serde-1", serde(flatten))]
pub(super) residual: EGraphResidual<L>,
pub(super) residual: EGraphResidual<L, U::AllowPathCompress>,
/// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode,
/// not the canonical id of the eclass.
pub(super) pending: Vec<Id>,
Expand All @@ -368,7 +380,7 @@ pub struct RawEGraph<L: Language, D, U = ()> {
pub(super) undo_log: U,
}

impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
impl<L: Language, D, U: Default + UndoLogT<L, D>> Default for RawEGraph<L, D, U> {
fn default() -> Self {
let residual = EGraphResidual {
unionfind: Default::default(),
Expand All @@ -385,24 +397,24 @@ impl<L: Language, D, U: Default> Default for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L>;
impl<L: Language, D, U: UndoLogT<L, D>> Deref for RawEGraph<L, D, U> {
type Target = EGraphResidual<L, U::AllowPathCompress>;

#[inline]
fn deref(&self) -> &Self::Target {
&self.residual
}
}

impl<L: Language, D, U> DerefMut for RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> DerefMut for RawEGraph<L, D, U> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.residual
}
}

// manual debug impl to avoid L: Language bound on EGraph defn
impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
impl<L: Language, D: Debug, U: UndoLogT<L, D>> Debug for RawEGraph<L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let classes: BTreeMap<_, _> = self
.classes
Expand All @@ -428,7 +440,7 @@ impl<L: Language, D: Debug, U> Debug for RawEGraph<L, D, U> {
}
}

impl<L: Language, D, U> RawEGraph<L, D, U> {
impl<L: Language, D, U: UndoLogT<L, D>> RawEGraph<L, D, U> {
/// Returns an iterator over the eclasses in the egraph.
pub fn classes(&self) -> impl ExactSizeIterator<Item = &RawEClass<D>> {
self.classes.iter()
Expand All @@ -440,7 +452,7 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
&mut self,
) -> (
impl ExactSizeIterator<Item = &mut RawEClass<D>>,
&mut EGraphResidual<L>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let iter = self.classes.iter_mut();
(iter, &mut self.residual)
Expand Down Expand Up @@ -470,7 +482,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut<I: BorrowMut<Id>>(
&mut self,
mut id: I,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let id = id.borrow_mut();
let (nid, cid) = self.unionfind.find_mut_full(*id);
*id = nid;
Expand All @@ -481,7 +496,10 @@ impl<L: Language, D, U> RawEGraph<L, D, U> {
pub fn get_class_mut_with_cannon(
&mut self,
id: Id,
) -> (&mut RawEClass<D>, &mut EGraphResidual<L>) {
) -> (
&mut RawEClass<D>,
&mut EGraphResidual<L, U::AllowPathCompress>,
) {
let cid = self.unionfind.find_canon(id);
(&mut self.classes[cid.idx()], &mut self.residual)
}
Expand Down Expand Up @@ -900,9 +918,9 @@ impl<L: Language, U: UndoLogT<L, ()>> RawEGraph<L, (), U> {
}
}

struct EGraphUncanonicalDump<'a, L: Language>(&'a EGraphResidual<L>);
struct EGraphUncanonicalDump<'a, L: Language, P: PathCompressT>(&'a EGraphResidual<L, P>);

impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
impl<'a, L: Language, P: PathCompressT> Debug for EGraphUncanonicalDump<'a, L, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (id, node) in self.0.uncanonical_nodes() {
writeln!(f, "{}: {:?} (root={})", id, node, self.0.find(id))?
Expand All @@ -911,9 +929,9 @@ impl<'a, L: Language> Debug for EGraphUncanonicalDump<'a, L> {
}
}

struct EGraphDump<'a, L: Language, D, U>(&'a RawEGraph<L, D, U>);
struct EGraphDump<'a, L: Language, D, U: UndoLogT<L, D>>(&'a RawEGraph<L, D, U>);

impl<'a, L: Language, D: Debug, U> Debug for EGraphDump<'a, L, D, U> {
impl<'a, L: Language, D: Debug, U: UndoLogT<L, D>> Debug for EGraphDump<'a, L, D, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut ids: Vec<Id> = self.0.classes().map(|c| c.id).collect();
ids.sort();
Expand Down
14 changes: 14 additions & 0 deletions src/raw/reflect_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#![allow(missing_docs)]
use core::fmt::Debug;

#[derive(Copy, Clone, Eq, PartialEq, Default, Debug)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct PathCompress<const B: bool>;

impl<const B: bool> PathCompressT for PathCompress<B> {
const PATH_COMPRESS: bool = B;
}

pub trait PathCompressT: Copy + Clone + Eq + PartialEq + Default + Debug {
const PATH_COMPRESS: bool;
}
18 changes: 17 additions & 1 deletion src/raw/semi_persistent.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::raw::reflect_const::{PathCompress, PathCompressT};
use crate::raw::{Language, RawEGraph};
use crate::{ClassId, Id};
use no_std_compat::prelude::v1::*;
Expand All @@ -9,7 +10,14 @@ impl<U: Sealed> Sealed for Option<U> {}

/// A sealed trait for types that can be used for `push`/`pop` APIs
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed {
pub trait UndoLogPC {
/// When this type of undo log allows for path compression
type AllowPathCompress: PathCompressT;
}

/// A sealed trait for types that can be used for `push`/`pop` APIs
/// It is trivially implemented for `()`
pub trait UndoLogT<L, D>: Default + Debug + Sealed + UndoLogPC {
#[doc(hidden)]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId);

Expand All @@ -32,6 +40,10 @@ pub trait UndoLogT<L, D>: Default + Debug + Sealed {
fn is_enabled(&self) -> bool;
}

impl UndoLogPC for () {
type AllowPathCompress = PathCompress<true>;
}

impl<L, D> UndoLogT<L, D> for () {
#[inline]
fn add_node(&mut self, _: &L, _: &[Id], _: Id, _: ClassId) {}
Expand All @@ -54,6 +66,10 @@ impl<L, D> UndoLogT<L, D> for () {
}
}

impl<U: UndoLogPC> UndoLogPC for Option<U> {
type AllowPathCompress = U::AllowPathCompress;
}

impl<L, D, U: UndoLogT<L, D>> UndoLogT<L, D> for Option<U> {
#[inline]
fn add_node(&mut self, node: &L, canon_children: &[Id], node_id: Id, cid: ClassId) {
Expand Down
Loading
Loading