Skip to content

Commit

Permalink
Fix and optimize AC norm. (#98)
Browse files Browse the repository at this point in the history
* Fix and optimize AC norm.

* Fix error message.
  • Loading branch information
abdoo8080 authored May 23, 2024
1 parent 8e32551 commit b28f919
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
39 changes: 14 additions & 25 deletions Smt/Reconstruct/Builtin/AC.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Lean

theorem Eq.same_root (hac : a = c) (hbc : b = c) : a = b := hac ▸ hbc ▸ rfl

namespace Lean.Meta.AC

open Lean.Elab Tactic
Expand All @@ -11,36 +13,23 @@ def traceACRflTop (r : Except Exception Unit) : MetaM MessageData :=

/-- Similar to `rewriteUnnormalized`, but rewrite is only applied at the top level. -/
def rewriteUnnormalizedTop (mv : MVarId) : MetaM Unit := withTraceNode `smt.reconstruct.acRflTop traceACRflTop do
let some (_, l, r) := (← mv.getType).eq?
let some (α, l, r) := (← mv.getType).eq?
| throwError "[ac_rfl_top] expected a top level equality with AC operator on lhs and/or rhs, got {← mv.getType}"
let lvl ← Meta.getLevel α
let (nl, pl) ← normalize l
let (nr, pr) ← normalize r
if nl == r then
let some pl := pl | throwError "[ac_rfl_top] expected {l} to have an AC operator"
let hl ← Meta.mkFreshExprMVar (← mkEq l nl)
hl.mvarId!.assign pl
let rl ← mv.rewrite (← mv.getType) hl false { occs := .pos [1] }
let mv ← mv.replaceTargetEq rl.eNew rl.eqProof
mv.refl
mv.assign pl
else if l == nr then
let some pr := pr | throwError "[ac_rfl_top] expected {r} to have an AC operator"
let hr ← Meta.mkFreshExprMVar (← mkEq r nr)
hr.mvarId!.assign pr
let rr ← mv.rewrite (← mv.getType) hr false { occs := .pos [1] }
let mv ← mv.replaceTargetEq rr.eNew rr.eqProof
mv.refl
else
mv.assign (mkApp4 (.const ``Eq.symm [lvl]) α r l pr)
else if nl == nr then
let some pl := pl | throwError "[ac_rfl_top] expected {l} to have an AC operator"
let hl ← Meta.mkFreshExprMVar (← mkEq l nl)
hl.mvarId!.assign pl
let rl ← mv.rewrite (← mv.getType) hl false { occs := .pos [1] }
let mv ← mv.replaceTargetEq rl.eNew rl.eqProof
let some pr := pr | throwError "[ac_rfl_top] expected {r} to have an AC operator"
let hr ← Meta.mkFreshExprMVar (← mkEq r nr)
hr.mvarId!.assign pr
let rr ← mv.rewrite (← mv.getType) hr false { occs := .pos [1] }
let mv ← mv.replaceTargetEq rr.eNew rr.eqProof
mv.refl
mv.assign (mkApp6 (.const ``Eq.same_root [lvl]) α l nl r pl pr)
else
throwError "[ac_rfl_top] expected {l} and {r} to have the same AC operator"
where
normalize (e : Expr) : MetaM (Expr × Option Expr) := do
let bin op l r := e | return (e, none)
Expand All @@ -52,10 +41,10 @@ syntax (name := ac_rfl_top) "ac_rfl_top" : tactic

@[tactic ac_rfl_top] def evalacRflTop : Lean.Elab.Tactic.Tactic := fun _ => do
let goal ← getMainGoal
goal.withContext <| rewriteUnnormalizedTop goal
goal.withContext (rewriteUnnormalizedTop goal)

private instance : Std.Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
private instance : Std.Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩
local instance : Std.Associative (α := Nat) (· + ·) := ⟨Nat.add_assoc⟩
local instance : Std.Commutative (α := Nat) (· + ·) := ⟨Nat.add_comm⟩

example (a b c d : Nat) : 2 * (a + b + c + d) = 2 * (d + (b + c) + a) := by
try ac_rfl_top
Expand All @@ -66,6 +55,6 @@ example (a b c d : Nat) : a + b + c + d + (2 * (a + b)) = d + (b + c) + a + (2 *
ac_rfl

example (a b c d : Nat) : a + b + c + d + (a + b) = d + (b + c) + a + (b + a) := by
ac_rfl
ac_rfl_top

end Lean.Meta.AC
1 change: 1 addition & 0 deletions Test/linarith.expected
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ Test/linarith.lean:117:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:129:60: warning: unused variable `h3` [linter.unusedVariables]
Test/linarith.lean:136:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:136:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:157:2: error: [arithPolyNorm]: could not prove x - y = -x + y
Test/linarith.lean:179:34: warning: unused variable `z` [linter.unusedVariables]
Test/linarith.lean:180:5: warning: unused variable `h5` [linter.unusedVariables]

0 comments on commit b28f919

Please sign in to comment.