Skip to content

Commit

Permalink
deriv notation that calls ftrans
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jan 3, 2024
1 parent 73a2181 commit 77ce441
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 43 deletions.
21 changes: 19 additions & 2 deletions SciLean/Core/Notation/FwdCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,43 @@ import SciLean.Core.FunctionTransformations.FwdCDeriv
namespace SciLean.NotationOverField


scoped syntax "∂> " term:66 : term
scoped syntax "∂> " term+ : term
scoped syntax "∂> " diffBinder ", " term:66 : term
scoped syntax "∂> " "(" diffBinder ")" ", " term:66 : term

scoped syntax "∂>! " term+ : term
scoped syntax "∂>! " diffBinder ", " term:66 : term
scoped syntax "∂>! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(∂> $f $xs*) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(fwdCDeriv $K $f $xs*)) none

| `(∂> $f) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(fwdCDeriv $K $f)) none


macro_rules
| `(∂> $f $xs*) => `((∂> $f) $xs*)
| `(∂> $x:ident, $b) => `(∂> (fun $x => $b))
| `(∂> $x:ident := $val:term, $b) => `(∂> (fun $x => $b) $val)
| `(∂> $x:ident : $type:term, $b) => `(∂> fun $x : $type => $b)
| `(∂> $x:ident := $val:term ; $dir:term, $b) => `(∂> (fun $x => $b) $val $dir)
| `(∂> ($b:diffBinder), $f) => `(∂> $b, $f)


macro_rules
| `(∂>! $f $xs*) => `((∂> $f $xs*) rewrite_by ftrans; ftrans; ftrans)
| `(∂>! $f) => `((∂> $f) rewrite_by ftrans; ftrans; ftrans)
| `(∂>! $x:ident, $b) => `(∂>! (fun $x => $b))
| `(∂>! $x:ident := $val:term, $b) => `(∂>! (fun $x => $b) $val)
| `(∂>! $x:ident : $type:term, $b) => `(∂>! fun $x : $type => $b)
| `(∂>! $x:ident := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir)
| `(∂>! ($b:diffBinder), $f) => `(∂>! $b, $f)


@[app_unexpander fwdCDeriv] def unexpandFwdCDeriv : Lean.PrettyPrinter.Unexpander

| `($(_) $_ $f:term $x $dx) =>
Expand Down
40 changes: 27 additions & 13 deletions SciLean/Core/Notation/Gradient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,39 @@ import SciLean.Core.Notation.CDeriv

namespace SciLean.NotationOverField

scoped syntax (name:=gradNotation1) "∇ " term:66 : term
scoped syntax (name:=gradNotation1) "∇ " term+ : term
scoped syntax "∇ " diffBinder ", " term:66 : term
scoped syntax "∇ " "(" diffBinder ")" ", " term:66 : term
scoped syntax "∇! " term:66 : term

scoped syntax "∇! " term+ : term
scoped syntax "∇! " diffBinder ", " term:66 : term
scoped syntax "∇! " "(" diffBinder ")" ", " term:66 : term


open Lean Elab Term Meta in
elab_rules (kind:=gradNotation1) : term
| `(∇ $f $x $xs*) => do
let K := mkIdent (← currentFieldName.get)
let KExpr ← elabTerm (← `($K)) none
let X ← inferType (← elabTerm x none)
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
-- Y might also be infered by the function `f`
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
let .some (_,Y) := (← inferType fExpr).arrow?
| return ← throwUnsupportedSyntax
if (← isDefEq KExpr Y) then
elabTerm (← `(scalarGradient $K $f $x $xs*)) none false
else
elabTerm (← `(gradient $K $f $x $xs*)) none false

| `(∇ $f) => do
let K := mkIdent (← currentFieldName.get)
let X ← mkFreshTypeMVar
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
let KExpr ← elabTerm (← `($K)) none
let fExpr ← elabTerm f none
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
if let .some (_,Y) := (← inferType fExpr).arrow? then
if (← isDefEq KExpr Y) then
elabTerm (← `(scalarGradient $K $f)) none false
Expand All @@ -27,24 +46,19 @@ elab_rules (kind:=gradNotation1) : term
else
throwUnsupportedSyntax

-- open Lean Elab Term Meta in
-- elab_rules (kind:=gradNotation1) : term
-- | `(∇ $x:ident := $val:term; $codir:term, $b) => do
-- let K := mkIdent (← currentFieldName.get)
-- elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false

macro_rules
| `(∇ $x:ident, $f) => `(∇ fun $x => $f)
| `(∇ $x:ident : $type:term, $f) => `(∇ fun $x : $type => $f)
| `(∇ $x:ident := $val:term, $f) => `((∇ fun $x => $f) $val)
| `(∇ $x:ident := $val:term, $f) => `(∇ (fun $x => $f) $val)
| `(∇ ($b:diffBinder), $f) => `(∇ $b, $f)
| `(∇! $f) => `((∇ $f) rewrite_by autodiff)

macro_rules
| `(∇! $f) => `((∇ $f) rewrite_by ftrans; ftrans; ftrans)
| `(∇! $x:ident, $f) => `(∇! fun $x => $f)
| `(∇! $x:ident : $type:term, $f) => `(∇! fun $x : $type => $f)
| `(∇! $x:ident := $val:term, $f) => `((∇! fun $x => $f) $val)
| `(∇! $x:ident := $val:term, $f) => `(∇! (fun $x => $f) $val)
| `(∇! ($b:diffBinder), $f) => `(∇! $b, $f)


@[app_unexpander gradient] def unexpandGradient : Lean.PrettyPrinter.Unexpander

| `($(_) $_ $f:term $x $dy $z $zs*) =>
Expand Down
20 changes: 17 additions & 3 deletions SciLean/Core/Notation/RevCDeriv.lean
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import SciLean.Core.Notation.CDeriv
import SciLean.Core.FunctionTransformations.RevCDeriv


--------------------------------------------------------------------------------
-- Notation -------------------------------------------------------------------
--------------------------------------------------------------------------------

namespace SciLean.NotationOverField


scoped syntax "<∂ " term:66 : term
scoped syntax "<∂ " term+ : term
scoped syntax "<∂ " diffBinder ", " term:66 : term
scoped syntax "<∂ " "(" diffBinder ")" ", " term:66 : term

scoped syntax "<∂! " term+ : term
scoped syntax "<∂! " diffBinder ", " term:66 : term
scoped syntax "<∂! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(<∂ $f $xs*) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(revCDeriv $K $f $xs*)) none
| `(<∂ $f) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(revCDeriv $K $f)) none
Expand All @@ -22,12 +29,19 @@ elab_rules : term
elabTerm (← `(revCDerivEval $K (fun $x => $b) $val $codir)) none

macro_rules
| `(<∂ $f $xs*) => `((<∂ $f) $xs*)
| `(<∂ $x:ident, $b) => `(<∂ (fun $x => $b))
| `(<∂ $x:ident := $val:term, $b) => `(<∂ (fun $x => $b) $val)
| `(<∂ $x:ident : $type:term, $b) => `(<∂ fun $x : $type => $b)
| `(<∂ ($b:diffBinder), $f) => `(<∂ $b, $f)

macro_rules
| `(<∂! $f $xs*) => `((<∂ $f $xs*) rewrite_by ftrans; ftrans; ftrans)
| `(<∂! $f) => `((<∂ $f) rewrite_by ftrans; ftrans; ftrans)
| `(<∂! $x:ident, $b) => `(<∂! (fun $x => $b))
| `(<∂! $x:ident := $val:term, $b) => `(<∂! (fun $x => $b) $val)
| `(<∂! $x:ident : $type:term, $b) => `(<∂! fun $x : $type => $b)
| `(<∂! ($b:diffBinder), $f) => `(<∂! $b, $f)


@[app_unexpander revCDeriv] def unexpandRevCDeriv : Lean.PrettyPrinter.Unexpander

Expand Down
25 changes: 0 additions & 25 deletions test/cderiv_notation.lean

This file was deleted.

82 changes: 82 additions & 0 deletions test/deriv_notation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import SciLean.Core.Notation.CDeriv
import SciLean.Core.Notation.Gradient
import SciLean.Core.Notation.FwdCDeriv
import SciLean.Core.Notation.RevCDeriv
import SciLean.Core.FloatAsReal


open SciLean

variable {K} [RealScalar K]

set_default_scalar K

#check ∂ (fun x : K => x*x)
#check ∂ (fun x => x*x) 1
#check ∂ (x:=(1:K)), x*x
#check ∂ (x:=1), x*x
#check ∂ (x:=0.1), x*x
#check ∂ (x:=((1:K),(2:K))), (x + x)

#check ∂! (fun x : K => x^2)
#check ∂! (fun x : K×K => x + x)
#check ∂! (fun x => x*x) 1
#check ∂! (x:=((1:K),(2:K))), (x + x)
#check ∂! (x:=1), x*x

variable {X} [Vec K X] (f : X → X)

#check ∂ (x:=0), f x

set_default_scalar Float

#eval ∂! (fun x => x^2) 1
#eval ∂! (fun x => x*x) 1
#eval ∂! (x:=1), x*x
#eval ∂! (fun x => x + x) (1.0,2.0) (1.0,0.0)
#eval ∂! (x:=(1.0,2.0);(1.0,0.0)), (x + x)



--------------------------------------------------------------------------------

set_default_scalar K

#check ∇ x : (K×K), x.1
#check ∇! x : (K×K), x.1
#check ∇! x : (K×K), x.2

variable (y : K × K)

#check ∇ (x:=y), (x + x)
#check ∇ (fun x => x + x) y
#check ∇ (fun x => x + x) ((1.0,2.0) : K×K)
#check (∇! x : (K×K), ⟪x,(1,0)⟫)


set_default_scalar Float

#eval ∇! (fun x => x^2) 1
#eval ∇! (fun x => x*x) 1
#eval ∇! (x:=1), x*x
#eval ∇! (fun x : Float×Float => (x + x).2) (1.0,2.0)
#eval ∇! (x:=((1.0 : Float),(2.0:Float))), (x + x).1



--------------------------------------------------------------------------------

set_default_scalar K

#check ∂>! x : K×K, (x.1 + x.2*x.1)
#check ∂>! x:=(1:K);2, (x + x*x)



--------------------------------------------------------------------------------

set_default_scalar K

#check <∂! x : K×K, (x.1 + x.2*x.1)
#check <∂! x:=(1:K), (x + x*x)

0 comments on commit 77ce441

Please sign in to comment.