Skip to content

Commit

Permalink
experiments with dense layer and revDeriv
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 2, 2023
1 parent ab1eb62 commit 593205c
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions SciLean/Modules/ML/DenseLayer.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,93 @@ variable {κ}
#generate_revCDeriv' dense weights bias x
prop_by unfold dense; fprop
trans_by unfold dense; autodiff



#check Nat
#check Function.repeatIdx

-- @[simp, ftrans_simp]
-- theorem _root_.Function.repeatIdx_modify {ι} [EnumType ι]
-- (g : ι → α → α)
-- : Function.repeatIdx (fun i f => Function.modify f i (g i))
-- =
-- fun f i => g i (f i) := sorry_proof


-- @[simp, ftrans_simp]
-- theorem _root_.Function.repeatIdx_add {ι} [EnumType ι] [Zero α] [Add α]
-- (f : ι → κ → α)
-- : Function.repeatIdx (fun i x j => x j + f i j)
-- =
-- fun (x : κ → α) j => x j + ∑ i, f i j := sorry_proof

-- @[simp, ftrans_simp]
-- theorem _root_.Function.repeatIdx_repeatIdx {ι κ} [EnumType ι] [EnumType κ]
-- (f : ι → κ → α → α)
-- : Function.repeatIdx (fun i x => (Function.repeatIdx fun j x => f i j x) x)
-- =
-- Function.repeatIdx (fun (ij : ι×κ) x => f ij.1 ij.2 x) := sorry_proof


section lazy
variable (weights : κ → ι → K) (bias : κ → K) (x : ι → K)

attribute [ftrans_simp] Pi.zero_apply

set_option trace.Meta.Tactic.simp.rewrite true in
#check
(revDeriv K fun (x : ι → K) => denseLazy κ weights bias x)
rewrite_by
unfold denseLazy
ftrans; ftrans

set_option pp.funBinderTypes true in
#check
(revDeriv K fun (weights : κ → ι → K) => denseLazy κ weights bias x)
rewrite_by
unfold denseLazy
ftrans; ftrans; simp
end lazy


section dense
variable (weights : DataArrayN K (κ×ι)) (bias : K ^ κ) (x : K ^ ι)

attribute [ftrans_simp] Pi.zero_apply


set_option trace.Meta.Tactic.simp.rewrite true in
#check
(revDeriv K fun (x : K ^ ι) => dense κ weights bias x)
rewrite_by
unfold dense; unfold denseLazy; dsimp
ftrans; ftrans


set_option pp.funBinderTypes true in
#check
(revDeriv K fun (weights : DataArrayN K (κ×ι)) => dense κ weights bias x)
rewrite_by
unfold dense; unfold denseLazy
ftrans; ftrans; simp


set_option pp.funBinderTypes true in
#check
(revDeriv K fun (wx : DataArrayN K (κ×ι) × K ^ ι) => dense κ wx.1 bias wx.2)
rewrite_by
unfold dense; unfold denseLazy
ftrans; ftrans; simp


#check
(revDeriv K fun (wx : DataArrayN K (κ×ι) × K^κ × K^ι) => dense κ wx.1 wx.2.1 wx.2.2)
rewrite_by
unfold dense; unfold denseLazy
ftrans


end dense


0 comments on commit 593205c

Please sign in to comment.