From ea8cedc27ae7baf4d567e8ce0f62a5e3a1302a2b Mon Sep 17 00:00:00 2001 From: lecopivo Date: Tue, 12 Mar 2024 17:24:42 -0400 Subject: [PATCH] allow for patterns in derivative notation --- SciLean/Core/Notation/CDeriv.lean | 16 ++++++++-------- SciLean/Core/Notation/FwdDeriv.lean | 16 ++++++++-------- .../Core/Rand/Distributions/WalkOnSpheres.lean | 11 ++++++----- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/SciLean/Core/Notation/CDeriv.lean b/SciLean/Core/Notation/CDeriv.lean index f17add3d..6b479069 100644 --- a/SciLean/Core/Notation/CDeriv.lean +++ b/SciLean/Core/Notation/CDeriv.lean @@ -9,7 +9,7 @@ namespace SciLean.Notation syntax diffBinderType := " : " term syntax diffBinderValue := ":=" term (";" term)? -syntax diffBinder := ident (diffBinderType <|> diffBinderValue)? +syntax diffBinder := term (diffBinderType <|> diffBinderValue)? syntax "∂ " term:66 : term syntax "∂ " diffBinder ", " term:66 : term @@ -54,19 +54,19 @@ elab_rules : term macro_rules -| `(∂ $x:ident, $b) => `(∂ (fun $x => $b)) -| `(∂ $x:ident := $val:term, $b) => `(∂ (fun $x => $b) $val) -| `(∂ $x:ident : $type:term, $b) => `(∂ fun $x : $type => $b) +| `(∂ $x:term, $b) => `(∂ (fun $x => $b)) +| `(∂ $x:term := $val:term, $b) => `(∂ (fun $x => $b) $val) +| `(∂ $x:term : $type:term, $b) => `(∂ fun $x : $type => $b) | `(∂ ($b:diffBinder), $f) => `(∂ $b, $f) macro_rules -- in some cases it is still necessary to call fun_trans multiple times -- | `(∂! $f $xs*) => `((∂ $f $xs*) rewrite_by fun_trans; fun_trans; fun_trans) | `(∂! $f) => `((∂ $f) rewrite_by (try unfold scalarCDeriv); autodiff; autodiff) -| `(∂! $x:ident, $b) => `(∂! (fun $x => $b)) -| `(∂! $x:ident := $val:term, $b) => `(∂! (fun $x => $b) $val) -| `(∂! $x:ident := $val:term;$dir:term, $b) => `(((∂ $x:ident:=$val;$dir, $b) rewrite_by (try unfold scalarCDeriv);fun_trans)) -| `(∂! $x:ident : $type:term, $b) => `(∂! fun $x : $type => $b) +| `(∂! $x:term, $b) => `(∂! (fun $x => $b)) +| `(∂! $x:term := $val:term, $b) => `(∂! (fun $x => $b) $val) +| `(∂! $x:term := $val:term;$dir:term, $b) => `(((∂ $x:term:=$val;$dir, $b) rewrite_by (try unfold scalarCDeriv);fun_trans)) +| `(∂! $x:term : $type:term, $b) => `(∂! fun $x : $type => $b) | `(∂! ($b:diffBinder), $f) => `(∂! $b, $f) diff --git a/SciLean/Core/Notation/FwdDeriv.lean b/SciLean/Core/Notation/FwdDeriv.lean index 0625b357..8cf56d6d 100644 --- a/SciLean/Core/Notation/FwdDeriv.lean +++ b/SciLean/Core/Notation/FwdDeriv.lean @@ -27,20 +27,20 @@ elab_rules : term macro_rules -| `(∂> $x:ident, $b) => `(∂> (fun $x => $b)) -| `(∂> $x:ident := $val:term, $b) => `(∂> (fun $x => $b) $val) -| `(∂> $x:ident : $type:term, $b) => `(∂> fun $x : $type => $b) -| `(∂> $x:ident := $val:term ; $dir:term, $b) => `(∂> (fun $x => $b) $val $dir) +| `(∂> $x:term, $b) => `(∂> (fun $x => $b)) +| `(∂> $x:term := $val:term, $b) => `(∂> (fun $x => $b) $val) +| `(∂> $x:term : $type:term, $b) => `(∂> fun $x : $type => $b) +| `(∂> $x:term := $val:term ; $dir:term, $b) => `(∂> (fun $x => $b) $val $dir) | `(∂> ($b:diffBinder), $f) => `(∂> $b, $f) macro_rules | `(∂>! $f $xs*) => `((∂> $f $xs*) rewrite_by autodiff; autodiff; autodiff) | `(∂>! $f) => `((∂> $f) rewrite_by autodiff; autodiff; autodiff) -| `(∂>! $x:ident, $b) => `(∂>! (fun $x => $b)) -| `(∂>! $x:ident := $val:term, $b) => `(∂>! (fun $x => $b) $val) -| `(∂>! $x:ident : $type:term, $b) => `(∂>! fun $x : $type => $b) -| `(∂>! $x:ident := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir) +| `(∂>! $x:term, $b) => `(∂>! (fun $x => $b)) +| `(∂>! $x:term := $val:term, $b) => `(∂>! (fun $x => $b) $val) +| `(∂>! $x:term : $type:term, $b) => `(∂>! fun $x : $type => $b) +| `(∂>! $x:term := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir) | `(∂>! ($b:diffBinder), $f) => `(∂>! $b, $f) diff --git a/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean b/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean index 72c3d649..c5545f4d 100644 --- a/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean +++ b/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean @@ -69,8 +69,8 @@ def harmonicRec_fwdDeriv (n : ℕ) (g : Vec3 → Y) (g' : Vec3 → Vec3 → Y×Y) : Vec3 → Vec3 → Y×Y := (∂> x, harmonicRec n φ g x) rewrite_by - assuming (hφ' : ∂> φ = φ') (hφ : CDifferentiable Float φ) - (hg' : ∂> g = g') (hg : CDifferentiable Float g) + assuming (hφ' : (∂> φ) = φ') (hφ : CDifferentiable Float φ) + (hg' : (∂> g) = g') (hg : CDifferentiable Float g) induction n n' du h . simp[harmonicRec]; autodiff . simp[harmonicRec]; @@ -111,15 +111,16 @@ theorem harmonicRec'_CDifferentiable (n : ℕ) : CDifferentiable Float (fun (w : (Vec3 ⟿FD Float)×(Vec3 ⟿FD Y)×Vec3) => harmonicRec' n w.1 w.2.1 w.2.2) := by induction n <;> (simp[harmonicRec']; fun_prop) +variable (n : Nat) + noncomputable def harmonicRec'_fwdDeriv (n : ℕ) := - (∂> (w : (Vec3 ⟿FD Float)×(Vec3 ⟿FD Y)×Vec3), harmonicRec' n w.1 w.2.1 w.2.2) + (∂> (φ,g,x), harmonicRec' (Y:=Y) n φ g x) rewrite_by induction n n' du h . simp only [harmonicRec']; autodiff - . simp only [harmonicRec',smul_push] - autodiff + . simp only [harmonicRec',smul_push]; autodiff def harmonicRec'_fwdDeriv_rand (n : ℕ)