Skip to content

Commit

Permalink
Only track fvars needed by translation.
Browse files Browse the repository at this point in the history
  • Loading branch information
abdoo8080 committed May 25, 2024
1 parent 73e3fdf commit 076045f
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 32 deletions.
37 changes: 25 additions & 12 deletions Smt/Translate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ structure TranslationM.State where
/-- Constants that the translated result depends on. We propagate these upwards during translation
in order to build a dependency graph. The value is reset at the `translateExpr` entry point. -/
depConstants : NameSet := .empty
/-- Memoizes `applyTranslators?` calls together with what they add to `depConstants`. -/
cache : HashMap Expr (Option (Term × NameSet)) := .empty
/-- Free variables that the translated result depends on. We propagate these upwards during translation
in order to build a dependency graph. The value is reset at the `translateExpr` entry point. -/
depFVars : FVarIdSet := .empty
/-- Memoizes `applyTranslators?` calls together with what they add to `depConstants` and `depFVars`. -/
cache : HashMap Expr (Option (Term × NameSet × FVarIdSet)) := .empty

abbrev TranslationM := StateT TranslationM.State MetaM

Expand Down Expand Up @@ -55,18 +58,23 @@ opaque getTranslators : MetaM (List (Translator × Name))
/-- Return a cached translation of `e` if found, otherwise run `k e` and cache the result. -/
def withCache (k : Translator) (e : Expr) : TranslationM (Option Term) := do
match (← get).cache.find? e with
| some (some (tm, deps)) =>
modify fun st => { st with depConstants := st.depConstants.union deps }
| some (some (tm, depConsts, depFVars)) =>
modify fun st => { st with
depConstants := st.depConstants.union depConsts
depFVars := st.depFVars.union depFVars
}
return some tm
| some none =>
return none
| none =>
let depConstantsBefore := (← get).depConstants
let depFVarsBefore := (← get).depFVars
modify fun st => { st with depConstants := .empty }
let ret? ← k e
modify fun st => { st with
depConstants := st.depConstants.union depConstantsBefore
cache := st.cache.insert e <| ret?.map ((·, st.depConstants))
depFVars := st.depFVars.union depFVarsBefore
cache := st.cache.insert e <| ret?.map ((·, st.depConstants, st.depFVars))
}
return ret?

Expand All @@ -81,7 +89,7 @@ partial def applyTranslators! (e : Expr) : TranslationM Term := do
expression and if one succeeds, its result is returned. Otherwise, `e` is split into subexpressions
which are then recursively translated and put together into an SMT-LIB term. The traversal proceeds
in a top-down, depth-first order. -/
partial def applyTranslators? : Translator := withCache fun e => do
partial def applyTranslators? : Translator := fun e => do
let ts ← getTranslators
go ts e
where
Expand All @@ -97,40 +105,45 @@ partial def applyTranslators? : Translator := withCache fun e => do
match e with
| fvar fv =>
let ld ← fv.getDecl
modify fun st => { st with depFVars := st.depFVars.insert fv }
return symbolT ld.userName.toString
| const nm _ =>
modify fun st => { st with depConstants := st.depConstants.insert nm }
return symbolT nm.toString
| app f e => return appT (← applyTranslators! f) (← applyTranslators! e)
| lam .. => throwError "cannot translate {e}, SMT-LIB does not support lambdas"
| forallE n t b bi =>
let tmB ← Meta.withLocalDecl n bi t (fun x => applyTranslators! <| b.instantiate #[x])
let tmB ← Meta.withLocalDecl n bi t (translateBody b)
if !b.hasLooseBVars /- not a dependent arrow -/ then
return arrowT (← applyTranslators! t) tmB
else
return forallT n.toString (← applyTranslators! t) tmB
| letE n t v b _ =>
let tmB ← Meta.withLetDecl n t v (fun x => applyTranslators! <| b.instantiate #[x])
let tmB ← Meta.withLetDecl n t v (translateBody b)
return letT n.toString (← applyTranslators! v) tmB
| mdata _ e => go ts e
| e => throwError "cannot translate {e}"
translateBody (b : Expr) (x : Expr) : TranslationM Term := do
let tmB ← applyTranslators! (b.instantiate #[x])
modify fun s => { s with depFVars := s.depFVars.erase x.fvarId! }
return tmB

end

def traceTranslation (e : Expr) (e' : Except ε (Term × NameSet)) : TranslationM MessageData :=
def traceTranslation (e : Expr) (e' : Except ε (Term × NameSet × FVarIdSet)) : TranslationM MessageData :=
return m!"{e} ↦ " ++ match e' with
| .ok (e', _) => m!"{e'}"
| .error _ => m!"{bombEmoji}"

/-- Processes `e` by running it through all the registered `Translator`s.
Returns the resulting SMT-LIB term and set of dependencies. -/
def translateExpr (e : Expr) : TranslationM (Term × NameSet) :=
def translateExpr (e : Expr) : TranslationM (Term × NameSet × FVarIdSet) :=
withTraceNode `smt.translate (traceTranslation e ·) do
modify fun st => { st with depConstants := .empty }
modify fun st => { st with depConstants := .empty, depFVars := .empty }
trace[smt.translate.expr] "before: {e}"
let tm ← applyTranslators! e
trace[smt.translate.expr] "translated: {tm}"
return (tm, (← get).depConstants)
return (tm, (← get).depConstants, (← get).depFVars)

def translateExpr' (e : Expr) : TranslationM Term :=
Prod.fst <$> translateExpr e
Expand Down
14 changes: 9 additions & 5 deletions Smt/Translate/Nat.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ open Translator Term
@[smt_translate] def translateForalls : Translator
| e@(forallE n t@(const ``Nat _) b bi) => do
if e.isArrow then return none
Meta.withLocalDecl n bi t fun x => do
let tmB ← applyTranslators! (b.instantiate #[x])
let tmGeqZero := Term.mkApp2 (symbolT ">=") (symbolT n.toString) (literalT "0")
let tmProp := Term.mkApp2 (symbolT "=>") tmGeqZero tmB
return forallT n.toString (symbolT "Int") tmProp
let tmB ← Meta.withLocalDecl n bi t (translateBody b)
let tmGeqZero := Term.mkApp2 (symbolT ">=") (symbolT n.toString) (literalT "0")
let tmProp := Term.mkApp2 (symbolT "=>") tmGeqZero tmB
return forallT n.toString (symbolT "Int") tmProp
| _ => return none
where
translateBody (b : Expr) (x : Expr) : TranslationM Term := do
let tmB ← applyTranslators! (b.instantiate #[x])
modify fun s => { s with depFVars := s.depFVars.erase x.fvarId! }
return tmB

end Smt.Translate.Nat
1 change: 1 addition & 0 deletions Smt/Translate/Prop.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ open Translator Term
Meta.withLocalDecl n bi t fun x => do
let tmT ← applyTranslators! t
let tmB ← applyTranslators! (b.instantiate #[x])
modify fun s => { s with depFVars := s.depFVars.erase x.fvarId! }
return existsT n.toString tmT tmB
| _ => return none

Expand Down
8 changes: 3 additions & 5 deletions Smt/Translate/Query.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ def addDependency (e e' : Expr) : QueryBuilderM Unit :=
/-- Translate an expression and compute its (non-SMT-builtin) dependencies.
When `fvarDeps = false`, we filter out dependencies on fvars. -/
def translateAndFindDeps (e : Expr) (fvarDeps := true) : QueryBuilderM (Term × Array Expr) := do
let (tm, deps) ← Translator.translateExpr e
let unknownConsts := deps.toArray.filterMap fun nm =>
let (tm, depConsts, depFVars) ← Translator.translateExpr e
let unknownConsts := depConsts.toArray.filterMap fun nm =>
if Util.smtConsts.contains nm.toString then none else some (mkConst nm)
if fvarDeps then
let st : CollectFVars.State := {}
let st := collectFVars st e
let fvs := st.fvarIds.map mkFVar
let fvs := depFVars.toArray.map mkFVar
return (tm, fvs ++ unknownConsts)
else
return (tm, unknownConsts)
Expand Down
2 changes: 1 addition & 1 deletion Test/BitVec/XorComm.expected
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Test/BitVec/XorComm.lean:3:8: warning: declaration uses 'sorry'
goal: x ^^^ y = y ^^^ x

query:
(declare-const y (_ BitVec 8))
(declare-const x (_ BitVec 8))
(declare-const y (_ BitVec 8))
(assert (distinct (bvxor x y) (bvxor y x)))
(check-sat)
Test/BitVec/XorComm.lean:7:8: warning: declaration uses 'sorry'
2 changes: 1 addition & 1 deletion Test/Bool/Cong.expected
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
goal: (p == q) = true → (f p == f q) = true

query:
(declare-const p Bool)
(declare-const q Bool)
(declare-const p Bool)
(declare-fun f (Bool) Bool)
(assert (not (=> (= (= p q) true) (= (= (f p) (f q)) true))))
(check-sat)
2 changes: 1 addition & 1 deletion Test/Bool/Trans.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ goal: (p == q) = true → (q == r) = true → (p == r) = true

query:
(declare-const q Bool)
(declare-const p Bool)
(declare-const r Bool)
(declare-const p Bool)
(assert (not (=> (= (= p q) true) (=> (= (= q r) true) (= (= p r) true)))))
(check-sat)
2 changes: 1 addition & 1 deletion Test/Int/Binders.expected
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ goal: mismatchNamesAdd a b = mismatchNamesAdd b a

query:
(define-fun mismatchNamesAdd ((a Int) (b Int)) Int (+ a b))
(declare-const b Int)
(declare-const a Int)
(declare-const b Int)
(assert (distinct (mismatchNamesAdd a b) (mismatchNamesAdd b a)))
(check-sat)
Test/Int/Binders.lean:25:0: warning: declaration uses 'sorry'
2 changes: 1 addition & 1 deletion Test/Nat/Cong.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ goal: x = y → f x = f y
query:
(define-sort Nat () Int)
(declare-fun f (Nat) Nat)
(assert (forall ((_uniq.1099 Nat)) (=> (>= _uniq.1099 0) (>= (f _uniq.1099) 0))))
(assert (forall ((_uniq.1765 Nat)) (=> (>= _uniq.1765 0) (>= (f _uniq.1765) 0))))
(declare-const x Nat)
(assert (>= x 0))
(declare-const y Nat)
Expand Down
8 changes: 4 additions & 4 deletions Test/Nat/Max.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ goal: x ≤ max' x y ∧ y ≤ max' x y
query:
(define-sort Nat () Int)
(declare-fun |Nat.max'| (Nat Nat) Nat)
(assert (forall ((_uniq.1345 Nat)) (=> (>= _uniq.1345 0) (forall ((_uniq.1346 Nat)) (=> (>= _uniq.1346 0) (>= (|Nat.max'| _uniq.1345 _uniq.1346) 0))))))
(assert (forall ((_uniq.2574 Nat)) (=> (>= _uniq.2574 0) (forall ((_uniq.2575 Nat)) (=> (>= _uniq.2575 0) (>= (|Nat.max'| _uniq.2574 _uniq.2575) 0))))))
(declare-const y Nat)
(assert (>= y 0))
(declare-const x Nat)
Expand All @@ -15,10 +15,10 @@ goal: x ≤ max' x y ∧ y ≤ max' x y
query:
(define-sort Nat () Int)
(define-fun |Nat.max'| ((x Nat) (y Nat)) Nat (ite (<= x y) y x))
(assert (forall ((_uniq.3382 Nat)) (=> (>= _uniq.3382 0) (forall ((_uniq.3383 Nat)) (=> (>= _uniq.3383 0) (>= (|Nat.max'| _uniq.3382 _uniq.3383) 0))))))
(declare-const y Nat)
(assert (>= y 0))
(assert (forall ((_uniq.6062 Nat)) (=> (>= _uniq.6062 0) (forall ((_uniq.6063 Nat)) (=> (>= _uniq.6063 0) (>= (|Nat.max'| _uniq.6062 _uniq.6063) 0))))))
(declare-const x Nat)
(assert (>= x 0))
(declare-const y Nat)
(assert (>= y 0))
(assert (not (and (<= x (|Nat.max'| x y)) (<= y (|Nat.max'| x y)))))
(check-sat)
2 changes: 1 addition & 1 deletion Test/Nat/Sum'.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ query:
(declare-const n Nat)
(assert (>= n 0))
(define-fun-rec sum ((n Nat)) Nat (ite (= n 0) 0 (+ n (sum (ite (<= 1 n) (- n 1) 0)))))
(assert (forall ((_uniq.22096 Nat)) (=> (>= _uniq.22096 0) (>= (sum _uniq.22096) 0))))
(assert (forall ((_uniq.23999 Nat)) (=> (>= _uniq.23999 0) (>= (sum _uniq.23999) 0))))
(assert (= (sum n) (div (* n (+ n 1)) 2)))
(assert (distinct (sum (+ n 1)) (div (* (+ n 1) (+ (+ n 1) 1)) 2)))
(check-sat)
1 change: 1 addition & 0 deletions Test/linarith.expected
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ Test/linarith.lean:117:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:129:60: warning: unused variable `h3` [linter.unusedVariables]
Test/linarith.lean:136:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:136:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:157:2: error: [arithPolyNorm]: could not prove x - y = -x + y
Test/linarith.lean:179:34: warning: unused variable `z` [linter.unusedVariables]
Test/linarith.lean:180:5: warning: unused variable `h5` [linter.unusedVariables]

0 comments on commit 076045f

Please sign in to comment.