Skip to content

Commit

Permalink
some refactoring of distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Mar 30, 2024
1 parent ea9d3b0 commit 4af1c95
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 55 deletions.
34 changes: 31 additions & 3 deletions SciLean/Core/Distribution/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ simproc_decl Distribution.mk_extAction_simproc (Distribution.extAction (Distribu
-- seqRight_eq := by intros; rfl
-- pure_seq := by intros; rfl

def vecDirac (x : X) (y : Y) : 𝒟'(X,Y) := ⟨fun φ ⊸ φ x • y⟩
abbrev dirac (x : X) : 𝒟' X := vecDirac x 1
def dirac (x : X) : 𝒟' X := ⟨fun φ ⊸ φ x⟩

open Notation
noncomputable
Expand All @@ -170,7 +169,7 @@ def Distribution.bind' (x' : 𝒟'(X,U)) (f : X → 𝒟'(Y,V)) (L : U → V →
----------------------------------------------------------------------------------------------------

@[simp, ftrans_simp]
theorem action_vecDirac (x : X) (y : Y) (φ : 𝒟 X) : ⟪(vecDirac x y), φ⟫ = φ x • y := by simp[dirac,vecDirac]
theorem action_dirac (x : X) (φ : 𝒟 X) : ⟪dirac x, φ⟫ = φ x := by simp[dirac]

@[simp, ftrans_simp]
theorem action_bind (x : 𝒟'(X,Z)) (f : X → 𝒟' Y) (φ : 𝒟 Y) :
Expand Down Expand Up @@ -411,6 +410,35 @@ abbrev Distribution.postRestrict (T : 𝒟'(X,𝒟'(Y,Z))) (A : X → Set Y) :
sorry_proof⟩⟩


@[simp, ftrans_simp]
theorem postComp_id (u : 𝒟'(X,Y)) :
(u.postComp (fun y => y)) = u := sorry_proof

@[simp, ftrans_simp]
theorem postComp_comp (x : 𝒟'(X,U)) (g : U → V) (f : V → W) :
(x.postComp g).postComp f
=
x.postComp (fun u => f (g u)) := sorry_proof

@[simp, ftrans_simp]
theorem postComp_assoc (x : 𝒟'(X,U)) (y : U → 𝒟'(Y,V)) (f : V → W) (φ : Y → R) :
(x.postComp y).postComp (fun T => T.postComp f)
=
(x.postComp (fun u => (y u).postComp f)) := sorry_proof

@[action_push]
theorem postComp_extAction (x : 𝒟'(X,U)) (y : U → V) (φ : X → R) :
(x.postComp y).extAction φ
=
y (x.extAction φ) := sorry_proof

@[action_push]
theorem postComp_restrict_extAction (x : 𝒟'(X,U)) (y : U → V) (A : Set X) (φ : X → R) :
((x.postComp y).restrict A).extAction φ
=
y ((x.restrict A).extAction φ) := sorry_proof


@[simp, ftrans_simp, action_push]
theorem Distribution.zero_postExtAction (φ : Y → R) : (0 : 𝒟'(X,𝒟'(Y,Z))).postExtAction φ = 0 := by sorry_proof

Expand Down
23 changes: 14 additions & 9 deletions SciLean/Core/Distribution/Eval.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ variable
{X} [TopologicalSpace X] [space : TCOr (Vec R X) (DiscreteTopology X)]
{Y} [Vec R Y]
{Z} [Vec R Z]
{U} [Vec R U]
{V} [Vec R V]
{W} [Vec R W]

set_default_scalar R

Expand All @@ -21,22 +24,24 @@ theorem action_extAction (T : 𝒟' X) (φ : 𝒟 X) :
T.action φ = T.extAction φ := sorry_proof

@[action_push]
theorem extAction_vecDirac (x : X) (y : Y) (φ : X → R) :
(vecDirac x y).extAction φ
theorem extAction_vecDirac (x : X) (φ : X → R) :
(dirac x).extAction φ
=
φ x • y := sorry_proof
φ x := sorry_proof

@[action_push]
theorem extAction_restrict_vecDirac (x : X) (y : Y) (A : Set X) (φ : X → R) :
((vecDirac x y).restrict A).extAction φ
theorem extAction_restrict_vecDirac (x : X) (A : Set X) (φ : X → R) :
((dirac x).restrict A).extAction φ
=
if x ∈ A then φ x • y else 0 := sorry_proof
if x ∈ A then φ x else 0 := sorry_proof

-- x.postComp (fun u => (y u).extAction φ) := by sorry_proof

@[action_push]
theorem postExtAction_vecDirac (x : X) (y : 𝒟'(Y,Z)) (φ : Y → R) :
(vecDirac x y).postExtAction φ
theorem postExtAction_postComp (x : 𝒟'(X,U)) (y : U → 𝒟'(Y,Z)) (φ : Y → R) :
(x.postComp y).postExtAction φ
=
vecDirac x (y.extAction φ) := sorry_proof
x.postComp (fun u => (y u).extAction φ) := by sorry_proof

variable [MeasureSpace X]

Expand Down
92 changes: 72 additions & 20 deletions SciLean/Core/Distribution/ParametricDistribDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ open Distribution
variable
{R} [RealScalar R]
{W} [Vec R W]
{X} [Vec R X]
{X} [Vec R X] [MeasureSpace X]
{Y} [Vec R Y] [Module ℝ Y]
{Z} [Vec R Z] [Module ℝ Z]
{U} [Vec R U] -- [Module ℝ U]


set_default_scalar R


noncomputable
def vecDiracDeriv (x dx : X) (y dy : Y) : 𝒟'(X,Y) := ⟨fun φ ⊸ φ x • dy + cderiv R φ x dx • y
def diracDeriv (x dx : X) : 𝒟' X := ⟨fun φ ⊸ cderiv R φ x dx⟩

@[fun_prop]
def DistribDifferentiableAt (f : X → 𝒟'(Y,Z)) (x : X) :=
Expand All @@ -43,15 +44,25 @@ def DistribDifferentiable (f : X → 𝒟'(Y,Z)) :=
∀ x, DistribDifferentiableAt f x


-- TODO:
-- probably change the definition of `parDistribDeriv` to:
-- ⟨⟨fun φ =>
-- if h : DistribDifferentiableAt f x then
-- ∂ (x':=x;dx), ⟪f x', φ⟫
-- else
-- 0 , sorry_proof⟩⟩
-- I believe in that case the function is indeed linear in φ

open Classical in
@[fun_trans]
noncomputable
def parDistribDeriv (f : X → 𝒟'(Y,Z)) (x dx : X) : 𝒟'(Y,Z) :=
⟨⟨fun φ =>
if _ : DistribDifferentiableAt f x then
∂ (x':=x;dx), ⟪f x', φ⟫
else
0, sorry_proof⟩⟩
⟨⟨fun φ => ∂ (x':=x;dx), ⟪f x', φ⟫, sorry_proof⟩⟩


@[simp, ftrans_simp]
theorem action_parDistribDeriv (f : X → 𝒟'(Y,Z)) (x dx : X) (φ : 𝒟 Y) :
⟪parDistribDeriv f x dx, φ⟫ = ∂ (x':=x;dx), ⟪f x', φ⟫ := rfl


----------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -79,32 +90,28 @@ theorem parDistribDeriv.const_rule (T : 𝒟'(X,Y)) :
----------------------------------------------------------------------------------------------------

@[fun_prop]
theorem vecDirac.arg_xy.DistribDiffrentiable_rule
(x : W → X) (y : W → Y) (hx : CDifferentiable R x) (hy : CDifferentiable R y) :
DistribDifferentiable (R:=R) (fun w => vecDirac (x w) (y w)) := by
theorem dirac.arg_xy.DistribDiffrentiable_rule
(x : W → X) (hx : CDifferentiable R x) :
DistribDifferentiable (R:=R) (fun w => dirac (x w)) := by
intro x
unfold DistribDifferentiableAt
intro φ hφ
simp [action_vecDirac, dirac]
simp [action_dirac, dirac]
fun_prop


@[fun_trans]
theorem vecDirac.arg_x.parDistribDeriv_rule
(x : W → X) (y : W → Y) (hx : CDifferentiable R x) (hy : CDifferentiable R y) :
parDistribDeriv (R:=R) (fun w => vecDirac (x w) (y w))
theorem dirac.arg_x.parDistribDeriv_rule
(x : W → X) (hx : CDifferentiable R x) :
parDistribDeriv (R:=R) (fun w => dirac (x w))
=
fun w dw =>
let xdx := fwdDeriv R x w dw
let ydy := fwdDeriv R y w dw
vecDiracDeriv xdx.1 xdx.2 ydy.1 ydy.2 := by --= (dpure (R:=R) ydy.1 ydy.2) := by
diracDeriv xdx.1 xdx.2 := by --= (dpure (R:=R) ydy.1 ydy.2) := by
funext w dw; ext φ
unfold parDistribDeriv vecDirac vecDiracDeriv
unfold parDistribDeriv dirac diracDeriv
simp [pure, fwdDeriv, DistribDifferentiableAt]
fun_trans
. intro φ' hφ' h
have : CDifferentiableAt R (fun w : W => (φ' w) (x w) • (y w)) w := by fun_prop
contradiction


----------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -176,6 +183,49 @@ theorem Bind.bind.arg_fx.parDistribDiff_rule



----------------------------------------------------------------------------------------------------
-- Move these around -------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

@[fun_prop]
theorem Distribution.restrict.arg_T.IsSmoothLinearMap_rule (T : W → 𝒟'(X,Y)) (A : Set X)
(hT : IsSmoothLinearMap R T) :
IsSmoothLinearMap R (fun w => (T w).restrict A) := sorry_proof

@[fun_prop]
theorem Distribution.restrict.arg_T.IsSmoothLinearMap_rule_simple (A : Set X) :
IsSmoothLinearMap R (fun (T : 𝒟'(X,Y)) => T.restrict A) := sorry_proof

@[fun_prop]
theorem Function.toDistribution.arg_f.CDifferentiable_rule (f : W → X → Y)
(hf : ∀ x, CDifferentiable R (f · x)) :
CDifferentiable R (fun w => (fun x => f w x).toDistribution (R:=R)) := sorry_proof

@[fun_trans]
theorem Function.toDistribution.arg_f.cderiv_rule (f : W → X → Y)
(hf : ∀ x, CDifferentiable R (f · x)) :
cderiv R (fun w => (fun x => f w x).toDistribution (R:=R))
=
fun w dw =>
(fun x =>
let dy := cderiv R (f · x) w dw
dy).toDistribution := sorry_proof

@[fun_trans]
theorem toDistribution.linear_parDistribDeriv_rule (f : W → X → Y) (L : Y → Z)
(hL : IsSmoothLinearMap R L) :
parDistribDeriv (fun w => (fun x => L (f w x)).toDistribution)
=
fun w dw =>
parDistribDeriv Tf w dw |>.postComp L := by
funext w dw
unfold parDistribDeriv Distribution.postComp Function.toDistribution
ext φ
simp [ftrans_simp, Distribution.mk_extAction_simproc]
sorry_proof



----------------------------------------------------------------------------------------------------
-- Integral ----------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
Expand All @@ -201,6 +251,8 @@ theorem cintegral.arg_f.cderiv_distrib_rule' (f : W → X → R) (A : Set X):

-- (parDistribDeriv (fun w => (f w ·).toDistribution) w dw).extAction (fun x => if x ∈ A then 1 else 0) := sorry_proof



@[fun_trans]
theorem cintegral.arg_f.parDistribDeriv_rule (f : W → X → Y → R) :
parDistribDeriv (fun w => (fun x => ∫' y, f w x y).toDistribution)
Expand Down
67 changes: 63 additions & 4 deletions SciLean/Core/Distribution/ParametricDistribRevDeriv.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import SciLean.Core.Distribution.ParametricDistribDeriv
import SciLean.Core.Distribution.ParametricDistribFwdDeriv
import SciLean.Core.Distribution.Eval

namespace SciLean

Expand All @@ -20,28 +22,37 @@ variable

set_default_scalar R


@[fun_trans]
noncomputable
def parDistribRevDeriv (f : X → 𝒟'(Y,Z)) (x : X) : 𝒟'(Y,Z×(Z→X)) :=
⟨⟨fun φ =>
let dz := semiAdjoint R (fun dx => ⟪parDistribDeriv f x dx,φ⟫)
let z := ⟪f x, φ⟫
(z, sorry), sorry_proof⟩⟩

(z, dz), sorry_proof⟩⟩


namespace parDistribRevDeriv


theorem comp_rule
(f : Y → 𝒟'(Z,U)) (g : X → Y)
(hf : DistribDifferentiable f) (hg : CDifferentiable R g) :
(hf : DistribDifferentiable f) (hg : HasAdjDiff R g) :
parDistribRevDeriv (fun x => f (g x))
=
fun x =>
let ydg := revDeriv R g x
let udf := parDistribRevDeriv f ydg.1
udf.postComp (fun (u,df') => (u, fun du => ydg.2 (df' du))) := by sorry_proof
udf.postComp (fun (u,df') => (u, fun du => ydg.2 (df' du))) := by

unfold parDistribRevDeriv
funext x; ext φ
simp
fun_trans
simp [action_push,revDeriv,fwdDeriv]
have : ∀ x, HasSemiAdjoint R (∂ x':=x, ⟪f x', φ⟫) := sorry_proof -- todo add: `DistribHasAdjDiff`
fun_trans



theorem bind_rule
Expand All @@ -52,3 +63,51 @@ theorem bind_rule
let ydg := parDistribRevDeriv g x
let zdf := fun y => parDistribRevDeriv (f · y) x
ydg.bind' zdf (fun (_,dg) (z,df) => (z, fun dr => dg dr + df dr)) := sorry_proof


theorem bind_rule'
(f : X → Y → 𝒟'(Z,V)) (g : X → 𝒟'(Y,U)) (L : U → V → W) :
parDistribRevDeriv (fun x => (g x).bind' (f x) L)
=
fun x =>
let ydg := parDistribRevDeriv g x
let zdf := fun y => parDistribRevDeriv (f · y) x
ydg.bind' zdf (fun (u,dg) (v,df) =>
(L u v, fun dw =>
df (semiAdjoint R (L u ·) dw) +
dg (semiAdjoint R (L · v) dw))) := by

unfold parDistribRevDeriv bind'
funext x; ext φ
simp
sorry_proof
sorry_proof



----------------------------------------------------------------------------------------------------
-- Dirac -------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

noncomputable
def diracRevDeriv (x : X) : 𝒟'(X,R×(R→X)) :=
⟨⟨fun φ => revDeriv R φ x, sorry_proof⟩⟩


@[fun_trans]
theorem dirac.arg_xy.parDistribRevDeriv_rule
(x : W → X) (hx : HasAdjDiff R x) :
parDistribRevDeriv (fun w => dirac (x w))
=
fun w =>
let xdx := revDeriv R x w
diracRevDeriv xdx.1 |>.postComp (fun (r,dφ) => (r, fun dr => xdx.2 (dφ dr))) := by

funext w; apply Distribution.ext _ _; intro φ
have : HasAdjDiff R φ := sorry_proof -- this should be consequence of that `R` has dimension one
simp [diracRevDeriv,revDeriv, parDistribRevDeriv]
fun_trans



#check Distribution.postComp
6 changes: 0 additions & 6 deletions SciLean/Core/Distribution/SimpleExamples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ theorem _root_.FiniteDimensional.finrank_unit : finrank R Unit = 0 := by sorry_p

variable [MeasureSpace R] -- [Module ℝ R]



def foo1 (t' : R) := (∂ (t:=t'), ∫' (x:R) in Ioo 0 1, if x ≤ t then (1:R) else 0)
rewrite_by
fun_trans only [scalarGradient, scalarCDeriv]
Expand All @@ -49,12 +47,8 @@ theorem foo1_spec (t : R) :
#eval foo1 (-1.0) -- 0.0
#eval foo1 2.0 -- 0.0

#check Set.add_empty

open Classical in

set_option pp.funBinderTypes true in

def foo2 (t' : R) := (∂ (t:=t'), ∫' (x:R) in Ioo 0 1, if x - t ≤ 0 then (1:R) else 0)
rewrite_by
fun_trans only [scalarGradient, scalarCDeriv]
Expand Down
Loading

0 comments on commit 4af1c95

Please sign in to comment.