Skip to content

Commit

Permalink
improve rhs form of some revDeriv rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 30, 2023
1 parent f516f51 commit aea3ae2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 29 deletions.
124 changes: 96 additions & 28 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set_option linter.unusedVariables false

namespace SciLean

-- set_option linter.ftransSsaRhs true

variable
(K I : Type _) [IsROrC K]
{X : Type _} [SemiInnerProductSpace K X]
Expand Down Expand Up @@ -915,7 +917,10 @@ theorem Prod.mk.arg_fstsnd.revDeriv_rule
fun x =>
let ydg := revDeriv K g x
let zdf := revDerivUpdate K f x
((ydg.1,zdf.1), fun dyz => zdf.2 dyz.2 (ydg.2 dyz.1)) :=
((ydg.1,zdf.1),
fun dyz =>
let dx := ydg.2 dyz.1
zdf.2 dyz.2 dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -930,7 +935,10 @@ theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K f x
((ydg.1,zdf.1), fun dyz dx => zdf.2 dyz.2 (ydg.2 dyz.1 dx)) :=
((ydg.1,zdf.1),
fun dyz dx =>
let dx := ydg.2 dyz.1 dx
zdf.2 dyz.2 dx) :=
by
unfold revDerivUpdate; ftrans; simp[add_assoc, revDerivUpdate]

Expand Down Expand Up @@ -1107,7 +1115,10 @@ theorem HAdd.hAdd.arg_a0a1.revDeriv_rule
fun x =>
let ydf := revDeriv K f x
let ydg := revDerivUpdate K g x
(ydf.1 + ydg.1, fun dy => ydg.2 dy (ydf.2 dy)) :=
(ydf.1 + ydg.1,
fun dy =>
let dx := ydf.2 dy
ydg.2 dy dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -1121,7 +1132,10 @@ theorem HAdd.hAdd.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let ydg := revDerivUpdate K g x
(ydf.1 + ydg.1, fun dy dx => ydg.2 dy (ydf.2 dy dx)) :=
(ydf.1 + ydg.1,
fun dy dx =>
let dx := ydf.2 dy dx
ydg.2 dy dx) :=
by
unfold revDerivUpdate
ftrans; funext x; simp[add_assoc,revDerivUpdate]
Expand All @@ -1135,7 +1149,9 @@ theorem HAdd.hAdd.arg_a0a1.revDerivProj_rule
let ydf := revDerivProj K Yi f x
let ydg := revDerivProjUpdate K Yi g x
(ydf.1 + ydg.1,
fun i dy => (ydg.2 i dy (ydf.2 i dy))) :=
fun i dy =>
let dx := ydf.2 i dy
(ydg.2 i dy dx)) :=
by
unfold revDerivProjUpdate; unfold revDerivProj
ftrans; simp[revDerivUpdate]
Expand All @@ -1148,7 +1164,10 @@ theorem HAdd.hAdd.arg_a0a1.revDerivProjUpdate_rule
fun x =>
let ydf := revDerivProjUpdate K Yi f x
let ydg := revDerivProjUpdate K Yi g x
(ydf.1 + ydg.1, fun i dy dx => ydg.2 i dy (ydf.2 i dy dx)) :=
(ydf.1 + ydg.1,
fun i dy dx =>
let dx := ydf.2 i dy dx
ydg.2 i dy dx) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivProjUpdate, add_assoc]
Expand All @@ -1165,7 +1184,11 @@ theorem HSub.hSub.arg_a0a1.revDeriv_rule
fun x =>
let ydf := revDeriv K f x
let ydg := revDerivUpdate K g x
(ydf.1 - ydg.1, fun dy => ydg.2 (-dy) (ydf.2 dy)) :=
(ydf.1 - ydg.1,
fun dy =>
let dx := ydf.2 dy
let dy' := -dy
ydg.2 dy' dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -1180,7 +1203,11 @@ theorem HSub.hSub.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let ydg := revDerivUpdate K g x
(ydf.1 - ydg.1, fun dy dx => ydg.2 (-dy) (ydf.2 dy dx)) :=
(ydf.1 - ydg.1,
fun dy dx =>
let dx := ydf.2 dy dx
let dy' := -dy
ydg.2 dy' dx) :=
by
unfold revDerivUpdate
ftrans; funext x; simp[add_assoc,revDerivUpdate]
Expand All @@ -1194,7 +1221,10 @@ theorem HSub.hSub.arg_a0a1.revDerivProj_rule
let ydf := revDerivProj K Yi f x
let ydg := revDerivProjUpdate K Yi g x
(ydf.1 - ydg.1,
fun i dy => (ydg.2 i (-dy) (ydf.2 i dy))) :=
fun i dy =>
let dx := ydf.2 i dy
let dy' := -dy
(ydg.2 i dy' dx)) :=
by
unfold revDerivProjUpdate; unfold revDerivProj
ftrans; simp[revDerivUpdate, neg_pull,revDeriv]
Expand All @@ -1208,7 +1238,11 @@ theorem HSub.hSub.arg_a0a1.revDerivProjUpdate_rule
fun x =>
let ydf := revDerivProjUpdate K Yi f x
let ydg := revDerivProjUpdate K Yi g x
(ydf.1 - ydg.1, fun i dy dx => ydg.2 i (-dy) (ydf.2 i dy dx)) :=
(ydf.1 - ydg.1,
fun i dy dx =>
let dx := ydf.2 i dy dx
let dy' := -dy
ydg.2 i dy' dx) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivProjUpdate, neg_pull, revDerivProj, revDeriv,add_assoc]
Expand All @@ -1223,7 +1257,10 @@ theorem Neg.neg.arg_a0.revDeriv_rule
: (revDeriv K fun x => - f x) x
=
let ydf := revDeriv K f x
(-ydf.1, fun dy => - ydf.2 dy) :=
(-ydf.1,
fun dy =>
let dx := ydf.2 dy
(-dx)) :=
by
unfold revDeriv; simp; ftrans; ftrans

Expand All @@ -1234,7 +1271,10 @@ theorem Neg.neg.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K f x
(-ydf.1, fun dy dx => ydf.2 (-dy) dx) :=
(-ydf.1,
fun dy dx =>
let dy' := -dy
ydf.2 dy' dx) :=
by
unfold revDerivUpdate; funext x; ftrans; simp[neg_pull,revDeriv]

Expand All @@ -1245,9 +1285,12 @@ theorem Neg.neg.arg_a0.revDerivProj_rule
=
fun x =>
let ydf := revDerivProj K Yi f x
(-ydf.1, fun i dy => ydf.2 i (-dy)) :=
(-ydf.1,
fun i dy =>
let dy' := -dy
ydf.2 i dy') :=
by
unfold revDerivProj; ftrans; simp[neg_pull,revDeriv]
unfold revDerivProj; ftrans; simp[neg_push,revDeriv]

@[ftrans]
theorem Neg.neg.arg_a0.revDerivProjUpdate_rule
Expand All @@ -1256,7 +1299,10 @@ theorem Neg.neg.arg_a0.revDerivProjUpdate_rule
=
fun x =>
let ydf := revDerivProjUpdate K Yi f x
(-ydf.1, fun i dy dx => ydf.2 i (-dy) dx) :=
(-ydf.1,
fun i dy dx =>
let dy' := -dy
ydf.2 i dy' dx) :=
by
unfold revDerivProjUpdate; ftrans

Expand All @@ -1278,7 +1324,8 @@ theorem HMul.hMul.arg_a0a1.revDeriv_rule
fun dx' =>
let dx₁ := (conj zdg.1 * dx')
let dx₂ := (conj ydf.1 * dx')
ydf.2 dx₁ (zdg.2 dx₂)) :=
let dx := zdg.2 dx₂
ydf.2 dx₁ dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -1298,7 +1345,8 @@ theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule
fun dx' dx =>
let dx₁ := (conj zdg.1 * dx')
let dx₂ := (conj ydf.1 * dx')
ydf.2 dx₁ (zdg.2 dx₂ dx)) :=
let dx := zdg.2 dx₂ dx
ydf.2 dx₁ dx) :=
by
unfold revDerivUpdate; simp; ftrans; ftrans;
simp [smul_push,add_assoc,revDerivUpdate]
Expand All @@ -1315,8 +1363,9 @@ theorem HMul.hMul.arg_a0a1.revDerivProj_rule
(ydf.1 * zdg.1,
fun _ dy =>
let dy₁ := (conj zdg.1)*dy
let dy₂ := (conj ydf.1)* dy
ydf.2 dy₁ (zdg.2 dy₂)) :=
let dy₂ := (conj ydf.1)*dy
let dx := zdg.2 dy₂
ydf.2 dy₁ dx) :=
by
unfold revDerivProj
ftrans; simp[oneHot, structMake]
Expand All @@ -1334,7 +1383,8 @@ theorem HMul.hMul.arg_a0a1.revDerivProjUpdate_rule
fun _ dy dx =>
let dy₁ := (conj zdg.1)*dy
let dy₂ := (conj ydf.1)*dy
ydf.2 dy₁ (zdg.2 dy₂ dx)) :=
let dx := zdg.2 dy₂ dx
ydf.2 dy₁ dx) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]
Expand All @@ -1355,8 +1405,11 @@ theorem HSMul.hSMul.arg_a0a1.revDeriv_rule
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 • zdg.1,
fun dx' =>
ydf.2 (inner zdg.1 dx') (conj ydf.1 • zdg.2 dx')) :=
fun dy' =>
let dk := inner zdg.1 dy'
let dx := zdg.2 dy'
let dx := conj ydf.1 • dx
ydf.2 dk dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -1373,7 +1426,12 @@ theorem HSMul.hSMul.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 • zdg.1, fun dy dx => ydf.2 (inner zdg.1 dy) (zdg.2 (conj ydf.1•dy) dx)) :=
(ydf.1 • zdg.1,
fun dy dx =>
let dk := inner zdg.1 dy
let dy' := conj ydf.1 • dy
let dx := zdg.2 dy' dx
ydf.2 dk dx) :=
by
unfold revDerivUpdate;
funext x; ftrans; simp[mul_assoc,add_assoc,revDerivUpdate,revDeriv,smul_push]
Expand All @@ -1389,10 +1447,14 @@ theorem HSMul.hSMul.arg_a0a1.revDerivProj_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivProj K Yi g x
(ydf.1 • zdg.1, fun i (dy : YI i) => ydf.2 (inner (structProj zdg.1 i) dy) (zdg.2 i (conj ydf.1•dy))) :=
(ydf.1 • zdg.1,
fun i (dy : YI i) =>
let dk := inner (structProj zdg.1 i) dy
let dx := zdg.2 i dy
let dx := conj ydf.1•dx
ydf.2 dk dx) :=
by
unfold revDerivProj
ftrans; simp[revDerivUpdate,smul_push,revDeriv]
unfold revDerivProj; ftrans

@[ftrans]
theorem HSMul.hSMul.arg_a0a1.revDerivProjUpdate_rule
Expand All @@ -1405,10 +1467,16 @@ theorem HSMul.hSMul.arg_a0a1.revDerivProjUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivProjUpdate K Yi g x
(ydf.1 • zdg.1, fun i (dy : YI i) dx => ydf.2 (inner (structProj zdg.1 i) dy) (zdg.2 i (conj ydf.1•dy) dx)) :=
(ydf.1 • zdg.1,
fun i (dy : YI i) dx =>
let dk := inner (structProj zdg.1 i) dy
let dy' := conj ydf.1•dy
let dx := zdg.2 i dy' dx
ydf.2 dk dx) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]
ftrans; simp[revDerivUpdate,add_assoc,smul_pull]
simp only [smul_pull,revDerivProj,revDeriv]



Expand Down
1 change: 0 additions & 1 deletion test/basic_gradients.lean
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,5 @@ example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5))
⊞ i => ∑ (j : (Idx' (-5) 5 × Idx' (-5) 5)), w[(j.2,j.1)] * dy[(-j.2.1 +ᵥ i.fst, -j.1.1 +ᵥ i.snd)] :=
by
conv => lhs; unfold gradient; ftrans
sorry_proof


0 comments on commit aea3ae2

Please sign in to comment.