Skip to content

Commit

Permalink
allow for patterns in derivative notation
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Mar 12, 2024
1 parent 7538559 commit ea8cedc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
16 changes: 8 additions & 8 deletions SciLean/Core/Notation/CDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace SciLean.Notation

syntax diffBinderType := " : " term
syntax diffBinderValue := ":=" term (";" term)?
syntax diffBinder := ident (diffBinderType <|> diffBinderValue)?
syntax diffBinder := term (diffBinderType <|> diffBinderValue)?

syntax "∂ " term:66 : term
syntax "∂ " diffBinder ", " term:66 : term
Expand Down Expand Up @@ -54,19 +54,19 @@ elab_rules : term


macro_rules
| `(∂ $x:ident, $b) => `(∂ (fun $x => $b))
| `(∂ $x:ident := $val:term, $b) => `(∂ (fun $x => $b) $val)
| `(∂ $x:ident : $type:term, $b) => `(∂ fun $x : $type => $b)
| `(∂ $x:term, $b) => `(∂ (fun $x => $b))
| `(∂ $x:term := $val:term, $b) => `(∂ (fun $x => $b) $val)
| `(∂ $x:term : $type:term, $b) => `(∂ fun $x : $type => $b)
| `(∂ ($b:diffBinder), $f) => `(∂ $b, $f)

macro_rules
-- in some cases it is still necessary to call fun_trans multiple times
-- | `(∂! $f $xs*) => `((∂ $f $xs*) rewrite_by fun_trans; fun_trans; fun_trans)
| `(∂! $f) => `((∂ $f) rewrite_by (try unfold scalarCDeriv); autodiff; autodiff)
| `(∂! $x:ident, $b) => `(∂! (fun $x => $b))
| `(∂! $x:ident := $val:term, $b) => `(∂! (fun $x => $b) $val)
| `(∂! $x:ident := $val:term;$dir:term, $b) => `(((∂ $x:ident:=$val;$dir, $b) rewrite_by (try unfold scalarCDeriv);fun_trans))
| `(∂! $x:ident : $type:term, $b) => `(∂! fun $x : $type => $b)
| `(∂! $x:term, $b) => `(∂! (fun $x => $b))
| `(∂! $x:term := $val:term, $b) => `(∂! (fun $x => $b) $val)
| `(∂! $x:term := $val:term;$dir:term, $b) => `(((∂ $x:term:=$val;$dir, $b) rewrite_by (try unfold scalarCDeriv);fun_trans))
| `(∂! $x:term : $type:term, $b) => `(∂! fun $x : $type => $b)
| `(∂! ($b:diffBinder), $f) => `(∂! $b, $f)


Expand Down
16 changes: 8 additions & 8 deletions SciLean/Core/Notation/FwdDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ elab_rules : term


macro_rules
| `(∂> $x:ident, $b) => `(∂> (fun $x => $b))
| `(∂> $x:ident := $val:term, $b) => `(∂> (fun $x => $b) $val)
| `(∂> $x:ident : $type:term, $b) => `(∂> fun $x : $type => $b)
| `(∂> $x:ident := $val:term ; $dir:term, $b) => `(∂> (fun $x => $b) $val $dir)
| `(∂> $x:term, $b) => `(∂> (fun $x => $b))
| `(∂> $x:term := $val:term, $b) => `(∂> (fun $x => $b) $val)
| `(∂> $x:term : $type:term, $b) => `(∂> fun $x : $type => $b)
| `(∂> $x:term := $val:term ; $dir:term, $b) => `(∂> (fun $x => $b) $val $dir)
| `(∂> ($b:diffBinder), $f) => `(∂> $b, $f)


macro_rules
| `(∂>! $f $xs*) => `((∂> $f $xs*) rewrite_by autodiff; autodiff; autodiff)
| `(∂>! $f) => `((∂> $f) rewrite_by autodiff; autodiff; autodiff)
| `(∂>! $x:ident, $b) => `(∂>! (fun $x => $b))
| `(∂>! $x:ident := $val:term, $b) => `(∂>! (fun $x => $b) $val)
| `(∂>! $x:ident : $type:term, $b) => `(∂>! fun $x : $type => $b)
| `(∂>! $x:ident := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir)
| `(∂>! $x:term, $b) => `(∂>! (fun $x => $b))
| `(∂>! $x:term := $val:term, $b) => `(∂>! (fun $x => $b) $val)
| `(∂>! $x:term : $type:term, $b) => `(∂>! fun $x : $type => $b)
| `(∂>! $x:term := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir)
| `(∂>! ($b:diffBinder), $f) => `(∂>! $b, $f)


Expand Down
11 changes: 6 additions & 5 deletions SciLean/Core/Rand/Distributions/WalkOnSpheres.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def harmonicRec_fwdDeriv (n : ℕ)
(g : Vec3 → Y) (g' : Vec3 → Vec3 → Y×Y) : Vec3 → Vec3 → Y×Y :=
(∂> x, harmonicRec n φ g x)
rewrite_by
assuming (hφ' : ∂> φ = φ') (hφ : CDifferentiable Float φ)
(hg' : ∂> g = g') (hg : CDifferentiable Float g)
assuming (hφ' : (∂> φ) = φ') (hφ : CDifferentiable Float φ)
(hg' : (∂> g) = g') (hg : CDifferentiable Float g)
induction n n' du h
. simp[harmonicRec]; autodiff
. simp[harmonicRec];
Expand Down Expand Up @@ -111,15 +111,16 @@ theorem harmonicRec'_CDifferentiable (n : ℕ) :
CDifferentiable Float (fun (w : (Vec3 ⟿FD Float)×(Vec3 ⟿FD Y)×Vec3) => harmonicRec' n w.1 w.2.1 w.2.2) := by
induction n <;> (simp[harmonicRec']; fun_prop)

variable (n : Nat)


noncomputable
def harmonicRec'_fwdDeriv (n : ℕ) :=
(∂> (w : (Vec3 ⟿FD Float)×(Vec3 ⟿FD Y)×Vec3), harmonicRec' n w.1 w.2.1 w.2.2)
(∂> (φ,g,x), harmonicRec' (Y:=Y) n φ g x)
rewrite_by
induction n n' du h
. simp only [harmonicRec']; autodiff
. simp only [harmonicRec',smul_push]
autodiff
. simp only [harmonicRec',smul_push]; autodiff


def harmonicRec'_fwdDeriv_rand (n : ℕ)
Expand Down

0 comments on commit ea8cedc

Please sign in to comment.