Skip to content

Commit

Permalink
linter to chech ssa form of rhs of ftrans rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 30, 2023
1 parent 5f28e16 commit 90fc50d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 43 deletions.
104 changes: 63 additions & 41 deletions SciLean/Lean/ToSSA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,50 @@ namespace Lean.Expr
/-- Turns expression `e` into single-static-assigment w.r.t. to free variables `fvars` and all bound variables
Returns expression, newly introduced let bindings and local context where these bindings live
--TODO: add option to do common subexpression elimination i.e. check if let binding with particular value already exists
-/
partial def toSSA.impl (e : Expr) (fvars : Array Expr) : MetaM (Expr × Array Expr × LocalContext) :=
if ¬(e.hasAnyFVar (fun id => fvars.contains (.fvar id))) then
return (e,#[],←getLCtx)
else
match e with
| .app .. => do
let fn := e.getAppFn
let args := e.getAppArgs
goApp fn args fvars 0 #[]
| .lam n t b bi =>
withLocalDecl n bi t fun x => do
partial def toSSA.impl (e : Expr) (fvars : Array Expr) : MetaM (Expr × Array Expr × LocalContext) := do
match e with
| .app .. => do
let fn := e.getAppFn
let args := e.getAppArgs
let infos := (← getFunInfoNArgs fn args.size).paramInfo
goApp fn args infos fvars 0 #[]
| .lam n t b bi =>
withLocalDecl n bi t fun x => do
let b := b.instantiate1 x
let lctx ← getLCtx
let (b', lets, lctx') ← impl b (fvars.push x)
withLCtx lctx' (← getLocalInstances) do
return (← mkLambdaFVars (#[x]++lets) b', #[], lctx)
| .letE n t v b _ => do
let (v', lets, lctx') ← impl v fvars
withLCtx lctx' (← getLocalInstances) do
withLetDecl n t v' fun x => do
let b := b.instantiate1 x
let lctx ← getLCtx
let (b', lets, lctx') ← impl b (fvars.push x)
withLCtx lctx' (← getLocalInstances) do
return (← mkLambdaFVars (#[x]++lets) b', #[], lctx)
| .mdata _ e => impl e fvars
| _ => return (e,#[],←getLCtx)
let (b', lets', lctx'') ← impl b (fvars ++ lets.push x)
withLCtx lctx'' (← getLocalInstances) do
return (b', lets.push x ++ lets', ← getLCtx)
| .mdata _ e => impl e fvars
| _ => return (e,#[],←getLCtx)
where
goApp (fn : Expr) (args : Array Expr) (fvars : Array Expr) (i : Nat) (lets : Array Expr) : MetaM (Expr × Array Expr × LocalContext) := do
if h : i < args.size then
goApp (fn : Expr) (args : Array Expr) (infos : Array ParamInfo) (fvars : Array Expr) (i : Nat) (lets : Array Expr) : MetaM (Expr × Array Expr × LocalContext) := do
if h : i < args.size then do

if h' : i < infos.size then
let info := infos[i]!
if info.isImplicit || info.isInstImplicit then
return ← goApp fn args infos fvars (i+1) lets

let arg := args[i]
let (arg', lets', lctx') ← toSSA.impl arg fvars
withLCtx lctx' (← getLocalInstances) do
if ¬(arg'.hasAnyFVar (fun id => fvars.contains (.fvar id))) && lets'.size = 0 then
goApp fn args fvars (i+1) lets
if arg'.consumeMData.isApp then
withLetDecl Name.anonymous (← inferType arg') arg' fun argVar => do
goApp fn (args.set ⟨i,h⟩ argVar) infos (fvars.push argVar) (i+1) (lets ++ lets'.push argVar)
else
if arg'.isApp then
withLetDecl Name.anonymous (← inferType arg') arg' fun argVar => do
goApp fn (args.set ⟨i,h⟩ argVar) (fvars.push argVar) (i+1) (lets ++ lets'.push argVar)
else
goApp fn (args.set ⟨i,h⟩ arg') fvars (i+i) (lets++lets')
goApp fn (args.set ⟨i,h⟩ arg') infos fvars (i+1) (lets++lets')
else
return (mkAppN fn args, lets, ← getLCtx)

Expand All @@ -59,25 +70,36 @@ def toSSA (e : Expr) (fvars : Array Expr) : MetaM Expr := do
return e''


open Qq
#eval show MetaM Unit from do
-- open Qq Elab Term
-- #eval show TermElabM Unit from do

-- withLocalDeclDQ `x q(Nat) fun x => do

-- let e := q( fun y => $x*y + $x*$x)

-- let e' ← toSSA e #[x]
-- IO.println (← ppExpr e)
-- IO.println ""
-- IO.println (← ppExpr e')
-- IO.println ""


withLocalDeclDQ `x q(Nat) fun x => do
-- withLocalDeclDQ `x q(Nat) fun x => do

let e := q( fun y => $x*y + $x*$x)
-- let e := q( fun y : Nat => (($x*$x*y + $x^2) + $x*y + $x, fun z : Nat => z*y*$x + $x))

let e' ← toSSA e #[x]
IO.println (← ppExpr e)
IO.println ""
IO.println (← ppExpr e')
IO.println ""
-- let e' ← toSSA e #[x]
-- IO.println (← ppExpr e)
-- IO.println ""
-- IO.println (← ppExpr e')


withLocalDeclDQ `x q(Nat) fun x => do
-- withLocalDeclDQ `x q(Nat) fun x => do
-- withLocalDeclDQ `f q(Nat→Nat) fun f => do

let e := q( fun y : Nat => (($x*$x*y + $x^2) + $x*y + $x, fun z : Nat => z*y*$x + $x))
-- let e := q(fun y : Nat => ($f y, fun dy => (($f $x) * $f y * $f dy)))

let e' ← toSSA e #[x]
IO.println (← ppExpr e)
IO.println ""
IO.println (← ppExpr e')
-- let e' ← toSSA e #[x]
-- IO.println (← ppExpr e)
-- IO.println ""
-- IO.println (← ppExpr e')
11 changes: 9 additions & 2 deletions SciLean/Tactic/FTrans/Init.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Mathlib.Data.FunLike.Basic
import SciLean.Util.SorryProof
import SciLean.Lean.MergeMapDeclarationExtension
import SciLean.Lean.Meta.Basic

import SciLean.Lean.ToSSA
import SciLean.Tactic.StructuralInverse
import SciLean.Tactic.AnalyzeConstLambda

Expand All @@ -33,6 +33,7 @@ initialize registerTraceClass `Meta.Tactic.ftrans.discharge
initialize registerTraceClass `Meta.Tactic.ftrans.unify

initialize registerOption `linter.ftransDeclName { defValue := true, descr := "suggests declaration name for ftrans rule" }
initialize registerOption `linter.ftransSsaRhs { defValue := false, descr := "check if right hand side of ftrans rule is in single static asigment form" }
-- initialize registerTraceClass `Meta.Tactic.ftrans.lambda_special_cases

register_simp_attr ftrans_simp
Expand Down Expand Up @@ -269,7 +270,7 @@ initialize funTransRuleAttr : TagAttribute ←
MetaM.run' do
forallTelescope rule λ _ eq => do

let .some (_,lhs,_) := eq.app3? ``Eq
let .some (_,lhs,rhs) := eq.app3? ``Eq
| throwError s!"`{← ppExpr eq}` is not a rewrite rule!"

let .some (transName, _, f) ← getFTrans? lhs
Expand All @@ -282,6 +283,12 @@ To register function transformation call:
```
where <name> is name of the function transformation and <info> is corresponding `FTrans.Info`.
"

if (← getBoolOption `linter.ftransSsaRhs true) then
let rhs' ← rhs.toSSA #[]
if ¬(rhs.eqv rhs') then
logWarning s!"right hand side is not in single static assigment form, expected form:\n{←ppExpr rhs'}"

let data ← analyzeConstLambda f

let suggestedRuleName :=
Expand Down

0 comments on commit 90fc50d

Please sign in to comment.