diff --git a/src/egraph.rs b/src/egraph.rs index 23d90bc2..9612396d 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -852,7 +852,6 @@ impl> EGraph { /// /// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled pub fn add_expr_uncanonical(&mut self, expr: &RecExpr) -> Id { - eprintln!("Adding {:?} directly", expr); self.add_expr_uncanonical_with_reason(expr, ExistsOrReason::Reason(ExistenceReason::Direct)) } @@ -1228,27 +1227,6 @@ impl> EGraph { (self.find(id1), did_union) } - /// Like `union_instantiations`, but assumes that the `from_pat` and substitution - /// is guaranteed to match the egraph already. - /// Using this method makes existence explanations more precise. - pub fn union_instantiations_guaranteed_match( - &mut self, - from_pat: &PatternAst, - to_pat: &PatternAst, - subst: &Subst, - rule_name: impl Into, - ) -> (Id, bool) { - // add the lhs without an existence reason, - // assuming it matches - let id1 = self.add_instantiation_noncanonical(from_pat, subst, None); - // add the rhs, making it equal to the lhs - let id2 = - self.add_instantiation_noncanonical(to_pat, subst, Some(ExistenceReason::EqualTo(id1))); - - let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into()))); - (self.find(id1), did_union) - } - /// Like [`EGraph::union_instantiations`] but assumes that the `from_term` is a /// term that the `rule_name` rule matched. /// diff --git a/src/explain.rs b/src/explain.rs index a9512f11..c362e70a 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -423,6 +423,7 @@ impl Explanation { assert!(has_forward ^ has_backward); if has_forward { + eprintln!("Checking rewrite forward from {:?} to {:?}", current, next); assert!(self.check_rewrite_at(current, next, &rule_table, true)); } else { assert!(self.check_rewrite_at(current, next, &rule_table, false)); @@ -1326,22 +1327,19 @@ impl<'x, L: Language> ExplainNodes<'x, L> { } ExistenceReason::EqualTo(adjacent_id) => { let adjacent_node = &self.explainfind[usize::from(adjacent_id)]; - // The node should be directly adjacent to another node let connection = if node.parent_connection.next == adjacent_id { let mut connection = node.parent_connection.clone(); connection.is_rewrite_forward = !connection.is_rewrite_forward; std::mem::swap(&mut connection.next, &mut connection.current); connection } else { - assert!( - adjacent_node.parent_connection.next == term, - "existence reason between two nodes failed: not directly adjacent." - ); + assert_eq!(node.parent_connection.next, adjacent_id); adjacent_node.parent_connection.clone() }; let adj = self.explain_adjacent(connection, cache, enode_cache, false); let mut exp = self.explain_term_existence(adjacent_id, adj, cache, enode_cache); + exp.push(rest_of_proof); exp } diff --git a/src/test.rs b/src/test.rs index fe98bf26..0b15ef3d 100644 --- a/src/test.rs +++ b/src/test.rs @@ -290,7 +290,7 @@ macro_rules! test_fn { &[$( $goal.parse().unwrap() ),+], None $(.or(Some($check_fn)))?, check, - true $(&& $check_existence_explanations)?, + false $(|| $check_existence_explanations)?, ) }}; } diff --git a/tests/math.rs b/tests/math.rs index 4e3cf820..e4d370ed 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -85,7 +85,7 @@ impl Analysis for ConstantFold { let data = egraph[id].data.clone(); if let Some((c, pat)) = data { if egraph.are_explanations_enabled() { - egraph.union_instantiations_guaranteed_match( + egraph.union_instantiations( &pat, &format!("{}", c).parse().unwrap(), &Default::default(), @@ -227,7 +227,8 @@ egg::test_fn! { egg::test_fn! { #[should_panic(expected = "Could not prove goal 0")] math_fail, rules(), - "(+ x y)" => "(/ x y)" + "(+ x y)" => "(/ x y)", + @existence false } egg::test_fn! {math_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)" } @@ -235,7 +236,8 @@ egg::test_fn! {math_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y egg::test_fn! { math_simplify_const, rules(), - "(+ 1 (- a (* (- 2 1) a)))" => "1" + "(+ 1 (- a (* (- 2 1) a)))" => "1", + @existence false } egg::test_fn! { @@ -249,6 +251,7 @@ egg::test_fn! { 2)))"# => "(/ 1 (sqrt five))" + @existence false } egg::test_fn! { @@ -256,17 +259,19 @@ egg::test_fn! { "(* (+ x 3) (+ x 1))" => "(+ (+ (* x x) (* 4 x)) 3)" + @existence false } -egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"} +// Existence proofs don't support analysis, so we turn tests for them off +egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"} egg::test_fn! {math_diff_different, rules(), "(d x y)" => "0"} -egg::test_fn! {math_diff_simple1, rules(), "(d x (+ 1 (* 2 x)))" => "2"} egg::test_fn! {math_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y"} egg::test_fn! {math_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)"} egg::test_fn! { diff_power_simple, rules(), - "(d x (pow x 3))" => "(* 3 (pow x 2))" + "(d x (pow x 3))" => "(* 3 (pow x 2))", + @existence false } egg::test_fn! { @@ -280,11 +285,13 @@ egg::test_fn! { .with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()), "(d x (- (pow x 3) (* 7 (pow x 2))))" => - "(* x (- (* 3 x) 14))" + "(* x (- (* 3 x) 14))", + @existence false } egg::test_fn! { - integ_one, rules(), "(i 1 x)" => "x" + integ_one, rules(), "(i 1 x)" => "x", + @existence false } egg::test_fn! { diff --git a/tests/math_no_analysis.rs b/tests/math_no_analysis.rs new file mode 100644 index 00000000..4aa84fc3 --- /dev/null +++ b/tests/math_no_analysis.rs @@ -0,0 +1,189 @@ +//! Since existence proofs don't support analysis, +//! we test egg without analysis here. + +use egg::{rewrite as rw, *}; +use ordered_float::NotNan; + +pub type EGraph = egg::EGraph; +pub type Rewrite = egg::Rewrite; + +pub type Constant = NotNan; + +define_language! { + pub enum Math { + "d" = Diff([Id; 2]), + "i" = Integral([Id; 2]), + + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "*" = Mul([Id; 2]), + "/" = Div([Id; 2]), + "pow" = Pow([Id; 2]), + "ln" = Ln(Id), + "sqrt" = Sqrt(Id), + + "sin" = Sin(Id), + "cos" = Cos(Id), + + Constant(Constant), + Symbol(Symbol), + } +} + +// You could use egg::AstSize, but this is useful for debugging, since +// it will really try to get rid of the Diff operator +pub struct MathCostFn; +impl egg::CostFunction for MathCostFn { + type Cost = usize; + fn cost(&mut self, enode: &Math, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + let op_cost = match enode { + Math::Diff(..) => 100, + Math::Integral(..) => 100, + _ => 1, + }; + enode.fold(op_cost, |sum, i| sum + costs(i)) + } +} + +#[rustfmt::skip] +pub fn rules() -> Vec { vec![ + rw!("add-1-1"; "(+ 1 1)" => "2"), + rw!("add-0-r"; "(+ ?a 0)" => "?a"), + rw!("add-0-l"; "(+ 0 ?a)" => "?a"), + rw!("add-2-2"; "(+ 2 2)" => "4"), + rw!("add-3-1"; "(+ 3 1)" => "4"), + rw!("sub-0-r"; "(- ?a 0)" => "?a"), + rw!("sub-0-1"; "(- 0 1)" => "-1"), + rw!("sub-1-0"; "(- 1 0)" => "1"), + rw!("sub-1-1"; "(- 1 1)" => "0"), + rw!("add-2-neg1"; "(+ 2 -1)" => "1"), + rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"), + rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), + rw!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), + + rw!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), + rw!("div-canon"; "(/ ?a ?b)" => "(* ?a (pow ?b -1))"), + // rw!("canon-sub"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"), + // rw!("canon-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")), + + rw!("zero-add"; "(+ ?a 0)" => "?a"), + rw!("zero-mul"; "(* ?a 0)" => "0"), + rw!("one-mul"; "(* ?a 1)" => "?a"), + + rw!("add-zero"; "?a" => "(+ ?a 0)"), + rw!("mul-one"; "?a" => "(* ?a 1)"), + + rw!("cancel-sub"; "(- ?a ?a)" => "0"), + rw!("cancel-div"; "(/ ?a ?a)" => "1"), + + rw!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), + rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), + + rw!("pow-mul"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"), + rw!("pow0"; "(pow ?x 0)" => "1"), + rw!("pow1"; "(pow ?x 1)" => "?x"), + rw!("pow2"; "(pow ?x 2)" => "(* ?x ?x)"), + rw!("pow-recip"; "(pow ?x -1)" => "(/ 1 ?x)"), + rw!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1"), + + rw!("d-variable"; "(d ?x ?x)" => "1"), + rw!("d-constant"; "(d ?x ?c)" => "0"), + + rw!("d-add"; "(d ?x (+ ?a ?b))" => "(+ (d ?x ?a) (d ?x ?b))"), + rw!("d-mul"; "(d ?x (* ?a ?b))" => "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))"), + + rw!("d-sin"; "(d ?x (sin ?x))" => "(cos ?x)"), + rw!("d-cos"; "(d ?x (cos ?x))" => "(* -1 (sin ?x))"), + + rw!("d-ln"; "(d ?x (ln ?x))" => "(/ 1 ?x)"), + + rw!("d-power"; + "(d ?x (pow ?f ?g))" => + "(* (pow ?f ?g) + (+ (* (d ?x ?f) + (/ ?g ?f)) + (* (d ?x ?g) + (ln ?f))))" + ), + + rw!("i-one"; "(i 1 ?x)" => "?x"), + rw!("i-power-const"; "(i (pow ?x ?c) ?x)" => + "(/ (pow ?x (+ ?c 1)) (+ ?c 1))"), + rw!("i-cos"; "(i (cos ?x) ?x)" => "(sin ?x)"), + rw!("i-sin"; "(i (sin ?x) ?x)" => "(* -1 (cos ?x))"), + rw!("i-sum"; "(i (+ ?f ?g) ?x)" => "(+ (i ?f ?x) (i ?g ?x))"), + rw!("i-dif"; "(i (- ?f ?g) ?x)" => "(- (i ?f ?x) (i ?g ?x))"), + rw!("i-parts"; "(i (* ?a ?b) ?x)" => + "(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))"), +]} + +egg::test_fn! { + existence_associate_adds, [ + rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), + ], + runner = Runner::default() + .with_iter_limit(7) + .with_scheduler(SimpleScheduler), + "(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))" + => + "(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))" + @check |r: Runner| assert_eq!(r.egraph.number_of_classes(), 127), + @existence true +} + +egg::test_fn! { + #[should_panic(expected = "Could not prove goal 0")] + existence_fail, rules(), + "(+ x y)" => "(/ x y)", + @existence true +} + +egg::test_fn! {existence_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)", @existence true } +egg::test_fn! {existence_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y))", @existence true} + +egg::test_fn! { + existence_simplify_const, rules(), + "(+ 1 (- a (* (- 2 1) a)))" => "1", + @existence true +} + +egg::test_fn! { + existence_simplify_factor, rules(), + "(* (+ x 3) (+ x 1))" + => + "(+ (+ (* x x) (* 4 x)) 3)" + @existence true +} + +egg::test_fn! {existence_diff_same, rules(), "(d x x)" => "1", @existence true} +egg::test_fn! {existence_diff_different, rules(), "(d x y)" => "0", @existence true} +egg::test_fn! {existence_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y", @existence true} +egg::test_fn! {existence_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)", @existence true} + +egg::test_fn! { + existence_diff_power_simple, rules(), + "(d x (pow x 3))" => "(* 3 (pow x 2))", + @existence true +} + +egg::test_fn! { + existence_integ_one, rules(), "(i 1 x)" => "x", + @existence true +} + +egg::test_fn! { + existence_integ_sin, rules(), "(i (cos x) x)" => "(sin x)", @existence true +} + +egg::test_fn! { + existence_integ_x, rules(), "(i (pow x 1) x)" => "(/ (pow x 2) 2)", @existence true +} + +egg::test_fn! { + existence_integ_part1, rules(), "(i (* x (cos x)) x)" => "(+ (* x (sin x)) (cos x))", @existence true +}