From 38437176b9fea3b71f9ddc65aa29b863e896f1e9 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Wed, 29 Nov 2023 18:12:35 -0500 Subject: [PATCH] fixed type class diamonds when infering `Vec K (Sum.rec EI FJ i)` --- SciLean/Data/ArrayType/Basic.lean | 32 ++-- SciLean/Data/StructType/Algebra.lean | 218 +++++++++++++++++++++++---- SciLean/Data/StructType/Basic.lean | 60 ++++++-- SciLean/Tactic/FTrans/Simp.lean | 9 ++ 4 files changed, 265 insertions(+), 54 deletions(-) diff --git a/SciLean/Data/ArrayType/Basic.lean b/SciLean/Data/ArrayType/Basic.lean index 053bdffa..d27a87c6 100644 --- a/SciLean/Data/ArrayType/Basic.lean +++ b/SciLean/Data/ArrayType/Basic.lean @@ -130,7 +130,7 @@ def map [ArrayType Cont Idx Elem] [EnumType Idx] (f : Elem → Elem) (arr : Cont theorem getElem_map [ArrayType Cont Idx Elem] [EnumType Idx] (f : Elem → Elem) (arr : Cont) (i : Idx) : (map f arr)[i] = f arr[i] := sorry_proof -instance [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Cont) := ⟨λ x => Id.run do +instance (priority:=low) [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Cont) := ⟨λ x => Id.run do let mut fst := true let mut s := "⊞[" for i in fullRange Idx do @@ -164,24 +164,24 @@ section Operations variable [ArrayType Cont Idx Elem] [EnumType Idx] - instance [Add Elem] : Add Cont := ⟨λ f g => mapIdx (λ x fx => fx + g[x]) f⟩ - instance [Sub Elem] : Sub Cont := ⟨λ f g => mapIdx (λ x fx => fx - g[x]) f⟩ - instance [Mul Elem] : Mul Cont := ⟨λ f g => mapIdx (λ x fx => fx * g[x]) f⟩ - instance [Div Elem] : Div Cont := ⟨λ f g => mapIdx (λ x fx => fx / g[x]) f⟩ + instance (priority:=low) [Add Elem] : Add Cont := ⟨λ f g => mapIdx (λ x fx => fx + g[x]) f⟩ + instance (priority:=low) [Sub Elem] : Sub Cont := ⟨λ f g => mapIdx (λ x fx => fx - g[x]) f⟩ + instance (priority:=low) [Mul Elem] : Mul Cont := ⟨λ f g => mapIdx (λ x fx => fx * g[x]) f⟩ + instance (priority:=low) [Div Elem] : Div Cont := ⟨λ f g => mapIdx (λ x fx => fx / g[x]) f⟩ - -- instance {R} [HMul R Elem Elem] : HMul R Cont Cont := ⟨λ r f => map (λ fx => r*(fx : Elem)) f⟩ - instance {R} [SMul R Elem] : SMul R Cont := ⟨λ r f => map (λ fx => r•(fx : Elem)) f⟩ + -- instance (priority:=low) {R} [HMul R Elem Elem] : HMul R Cont Cont := ⟨λ r f => map (λ fx => r*(fx : Elem)) f⟩ + instance (priority:=low) {R} [SMul R Elem] : SMul R Cont := ⟨λ r f => map (λ fx => r•(fx : Elem)) f⟩ - instance [Neg Elem] : Neg Cont := ⟨λ f => map (λ fx => -(fx : Elem)) f⟩ - instance [Inv Elem] : Inv Cont := ⟨λ f => map (λ fx => (fx : Elem)⁻¹) f⟩ + instance (priority:=low) [Neg Elem] : Neg Cont := ⟨λ f => map (λ fx => -(fx : Elem)) f⟩ + instance (priority:=low) [Inv Elem] : Inv Cont := ⟨λ f => map (λ fx => (fx : Elem)⁻¹) f⟩ - instance [One Elem] : One Cont := ⟨introElem λ _ : Idx => 1⟩ - instance [Zero Elem] : Zero Cont := ⟨introElem λ _ : Idx => 0⟩ + instance (priority:=low) [One Elem] : One Cont := ⟨introElem λ _ : Idx => 1⟩ + instance (priority:=low) [Zero Elem] : Zero Cont := ⟨introElem λ _ : Idx => 0⟩ - instance [LT Elem] : LT Cont := ⟨λ f g => ∀ x, f[x] < g[x]⟩ - instance [LE Elem] : LE Cont := ⟨λ f g => ∀ x, f[x] ≤ g[x]⟩ + instance (priority:=low) [LT Elem] : LT Cont := ⟨λ f g => ∀ x, f[x] < g[x]⟩ + instance (priority:=low) [LE Elem] : LE Cont := ⟨λ f g => ∀ x, f[x] ≤ g[x]⟩ - instance [DecidableEq Elem] : DecidableEq Cont := + instance (priority:=low) [DecidableEq Elem] : DecidableEq Cont := λ f g => Id.run do let mut eq : Bool := true for x in fullRange Idx do @@ -190,7 +190,7 @@ section Operations break if eq then isTrue sorry_proof else isFalse sorry_proof - instance [LT Elem] [∀ x y : Elem, Decidable (x < y)] (f g : Cont) : Decidable (f < g) := Id.run do + instance (priority:=low) [LT Elem] [∀ x y : Elem, Decidable (x < y)] (f g : Cont) : Decidable (f < g) := Id.run do let mut lt : Bool := true for x in fullRange Idx do if ¬(f[x] < g[x]) then @@ -198,7 +198,7 @@ section Operations break if lt then isTrue sorry_proof else isFalse sorry_proof - instance [LE Elem] [∀ x y : Elem, Decidable (x ≤ y)] (f g : Cont) : Decidable (f ≤ g) := Id.run do + instance (priority:=low) [LE Elem] [∀ x y : Elem, Decidable (x ≤ y)] (f g : Cont) : Decidable (f ≤ g) := Id.run do let mut le : Bool := true for x in fullRange Idx do if ¬(f[x] ≤ g[x]) then diff --git a/SciLean/Data/StructType/Algebra.lean b/SciLean/Data/StructType/Algebra.lean index 12cc2be8..8a99e55d 100644 --- a/SciLean/Data/StructType/Algebra.lean +++ b/SciLean/Data/StructType/Algebra.lean @@ -20,12 +20,140 @@ variable [StructType F J FJ] -open StructType in -class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec K (XI i)] : Prop where +-------------------------------------------------------------------------------- +-- Algebra instances for Sum.rec ------------------------------------------ +-------------------------------------------------------------------------------- +-- There are some issues with defEq + +@[reducible] +instance [∀ i, Zero (EI i)] [∀ j, Zero (FJ j)] (i : I ⊕ J) : Zero (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, Add (EI i)] [∀ j, Add (FJ j)] (i : I ⊕ J) : Add (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, SMul K (EI i)] [∀ j, SMul K (FJ j)] (i : I ⊕ J) : SMul K (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, Neg (EI i)] [∀ j, Neg (FJ j)] (i : I ⊕ J) : Neg (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, Sub (EI i)] [∀ j, Sub (FJ j)] (i : I ⊕ J) : Sub (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, TopologicalSpace (EI i)] [∀ j, TopologicalSpace (FJ j)] (i : I ⊕ J) : TopologicalSpace (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] (i : I ⊕ J) : Vec K (Sum.rec EI FJ i) := Vec.mkSorryProofs +-- all the proofs should be solvable `by induction i <;> infer_instance` + +@[reducible] +instance [∀ i, Inner K (EI i)] [∀ j, Inner K (FJ j)] (i : I ⊕ J) : Inner K (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, TestFunctions (EI i)] [∀ j, TestFunctions (FJ j)] (i : I ⊕ J) : TestFunctions (Sum.rec EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +@[reducible] +instance [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] (i : I ⊕ J) + : SemiInnerProductSpace K (Sum.rec EI FJ i) := SemiInnerProductSpace.mkSorryProofs + +@[reducible] +instance [∀ i, SemiHilbert K (EI i)] [∀ j, SemiHilbert K (FJ j)] (i : I ⊕ J) + : SemiHilbert K (Sum.rec EI FJ i) where + test_functions_true := by induction i <;> apply SemiHilbert.test_functions_true + +-- instance [∀ i, FinVec ι K (EI i)] [∀ j, FinVec ι K (FJ j)] (i : I ⊕ J) +-- : FinVec ι K (Sum.rec EI FJ i) := +-- match i with +-- | .inl _ => by infer_instance +-- | .inr _ => by infer_instance + +-------------------------------------------------------------------------------- +-- Algebraic struct classes ---------------------------------------------------- +-------------------------------------------------------------------------------- + +class ZeroStruct (X I XI) [StructType X I XI] [Zero X] [∀ i, Zero (XI i)] : Prop where + structProj_zero : ∀ (i : I), structProj (0 : X) i = 0 + +class AddStruct (X I XI) [StructType X I XI] [Add X] [∀ i, Add (XI i)] : Prop where structProj_add : ∀ (i : I) (x x' : X), structProj (x + x') i = structProj x i + structProj x' i + +class SMulStruct (K X I XI) [StructType X I XI] [SMul K X] [∀ i, SMul K (XI i)] : Prop where structProj_smul : ∀ (i : I) (k : K) (x : X), structProj (k • x) i = k • structProj x i - structProj_continuous : Continuous (fun (x : X) (i : I) => structProj x i) - structMake_continuous : Continuous (fun (f : (i : I) → XI i) => structMake (X:=X) f) + +class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec K (XI i)] + extends ZeroStruct X I XI, AddStruct X I XI, SMulStruct K X I XI : Prop + where + structProj_continuous : Continuous (fun (x : X) (i : I) => structProj x i) + structMake_continuous : Continuous (fun (f : (i : I) → XI i) => structMake (X:=X) f) + +-------------------------------------------------------------------------------- +-- ZeroStruct instances --------------------------------------------------------- +-------------------------------------------------------------------------------- + +instance (priority:=low) instZeroStructDefault + {X} [Zero X] : ZeroStruct X Unit (fun _ => X) where + structProj_zero := by simp[structProj] + +instance instZeroStructProd + [Zero E] [Zero F] [∀ i, Zero (EI i)] [∀ j, Zero (FJ j)] + [ZeroStruct E I EI] [ZeroStruct F J FJ] + : ZeroStruct (E×F) (I⊕J) (Sum.rec EI FJ) where + structProj_zero := by simp[structProj, ZeroStruct.structProj_zero] + + +-------------------------------------------------------------------------------- +-- AddStruct instances --------------------------------------------------------- +-------------------------------------------------------------------------------- + +instance (priority:=low) instAddStructDefault + {X} [Add X] : AddStruct X Unit (fun _ => X) where + structProj_add := by simp[structProj] + +instance instAddStructProd + [Add E] [Add F] [∀ i, Add (EI i)] [∀ j, Add (FJ j)] + [AddStruct E I EI] [AddStruct F J FJ] + : AddStruct (E×F) (I⊕J) (Sum.rec EI FJ) where + structProj_add := by simp[structProj, AddStruct.structProj_add] + + +-------------------------------------------------------------------------------- +-- SMulStruct instances --------------------------------------------------------- +-------------------------------------------------------------------------------- + +instance (priority:=low) instSMulStructDefault + {X} [SMul K X] : SMulStruct K X Unit (fun _ => X) where + structProj_smul := by simp[structProj] + +instance instSMulStructProd + [SMul K E] [SMul K F] [∀ i, SMul K (EI i)] [∀ j, SMul K (FJ j)] + [SMulStruct K E I EI] [SMulStruct K F J FJ] + : SMulStruct K (E×F) (I⊕J) (Sum.rec EI FJ) where + structProj_smul := by simp[structProj, SMulStruct.structProj_smul] -------------------------------------------------------------------------------- @@ -34,23 +162,16 @@ class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec instance (priority:=low) instVecStructDefault {X} [Vec K X] : VecStruct K X Unit (fun _ => X) where + structProj_zero := by simp[structProj] structProj_add := by simp[structProj] structProj_smul := by simp[structProj] structProj_continuous := sorry_proof structMake_continuous := sorry_proof -@[reducible] -instance [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] (i : I ⊕ J) : Vec K (Prod.TypeFun EI FJ i) := - match i with - | .inl _ => by infer_instance - | .inr _ => by infer_instance - instance instVecStructProd - [Vec K E] [Vec K F] [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] + [Vec K E] [Vec K F] [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] [VecStruct K E I EI] [VecStruct K F J FJ] - : VecStruct K (E×F) (I⊕J) (Prod.TypeFun EI FJ) where - structProj_add := by simp[structProj, VecStruct.structProj_add] - structProj_smul := by simp[structProj, VecStruct.structProj_smul] + : VecStruct K (E×F) (I⊕J) (Sum.rec EI FJ) where structProj_continuous := sorry_proof structMake_continuous := sorry_proof @@ -187,29 +308,16 @@ instance (priority:=low) {X} [SemiInnerProductSpace K X] : SemiInnerProductSpace testFun_structProj := sorry_proof -instance [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] (i : I ⊕ J) - : SemiInnerProductSpace K (Prod.TypeFun EI FJ i) := - match i with - | .inl _ => by infer_instance - | .inr _ => by infer_instance - - instance [SemiInnerProductSpace K E] [SemiInnerProductSpace K F] [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] [EnumType I] [EnumType J] [SemiInnerProductSpaceStruct K E I EI] [SemiInnerProductSpaceStruct K F J FJ] - : SemiInnerProductSpaceStruct K (E×F) (I⊕J) (Prod.TypeFun EI FJ) := sorry_proof + : SemiInnerProductSpaceStruct K (E×F) (I⊕J) (Sum.rec EI FJ) := sorry_proof -- inner_structProj := sorry_proof -- testFun_structProj := sorry_proof -instance [∀ i, FinVec ι K (EI i)] [∀ j, FinVec ι K (FJ j)] (i : I ⊕ J) - : FinVec ι K (Prod.TypeFun EI FJ i) := - match i with - | .inl _ => by infer_instance - | .inr _ => by infer_instance - @[simp, ftrans_simp] theorem inner_oneHot_eq_inner_structProj [StructType X I XI] [EnumType I] [∀ i, SemiInnerProductSpace K (XI i)] [SemiInnerProductSpace K X] [SemiInnerProductSpaceStruct K X I XI] (i : I) (xi : XI i) (x : X) : ⟪x, oneHot i xi⟫[K] = ⟪structProj x i, xi⟫[K] := sorry_proof @@ -220,3 +328,57 @@ theorem inner_oneHot_eq_inner_proj' [StructType X I XI] [EnumType I] [∀ i, Sem +-------------------------------------------------------------------------------- +-- Prod simp lemmas +-- TODO: move somewhere else +-------------------------------------------------------------------------------- + +section OneHotSimp + +variable + [Zero E] [∀ i, Zero (EI i)] [ZeroStruct E I EI] + [Zero F] [∀ j, Zero (FJ j)] [ZeroStruct F J FJ] + [DecidableEq I] [DecidableEq J] + +@[simp, ftrans_simp] +theorem oneHot_inl_fst (i : I) (xi : EI i) + : (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).1 + = + oneHot i xi := +by + simp[oneHot, structMake]; + congr; funext; congr + funext h; subst h; rfl + +@[simp, ftrans_simp] +theorem oneHot_inl_snd (i : I) (xi : EI i) + : (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).2 + = + (0 : F) := +by + simp[oneHot, structMake] + apply structExt (I:=J) + simp[ZeroStruct.structProj_zero] + +@[simp, ftrans_simp] +theorem oneHot_inr_fst (j : J) (yj : FJ j) + : (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).1 + = + (0 : E) := +by + simp[oneHot, structMake] + apply structExt (I:=I) + simp[ZeroStruct.structProj_zero] + +@[simp, ftrans_simp] +theorem oneHot_inr_snd (j : J) (yj : FJ j) + : (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).2 + = + oneHot j yj := +by + simp[oneHot, structMake]; + congr; funext; congr + funext h; subst h; rfl + +end OneHotSimp + diff --git a/SciLean/Data/StructType/Basic.lean b/SciLean/Data/StructType/Basic.lean index fc4ee44b..ae119505 100644 --- a/SciLean/Data/StructType/Basic.lean +++ b/SciLean/Data/StructType/Basic.lean @@ -79,6 +79,11 @@ instance (priority:=low) instStructTypeDefault : StructType α Unit (fun _ => α structProj_structModify := by simp structProj_structModify' := by simp +@[simp, ftrans_simp] +theorem oneHot_unit {X} [Zero X] (x : X) + : oneHot (X:=X) (I:=Unit) () x = x := by rfl + + -- Pi -------------------------------------------------------------------------- -------------------------------------------------------------------------------- @@ -177,13 +182,9 @@ instance instStrucTypeArrow -- Prod ------------------------------------------------------------------------ -------------------------------------------------------------------------------- -abbrev _root_.Prod.TypeFun {I J: Type _} (EI : I → Type _) (FJ : J → Type _) (i : Sum I J) : Type _ := - match i with - | .inl a => EI a - | .inr b => FJ b - -instance instStrucTypeProd [StructType E I EI] [StructType F J FJ] - : StructType (E×F) (Sum I J) (Prod.TypeFun EI FJ) where +instance instStrucTypeProd + [StructType E I EI] [StructType F J FJ] + : StructType (E×F) (Sum I J) (Sum.rec EI FJ) where structProj := fun (x,y) i => match i with | .inl a => StructType.structProj x a @@ -199,7 +200,46 @@ instance instStrucTypeProd [StructType E I EI] [StructType F J FJ] structProj_structModify := by simp structProj_structModify' := by intro i j f x h; induction j <;> induction i <;> (simp at h; simp (disch:=assumption)) - - +-- @[simp, ftrans_simp] +-- theorem structMake_sum_match [StructType E I EI] [StructType F J FJ] (f : (i : I) → EI i) (g : (j : J) → FJ j) +-- : structMake (X:=E×F) (I:=I⊕J) (fun | .inl i => f i | .inr j => g j) +-- = +-- (structMake (X:=E) f, structMake (X:=F) g) := +-- by +-- simp[structMake] + +-- @[simp low, ftrans_simp low] +-- theorem structModify_inl [StructType E I EI] [StructType F J FJ] (i : I) (f : EI i → EI i) (xy : E×F) +-- : structModify (I:=I⊕J) (.inl i) f xy +-- = +-- {xy with fst := structModify i f xy.1} := +-- by +-- conv => +-- lhs +-- simp[structModify] + +-- @[simp, ftrans_simp] +-- theorem structModify_inl' [StructType E I EI] [StructType F J FJ] (i : I) (f : EI i → EI i) (x : E) (y : F) +-- : structModify (I:=I⊕J) (.inl i) f (x, y) +-- = +-- (structModify i f x, y) := +-- by +-- conv => +-- lhs +-- simp[structModify] - +-- @[simp low, ftrans_simp low] +-- theorem structModify_inr [StructType E I EI] [StructType F J FJ] (j : J) (f : FJ j → FJ j) (xy : E×F) +-- : structModify (I:=I⊕J) (.inr j) f xy +-- = +-- (xy.1, structModify j f xy.2) := +-- by +-- simp[structModify] + +-- @[simp, ftrans_simp] +-- theorem structModify_inr' [StructType E I EI] [StructType F J FJ] (j : J) (f : FJ j → FJ j) (x : E) (y : F) +-- : structModify (I:=I⊕J) (.inr j) f (x, y) +-- = +-- (x, structModify j f y) := +-- by +-- simp[structModify] diff --git a/SciLean/Tactic/FTrans/Simp.lean b/SciLean/Tactic/FTrans/Simp.lean index 60ed1800..4e982c7c 100644 --- a/SciLean/Tactic/FTrans/Simp.lean +++ b/SciLean/Tactic/FTrans/Simp.lean @@ -5,6 +5,15 @@ import Mathlib.Algebra.SMulWithZero namespace SciLean +-- basic algebraic operations attribute [ftrans_simp] Prod.mk_add_mk Prod.mk_mul_mk Prod.smul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self mul_one one_mul one_smul Prod.fst_zero Prod.snd_zero +-- simp theorems for `Equiv` attribute [ftrans_simp] Equiv.invFun_as_coe Equiv.symm_symm + +-- simp theorems for `if _ then _ else _` +attribute [ftrans_simp] dite_eq_ite eq_self ite_true ite_false dite_true dite_false + +-- simp theorems for `Sum` +attribute [ftrans_simp] Sum.inr.injEq Sum.inl.injEq +