From 741eb98a1e321e763f1947eb604e7aadc90a1d25 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Fri, 1 Dec 2023 08:56:12 -0500 Subject: [PATCH] simpler simp theorems for oneHot on Prod type --- .../FunctionTransformations/RevDeriv.lean | 10 +---- SciLean/Data/StructType/Algebra.lean | 42 ++++++------------- SciLean/Data/StructType/Basic.lean | 1 + 3 files changed, 14 insertions(+), 39 deletions(-) diff --git a/SciLean/Core/FunctionTransformations/RevDeriv.lean b/SciLean/Core/FunctionTransformations/RevDeriv.lean index fe6a2507..e9d648c1 100644 --- a/SciLean/Core/FunctionTransformations/RevDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevDeriv.lean @@ -1027,11 +1027,7 @@ theorem Prod.fst.arg_self.revDerivProj_rule by unfold revDerivProj funext x; ftrans; simp[revDerivProj] - funext e dxy - simp[structMake, oneHot] - apply congr_arg - congr; funext i; congr; funext h; subst h; rfl - + @[ftrans] theorem Prod.fst.arg_self.revDerivProjUpdate_rule (f : W → X'×Y) (hf : HasAdjDiff K f) @@ -1085,10 +1081,6 @@ theorem Prod.snd.arg_self.revDerivProj_rule by unfold revDerivProj funext x; ftrans; simp[revDerivProj] - funext e dxy - simp[structMake, oneHot] - apply congr_arg - congr; funext i; congr; funext h; subst h; rfl @[ftrans] theorem Prod.snd.arg_self.revDerivProjUpdate_rule diff --git a/SciLean/Data/StructType/Algebra.lean b/SciLean/Data/StructType/Algebra.lean index 8a99e55d..13c1b8ee 100644 --- a/SciLean/Data/StructType/Algebra.lean +++ b/SciLean/Data/StructType/Algebra.lean @@ -341,44 +341,26 @@ variable [DecidableEq I] [DecidableEq J] @[simp, ftrans_simp] -theorem oneHot_inl_fst (i : I) (xi : EI i) - : (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).1 +theorem oneHot_inl (i : I) (xi : EI i) + : (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi) = - oneHot i xi := -by - simp[oneHot, structMake]; - congr; funext; congr - funext h; subst h; rfl - -@[simp, ftrans_simp] -theorem oneHot_inl_snd (i : I) (xi : EI i) - : (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).2 - = - (0 : F) := + (oneHot i xi, 0) := by simp[oneHot, structMake] - apply structExt (I:=J) - simp[ZeroStruct.structProj_zero] + constructor + . congr; funext; congr; funext h; subst h; rfl + . apply structExt (I:=J); simp [ZeroStruct.structProj_zero] @[simp, ftrans_simp] -theorem oneHot_inr_fst (j : J) (yj : FJ j) - : (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).1 +theorem oneHot_inr (j : J) (xj : FJ j) + : (oneHot (X:=E×F) (I:=I⊕J) (.inr j) xj) = - (0 : E) := + (0, oneHot j xj) := by simp[oneHot, structMake] - apply structExt (I:=I) - simp[ZeroStruct.structProj_zero] - -@[simp, ftrans_simp] -theorem oneHot_inr_snd (j : J) (yj : FJ j) - : (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).2 - = - oneHot j yj := -by - simp[oneHot, structMake]; - congr; funext; congr - funext h; subst h; rfl + constructor + . apply structExt (I:=I); simp [ZeroStruct.structProj_zero] + . congr; funext; congr; funext h; subst h; rfl end OneHotSimp diff --git a/SciLean/Data/StructType/Basic.lean b/SciLean/Data/StructType/Basic.lean index 3467dc95..f92ca477 100644 --- a/SciLean/Data/StructType/Basic.lean +++ b/SciLean/Data/StructType/Basic.lean @@ -20,6 +20,7 @@ class StructType (X : Sort _) (I : (Sort _)) (XI : outParam <| I → Sort _) whe attribute [simp, ftrans_simp] StructType.structProj_structModify StructType.structProj_structModify' export StructType (structProj structMake structModify) +attribute [simp, ftrans_simp] structProj structMake structModify def oneHot {X I XI} [StructType X I XI] [DecidableEq I] [∀ i, Zero (XI i)] (i : I) (xi : XI i) : X := structMake fun i' =>