Skip to content

Commit

Permalink
revDeriv rules for get/set/introElem
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 1, 2023
1 parent fdb7a1c commit 8ae5880
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 24 deletions.
10 changes: 9 additions & 1 deletion SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ variable
{Cont : Type} {Idx : Type |> outParam} {Elem : Type |> outParam}
[ArrayType Cont Idx Elem]


@[ext]
theorem ext (x y : Cont) : (∀ i, x[i] = y[i]) → x = y :=
by
Expand Down Expand Up @@ -191,6 +190,15 @@ def _root_.ListN.toArrayType {n Elem} (Cont : Type) [ArrayType Cont (SciLean.Idx
(l : ListN Elem n) : Cont :=
introElem fun i => l.toArray[i.1.toNat]'sorry_proof

instance {Cont Idx Elem} [ArrayType Cont Idx Elem] [StructType Elem I ElemI] : StructType Cont (Idx×I) (fun (_,i) => ElemI i) where
structProj := fun x (i,j) => structProj x[i] j
structMake := fun f => introElem fun i => structMake fun j => f (i,j)
structModify := fun (i,j) f x => modifyElem x i (fun xi => structModify j f xi)
left_inv := by intro x; simp
right_inv := by intro x; simp
structProj_structModify := by intro x; simp
structProj_structModify' := by intro (i,j) (i',j') _ _ h; sorry_proof


section Operations

Expand Down
190 changes: 167 additions & 23 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,6 @@ by
unfold revCDeriv; ftrans
sorry_proof


instance {Cont Idx Elem} [ArrayType Cont Idx Elem] [StructType Elem I ElemI] : StructType Cont (Idx×I) (fun (_,i) => ElemI i) where
structProj := sorry
structMake := sorry
structModify := sorry
left_inv := sorry
right_inv := sorry
structProj_structModify := sorry
structProj_structModify' := sorry

instance {Cont Idx Elem} [ArrayType Cont Idx Elem] : StructType Cont Idx (fun _ => Elem) where
structProj := sorry
structMake := sorry
structModify := sorry
left_inv := sorry
right_inv := sorry
structProj_structModify := sorry
structProj_structModify' := sorry

@[ftrans]
theorem GetElem.getElem.arg_xs.revDeriv_rule
(f : X → Cont) (idx : Idx) (dom)
Expand Down Expand Up @@ -441,6 +422,27 @@ by
unfold revCDeriv; ftrans; ftrans; simp


@[ftrans]
theorem SetElem.setElem.arg_contelem.revDeriv_rule
(cont : X → Cont) (idx : Idx) (elem : X → Elem)
(hcont : HasAdjDiff K cont) (helem : HasAdjDiff K elem)
: revDeriv K (fun x => setElem (cont x) idx (elem x))
=
fun x =>
let cdc := revDeriv K cont x
let ede := revDerivUpdate K elem x
(setElem cdc.1 idx ede.1,
fun dcont' =>
let delem' := dcont'[idx]
let dcont' := setElem dcont' idx 0
let dx := cdc.2 dcont'
ede.2 delem' dx) :=
by
have ⟨_,_⟩ := hcont
have ⟨_,_⟩ := helem
unfold revDeriv; ftrans; ftrans; simp[revDerivUpdate,revDeriv]


@[ftrans]
theorem SetElem.setElem.arg_contelem.revDerivUpdate_rule
(cont : X → Cont) (idx : Idx) (elem : X → Elem)
Expand All @@ -453,15 +455,97 @@ theorem SetElem.setElem.arg_contelem.revDerivUpdate_rule
(setElem cdc.1 idx ede.1,
fun dcont' dx =>
let delem' := dcont'[idx]
ede.2 delem' (cdc.2 (setElem dcont' idx 0) dx)
) :=
let dcont' := setElem dcont' idx 0
let dx := cdc.2 dcont' dx
ede.2 delem' dx) :=
by
have ⟨_,_⟩ := hcont
have ⟨_,_⟩ := helem
unfold revDerivUpdate; ftrans; ftrans; simp[add_assoc]
sorry_proof
unfold revDerivUpdate; ftrans; ftrans; simp[add_assoc,revDerivUpdate]


@[ftrans]
theorem SetElem.setElem.arg_contelem.revDerivProj_rule
(cont : X → Cont) (idx : Idx) (elem : X → Elem)
(hcont : HasAdjDiff K cont) (helem : HasAdjDiff K elem)
: revDerivProj K Idx (fun x => setElem (cont x) idx (elem x))
=
fun x =>
let cdc := revDerivProj K Idx cont x
let ede := revDeriv K elem x
(setElem cdc.1 idx ede.1,
fun i dei =>
if i = idx then
ede.2 dei
else
cdc.2 i dei) :=
by
unfold revDerivProj; ftrans; ftrans; simp[revDerivUpdate,revDeriv]
funext x; simp; funext i dei
if h : i = idx then
subst h
simp[ArrayType.getElem_structProj, ArrayType.setElem_structModify]
sorry_proof
else
simp[h,ArrayType.getElem_structProj, ArrayType.setElem_structModify]
sorry_proof


@[ftrans]
theorem SetElem.setElem.arg_contelem.revDerivProj_rule'
{I ElemI} [StructType Elem I ElemI] [EnumType I] [∀ i, SemiInnerProductSpace K (ElemI i)]
[SemiInnerProductSpaceStruct K Elem I ElemI]
(cont : X → Cont) (idx : Idx) (elem : X → Elem)
(hcont : HasAdjDiff K cont) (helem : HasAdjDiff K elem)
: revDerivProj K (Idx×I) (fun x => setElem (cont x) idx (elem x))
=
fun x =>
let cdc := revDerivProj K (Idx×I) cont x
let ede := revDerivProj K I elem x
(setElem cdc.1 idx ede.1,
fun (i,j) deij =>
if i = idx then
ede.2 j deij
else
cdc.2 (i,j) deij) :=
by
unfold revDerivProj; ftrans; ftrans; simp[revDerivUpdate,revDeriv]
funext x; simp; funext (i,j) deij
if h : i = idx then
subst h
simp[ArrayType.getElem_structProj, ArrayType.setElem_structModify]
sorry_proof
else
simp[h,ArrayType.getElem_structProj, ArrayType.setElem_structModify]
sorry_proof


@[ftrans]
theorem SetElem.setElem.arg_contelem.revDerivProjUpdate_rule'
{I ElemI} [StructType Elem I ElemI] [EnumType I] [∀ i, SemiInnerProductSpace K (ElemI i)]
[SemiInnerProductSpaceStruct K Elem I ElemI]
(cont : X → Cont) (idx : Idx) (elem : X → Elem)
(hcont : HasAdjDiff K cont) (helem : HasAdjDiff K elem)
: revDerivProjUpdate K (Idx×I) (fun x => setElem (cont x) idx (elem x))
=
fun x =>
let cdc := revDerivProjUpdate K (Idx×I) cont x
let ede := revDerivProjUpdate K I elem x
(setElem cdc.1 idx ede.1,
fun (i,j) deij dx =>
if i = idx then
ede.2 j deij dx
else
cdc.2 (i,j) deij dx) :=
by
unfold revDerivProjUpdate; ftrans; ftrans; simp[revDerivProjUpdate]
funext x; simp; funext (i,j) deij
if h : i = idx then
subst h
simp[ArrayType.getElem_structProj, ArrayType.setElem_structModify]
else
simp[h,ArrayType.getElem_structProj, ArrayType.setElem_structModify]

end OnSemiInnerProductSpace


Expand Down Expand Up @@ -613,6 +697,66 @@ by
have ⟨_,_⟩ := hf
unfold revCDeriv; ftrans; ftrans; simp


@[ftrans]
theorem IntroElem.introElem.arg_f.revDeriv_rule
(f : X → Idx → Elem)
(hf : HasAdjDiff K f)
: revDeriv K (fun x => introElem (Cont:=Cont) (f x))
=
fun x =>
let fdf := revDeriv K f x
(introElem fdf.1,
fun dc => fdf.2 (fun i => dc[i])) :=
by
have ⟨_,_⟩ := hf
unfold revDeriv; ftrans; ftrans; simp

@[ftrans]
theorem IntroElem.introElem.arg_f.revDerivUpdate_rule
(f : X → Idx → Elem)
(hf : HasAdjDiff K f)
: revDerivUpdate K (fun x => introElem (Cont:=Cont) (f x))
=
fun x =>
let fdf := revDerivUpdate K f x
(introElem fdf.1,
fun dc dx => fdf.2 (fun i => dc[i]) dx) :=
by
unfold revDerivUpdate; ftrans

@[ftrans]
theorem IntroElem.introElem.arg_f.revDerivProj_rule
{I ElemI} [StructType Elem I ElemI] [EnumType I] [∀ i, SemiInnerProductSpace K (ElemI i)]
[SemiInnerProductSpaceStruct K Elem I ElemI]
(f : X → Idx → Elem)
(hf : HasAdjDiff K f)
: revDerivProj K (Idx×I) (fun x => introElem (Cont:=Cont) (f x))
=
fun x =>
let fdf := revDerivProj K (Idx×I) f x
(introElem fdf.1,
fun ij de => fdf.2 ij de) :=
by
unfold revDerivProj; ftrans; ftrans; simp
funext x; simp; funext i de
apply congr_arg; sorry_proof

@[ftrans]
theorem IntroElem.introElem.arg_f.revDerivProjUpdate_rule
{I ElemI} [StructType Elem I ElemI] [EnumType I] [∀ i, SemiInnerProductSpace K (ElemI i)]
[SemiInnerProductSpaceStruct K Elem I ElemI]
(f : X → Idx → Elem)
(hf : HasAdjDiff K f)
: revDerivProjUpdate K (Idx×I) (fun x => introElem (Cont:=Cont) (f x))
=
fun x =>
let fdf := revDerivProjUpdate K (Idx×I) f x
(introElem fdf.1,
fun ij de dx => fdf.2 ij de dx) :=
by
unfold revDerivProjUpdate; ftrans

end OnSemiInnerProductSpace


Expand Down

0 comments on commit 8ae5880

Please sign in to comment.