Skip to content

Commit

Permalink
simpler simp theorems for oneHot on Prod type
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 1, 2023
1 parent a38289f commit 741eb98
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 39 deletions.
10 changes: 1 addition & 9 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,7 @@ theorem Prod.fst.arg_self.revDerivProj_rule
by
unfold revDerivProj
funext x; ftrans; simp[revDerivProj]
funext e dxy
simp[structMake, oneHot]
apply congr_arg
congr; funext i; congr; funext h; subst h; rfl


@[ftrans]
theorem Prod.fst.arg_self.revDerivProjUpdate_rule
(f : W → X'×Y) (hf : HasAdjDiff K f)
Expand Down Expand Up @@ -1085,10 +1081,6 @@ theorem Prod.snd.arg_self.revDerivProj_rule
by
unfold revDerivProj
funext x; ftrans; simp[revDerivProj]
funext e dxy
simp[structMake, oneHot]
apply congr_arg
congr; funext i; congr; funext h; subst h; rfl

@[ftrans]
theorem Prod.snd.arg_self.revDerivProjUpdate_rule
Expand Down
42 changes: 12 additions & 30 deletions SciLean/Data/StructType/Algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -341,44 +341,26 @@ variable
[DecidableEq I] [DecidableEq J]

@[simp, ftrans_simp]
theorem oneHot_inl_fst (i : I) (xi : EI i)
: (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).1
theorem oneHot_inl (i : I) (xi : EI i)
: (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi)
=
oneHot i xi :=
by
simp[oneHot, structMake];
congr; funext; congr
funext h; subst h; rfl

@[simp, ftrans_simp]
theorem oneHot_inl_snd (i : I) (xi : EI i)
: (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).2
=
(0 : F) :=
(oneHot i xi, 0) :=
by
simp[oneHot, structMake]
apply structExt (I:=J)
simp[ZeroStruct.structProj_zero]
constructor
. congr; funext; congr; funext h; subst h; rfl
. apply structExt (I:=J); simp [ZeroStruct.structProj_zero]

@[simp, ftrans_simp]
theorem oneHot_inr_fst (j : J) (yj : FJ j)
: (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).1
theorem oneHot_inr (j : J) (xj : FJ j)
: (oneHot (X:=E×F) (I:=I⊕J) (.inr j) xj)
=
(0 : E) :=
(0, oneHot j xj) :=
by
simp[oneHot, structMake]
apply structExt (I:=I)
simp[ZeroStruct.structProj_zero]

@[simp, ftrans_simp]
theorem oneHot_inr_snd (j : J) (yj : FJ j)
: (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).2
=
oneHot j yj :=
by
simp[oneHot, structMake];
congr; funext; congr
funext h; subst h; rfl
constructor
. apply structExt (I:=I); simp [ZeroStruct.structProj_zero]
. congr; funext; congr; funext h; subst h; rfl

end OneHotSimp

1 change: 1 addition & 0 deletions SciLean/Data/StructType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class StructType (X : Sort _) (I : (Sort _)) (XI : outParam <| I → Sort _) whe

attribute [simp, ftrans_simp] StructType.structProj_structModify StructType.structProj_structModify'
export StructType (structProj structMake structModify)
attribute [simp, ftrans_simp] structProj structMake structModify

def oneHot {X I XI} [StructType X I XI] [DecidableEq I] [∀ i, Zero (XI i)] (i : I) (xi : XI i) : X :=
structMake fun i' =>
Expand Down

0 comments on commit 741eb98

Please sign in to comment.