Skip to content

Commit

Permalink
fix regarding the index type I in revDerivProj(Update)
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 28, 2023
1 parent 8870eaa commit 5d2b838
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 37 deletions.
72 changes: 44 additions & 28 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ set_option linter.unusedVariables false
namespace SciLean

variable
(K : Type _) [IsROrC K]
(K I : Type _) [IsROrC K]
{X : Type _} [SemiInnerProductSpace K X]
{Y : Type _} [SemiInnerProductSpace K Y]
{Z : Type _} [SemiInnerProductSpace K Z]
{W : Type _} [SemiInnerProductSpace K W]
{ι : Type _} [EnumType ι]
{κ : Type _} [EnumType κ]
{E I : Type _} {EI : I → Type _}
{E : Type _} {EI : I → Type _}
[StructType E I EI] [EnumType I]
[SemiInnerProductSpace K E] [∀ i, SemiInnerProductSpace K (EI i)]
[SemiInnerProductSpaceStruct K E I EI]
Expand All @@ -41,7 +41,6 @@ def revDerivUpdate
let ydf := revDeriv K f x
(ydf.1, fun dy dx => dx + ydf.2 dy)

variable (I)
noncomputable
def revDerivProj [DecidableEq I]
(f : X → E) (x : X) : E×((i : I)→EI i→X) :=
Expand All @@ -54,7 +53,6 @@ def revDerivProjUpdate [DecidableEq I]
(f : X → E) (x : X) : E×((i : I)→EI i→X→X) :=
let ydf' := revDerivProj K I f x
(ydf'.1, fun i de dx => dx + ydf'.2 i de)
variable {I}

--------------------------------------------------------------------------------
-- simplification rules for individual components ------------------------------
Expand Down Expand Up @@ -143,6 +141,7 @@ by
variable{X}

variable(EI)
variable {I}
theorem proj_rule (i : I)
: revDeriv K (fun (x : (i:I) → EI i) => x i)
=
Expand All @@ -151,6 +150,7 @@ theorem proj_rule (i : I)
by
unfold revDeriv
funext _; ftrans; ftrans
variable (I)
variable {EI}

theorem comp_rule
Expand Down Expand Up @@ -191,24 +191,26 @@ by
unfold revDeriv
funext _; ftrans; ftrans; rfl

variable {I}
theorem pi_rule
(f : X → (i : I) → EI i) (hf : ∀ i, HasAdjDiff K (f · i))
: (revDeriv K fun (x : X) (i : I) => f x i)
=
fun x =>
let xdf := revDerivProjUpdate K ((i:I)×Unit) f x
let xdf := revDerivProjUpdate K I f x
(fun i => xdf.1 i,
fun dy => Id.run do
let mut dx : X := 0
for i in fullRange I do
dx := xdf.2 ⟨i,()⟩ (dy i) dx
dx := xdf.2 i (dy i) dx
dx) :=
by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
unfold revDeriv
funext _; ftrans; ftrans
sorry_proof
variable (I)

end revDeriv

Expand Down Expand Up @@ -236,6 +238,7 @@ by
variable {X}

variable (EI)
variable {I}
theorem proj_rule (i : I)
: revDerivUpdate K (fun (x : (i:I) → EI i) => x i)
=
Expand All @@ -246,6 +249,7 @@ by
simp [revDeriv.proj_rule]
funext _; ftrans; ftrans;
simp; funext dxi dx j; simp; sorry_proof
variable (I)
variable {EI}

theorem comp_rule
Expand Down Expand Up @@ -281,22 +285,24 @@ by
unfold revDerivUpdate
simp [revDeriv.let_rule _ _ _ hf hg, revDerivUpdate,add_assoc]

variable {I}
theorem pi_rule
(f : X → (i : I) → EI i) (hf : ∀ i, HasAdjDiff K (f · i))
: (revDerivUpdate K fun (x : X) (i : I) => f x i)
=
fun x =>
let xdf := revDerivProjUpdate K ((i:I)×Unit) f x
let xdf := revDerivProjUpdate K I f x
(fun i => xdf.1 i,
fun dy dx => Id.run do
let mut dx := dx
for i in fullRange I do
dx := xdf.2 ⟨i,()⟩ (dy i) dx
dx := xdf.2 i (dy i) dx
dx) :=
by
unfold revDerivUpdate
simp [revDeriv.pi_rule _ _ hf, revDerivUpdate]
sorry_proof
variable (I)

end revDerivUpdate

Expand Down Expand Up @@ -378,12 +384,12 @@ theorem pi_rule
: (revDerivProj K Unit fun (x : X) (i : ι) => f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Unit) f x
let ydf := revDerivProjUpdate K ι f x
(fun i => ydf.1 i,
fun _ df => Id.run do
let mut dx : X := 0
for i in fullRange ι do
dx := ydf.2 (i,()) (df i) dx
dx := ydf.2 i (df i) dx
dx) :=
by
sorry_proof
Expand Down Expand Up @@ -458,7 +464,7 @@ theorem comp_rule
ydg'.2 (zdf'.2 i de) dx) :=
by
funext x
simp[revDerivProjUpdate,revDerivProj.comp_rule _ _ _ hf hg]
simp[revDerivProjUpdate,revDerivProj.comp_rule _ _ _ _ hf hg]
rfl


Expand All @@ -476,20 +482,20 @@ theorem let_rule
ydg'.2 dxy.2 dxy.1) :=
by
unfold revDerivProjUpdate
simp [revDerivProj.let_rule _ _ _ hf hg,add_assoc,add_comm,revDerivUpdate]
simp [revDerivProj.let_rule _ _ _ _ hf hg,add_assoc,add_comm,revDerivUpdate]


theorem pi_rule
(f : X → ι → Y) (hf : ∀ i, HasAdjDiff K (f · i))
: (revDerivProjUpdate K Unit fun (x : X) (i : ι) => f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Unit) f x
let ydf := revDerivProjUpdate K ι f x
(fun i => ydf.1 i,
fun _ df dx => Id.run do
let mut dx : X := dx
for i in fullRange ι do
dx := ydf.2 (i,()) (df i) dx
dx := ydf.2 i (df i) dx
dx) :=
by
conv => lhs; unfold revDerivProjUpdate
Expand Down Expand Up @@ -727,33 +733,38 @@ def ftransExt : FTransExt where

idRule e X := do
let .some K := e.getArg? 0 | return none
let proof ← mkAppOptM ``id_rule #[K,none, X,none,none,none,none,none,none]
let .some I := e.getArg? 1 | return none
let proof ← mkAppOptM ``id_rule #[K,I,none, X,none,none,none,none,none]
tryTheorems
#[ { proof := proof, origin := .decl ``id_rule, rfl := false} ]
discharger e

constRule e X y := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``const_rule #[K, X, y], origin := .decl ``const_rule, rfl := false} ]
#[ { proof := ← mkAppM ``const_rule #[K, I, X, y], origin := .decl ``const_rule, rfl := false} ]
discharger e

projRule e X i := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``proj_rule #[K, X, i], origin := .decl ``proj_rule, rfl := false} ]
#[ { proof := ← mkAppM ``proj_rule #[K, I, X, i], origin := .decl ``proj_rule, rfl := false} ]
discharger e

compRule e f g := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``comp_rule #[K, f, g], origin := .decl ``comp_rule, rfl := false} ]
#[ { proof := ← mkAppM ``comp_rule #[K, I, f, g], origin := .decl ``comp_rule, rfl := false} ]
discharger e

letRule e f g := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``let_rule #[K, f, g], origin := .decl ``let_rule, rfl := false} ]
#[ { proof := ← mkAppM ``let_rule #[K, I, f, g], origin := .decl ``let_rule, rfl := false} ]
discharger e

piRule e f := do
Expand Down Expand Up @@ -818,33 +829,38 @@ def ftransExt : FTransExt where

idRule e X := do
let .some K := e.getArg? 0 | return none
let proof ← mkAppOptM ``id_rule #[K,none, X,none,none,none,none,none,none]
let .some I := e.getArg? 1 | return none
let proof ← mkAppOptM ``id_rule #[K,I,none, X,none,none,none,none,none]
tryTheorems
#[ { proof := proof, origin := .decl ``id_rule, rfl := false} ]
discharger e

constRule e X y := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``const_rule #[K, X, y], origin := .decl ``const_rule, rfl := false} ]
#[ { proof := ← mkAppM ``const_rule #[K, I, X, y], origin := .decl ``const_rule, rfl := false} ]
discharger e

projRule e X i := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``proj_rule #[K, X, i], origin := .decl ``proj_rule, rfl := false} ]
#[ { proof := ← mkAppM ``proj_rule #[K, I, X, i], origin := .decl ``proj_rule, rfl := false} ]
discharger e

compRule e f g := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``comp_rule #[K, f, g], origin := .decl ``comp_rule, rfl := false} ]
#[ { proof := ← mkAppM ``comp_rule #[K, I, f, g], origin := .decl ``comp_rule, rfl := false} ]
discharger e

letRule e f g := do
let .some K := e.getArg? 0 | return none
let .some I := e.getArg? 1 | return none
tryTheorems
#[ { proof := ← mkAppM ``let_rule #[K, f, g], origin := .decl ``let_rule, rfl := false} ]
#[ { proof := ← mkAppM ``let_rule #[K, I, f, g], origin := .decl ``let_rule, rfl := false} ]
discharger e

piRule e f := do
Expand Down Expand Up @@ -1516,12 +1532,12 @@ theorem SciLean.EnumType.sum.arg_f.revDeriv_rule {ι : Type} [EnumType ι]
: revDeriv K (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Unit) f x
let ydf := revDerivProjUpdate K ι f x
(∑ i, ydf.1 i,
fun dy => Id.run do
let mut dx := 0
for i in fullRange ι do
dx := ydf.2 (i,()) dy dx
dx := ydf.2 i dy dx
dx) :=
by
have _ := fun i => (hf i).1
Expand All @@ -1538,12 +1554,12 @@ theorem SciLean.EnumType.sum.arg_f.revDerivUpdate_rule {ι : Type} [EnumType ι]
: revDerivUpdate K (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Unit) f x
let ydf := revDerivProjUpdate K ι f x
(∑ i, ydf.1 i,
fun dy dx => Id.run do
let mut dx := dx
for i in fullRange ι do
dx := ydf.2 (i,()) dy dx
dx := ydf.2 i dy dx
dx) :=
by
simp[revDerivUpdate]
Expand Down
28 changes: 19 additions & 9 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,15 @@ instance {Cont Idx Elem} [ArrayType Cont Idx Elem] [StructType Elem I ElemI] : S
right_inv := sorry
structProj_structModify := sorry
structProj_structModify' := sorry


instance {Cont Idx Elem} [ArrayType Cont Idx Elem] : StructType Cont Idx (fun _ => Elem) where
structProj := sorry
structMake := sorry
structModify := sorry
left_inv := sorry
right_inv := sorry
structProj_structModify := sorry
structProj_structModify' := sorry

@[ftrans]
theorem GetElem.getElem.arg_xs.revDeriv_rule
Expand All @@ -207,9 +215,9 @@ theorem GetElem.getElem.arg_xs.revDeriv_rule
: revDeriv K (fun x => getElem (f x) idx dom)
=
fun x =>
let ydf := revDerivProj K f x
let ydf := revDerivProj K Idx f x
(getElem ydf.1 idx dom,
fun delem => ydf.2 (idx,()) delem) :=
fun delem => ydf.2 idx delem) :=
by
have ⟨_,_⟩ := hf
unfold revDeriv; ftrans; ftrans
Expand All @@ -225,9 +233,9 @@ theorem GetElem.getElem.arg_xs.revDerivUpdate_rule
: revDerivUpdate K (fun x => getElem (f x) idx dom)
=
fun x =>
let ydf := revDerivProjUpdate K f x
let ydf := revDerivProjUpdate K Idx f x
(getElem ydf.1 idx dom,
fun delem dx => ydf.2 (idx,()) delem dx) :=
fun delem dx => ydf.2 idx delem dx) :=
by
unfold revDerivUpdate; ftrans; ftrans; simp[revDerivProjUpdate]

Expand All @@ -238,23 +246,25 @@ theorem GetElem.getElem.arg_xs.revDerivProj_rule
[SemiInnerProductSpaceStruct K Elem I ElemI]
(f : X → Cont) (idx : Idx) (dom)
(hf : HasAdjDiff K f)
: revDerivProj K (fun x => getElem (f x) idx dom)
: revDerivProj K I (fun x => getElem (f x) idx dom)
=
fun x =>
let ydf := revDerivProj K f x
let ydf := revDerivProj K (Idx×I) f x
(getElem ydf.1 idx dom,
fun i delem => ydf.2 (idx,i) delem) :=
by
sorry_proof

@[ftrans]
theorem GetElem.getElem.arg_xs.revDerivProjUpdate_rule
{I ElemI} [StructType Elem I ElemI] [EnumType I] [∀ i, SemiInnerProductSpace K (ElemI i)]
[SemiInnerProductSpaceStruct K Elem I ElemI]
(f : X → Cont) (idx : Idx) (dom)
(hf : HasAdjDiff K f)
: revDerivProjUpdate K (fun x => getElem (f x) idx dom)
: revDerivProjUpdate K I (fun x => getElem (f x) idx dom)
=
fun x =>
let ydf := revDerivProjUpdate K f x
let ydf := revDerivProjUpdate K (Idx×I) f x
(getElem ydf.1 idx dom,
fun i delem dx => ydf.2 (idx,i) delem dx) :=
by
Expand Down
Loading

0 comments on commit 5d2b838

Please sign in to comment.