Skip to content

Commit

Permalink
fixed type class diamonds when infering Vec K (Sum.rec EI FJ i)
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 29, 2023
1 parent 0e8edff commit 3843717
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 54 deletions.
32 changes: 16 additions & 16 deletions SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def map [ArrayType Cont Idx Elem] [EnumType Idx] (f : Elem → Elem) (arr : Cont
theorem getElem_map [ArrayType Cont Idx Elem] [EnumType Idx] (f : Elem → Elem) (arr : Cont) (i : Idx)
: (map f arr)[i] = f arr[i] := sorry_proof

instance [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Cont) := ⟨λ x => Id.run do
instance (priority:=low) [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Cont) := ⟨λ x => Id.run do
let mut fst := true
let mut s := "⊞["
for i in fullRange Idx do
Expand Down Expand Up @@ -164,24 +164,24 @@ section Operations

variable [ArrayType Cont Idx Elem] [EnumType Idx]

instance [Add Elem] : Add Cont := ⟨λ f g => mapIdx (λ x fx => fx + g[x]) f⟩
instance [Sub Elem] : Sub Cont := ⟨λ f g => mapIdx (λ x fx => fx - g[x]) f⟩
instance [Mul Elem] : Mul Cont := ⟨λ f g => mapIdx (λ x fx => fx * g[x]) f⟩
instance [Div Elem] : Div Cont := ⟨λ f g => mapIdx (λ x fx => fx / g[x]) f⟩
instance (priority:=low) [Add Elem] : Add Cont := ⟨λ f g => mapIdx (λ x fx => fx + g[x]) f⟩
instance (priority:=low) [Sub Elem] : Sub Cont := ⟨λ f g => mapIdx (λ x fx => fx - g[x]) f⟩
instance (priority:=low) [Mul Elem] : Mul Cont := ⟨λ f g => mapIdx (λ x fx => fx * g[x]) f⟩
instance (priority:=low) [Div Elem] : Div Cont := ⟨λ f g => mapIdx (λ x fx => fx / g[x]) f⟩

-- instance {R} [HMul R Elem Elem] : HMul R Cont Cont := ⟨λ r f => map (λ fx => r*(fx : Elem)) f⟩
instance {R} [SMul R Elem] : SMul R Cont := ⟨λ r f => map (λ fx => r•(fx : Elem)) f⟩
-- instance (priority:=low) {R} [HMul R Elem Elem] : HMul R Cont Cont := ⟨λ r f => map (λ fx => r*(fx : Elem)) f⟩
instance (priority:=low) {R} [SMul R Elem] : SMul R Cont := ⟨λ r f => map (λ fx => r•(fx : Elem)) f⟩

instance [Neg Elem] : Neg Cont := ⟨λ f => map (λ fx => -(fx : Elem)) f⟩
instance [Inv Elem] : Inv Cont := ⟨λ f => map (λ fx => (fx : Elem)⁻¹) f⟩
instance (priority:=low) [Neg Elem] : Neg Cont := ⟨λ f => map (λ fx => -(fx : Elem)) f⟩
instance (priority:=low) [Inv Elem] : Inv Cont := ⟨λ f => map (λ fx => (fx : Elem)⁻¹) f⟩

instance [One Elem] : One Cont := ⟨introElem λ _ : Idx => 1
instance [Zero Elem] : Zero Cont := ⟨introElem λ _ : Idx => 0
instance (priority:=low) [One Elem] : One Cont := ⟨introElem λ _ : Idx => 1
instance (priority:=low) [Zero Elem] : Zero Cont := ⟨introElem λ _ : Idx => 0

instance [LT Elem] : LT Cont := ⟨λ f g => ∀ x, f[x] < g[x]⟩
instance [LE Elem] : LE Cont := ⟨λ f g => ∀ x, f[x] ≤ g[x]⟩
instance (priority:=low) [LT Elem] : LT Cont := ⟨λ f g => ∀ x, f[x] < g[x]⟩
instance (priority:=low) [LE Elem] : LE Cont := ⟨λ f g => ∀ x, f[x] ≤ g[x]⟩

instance [DecidableEq Elem] : DecidableEq Cont :=
instance (priority:=low) [DecidableEq Elem] : DecidableEq Cont :=
λ f g => Id.run do
let mut eq : Bool := true
for x in fullRange Idx do
Expand All @@ -190,15 +190,15 @@ section Operations
break
if eq then isTrue sorry_proof else isFalse sorry_proof

instance [LT Elem] [∀ x y : Elem, Decidable (x < y)] (f g : Cont) : Decidable (f < g) := Id.run do
instance (priority:=low) [LT Elem] [∀ x y : Elem, Decidable (x < y)] (f g : Cont) : Decidable (f < g) := Id.run do
let mut lt : Bool := true
for x in fullRange Idx do
if ¬(f[x] < g[x]) then
lt := false
break
if lt then isTrue sorry_proof else isFalse sorry_proof

instance [LE Elem] [∀ x y : Elem, Decidable (x ≤ y)] (f g : Cont) : Decidable (f ≤ g) := Id.run do
instance (priority:=low) [LE Elem] [∀ x y : Elem, Decidable (x ≤ y)] (f g : Cont) : Decidable (f ≤ g) := Id.run do
let mut le : Bool := true
for x in fullRange Idx do
if ¬(f[x] ≤ g[x]) then
Expand Down
218 changes: 190 additions & 28 deletions SciLean/Data/StructType/Algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,140 @@ variable
[StructType F J FJ]


open StructType in
class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec K (XI i)] : Prop where
--------------------------------------------------------------------------------
-- Algebra instances for Sum.rec ------------------------------------------
--------------------------------------------------------------------------------
-- There are some issues with defEq

@[reducible]
instance [∀ i, Zero (EI i)] [∀ j, Zero (FJ j)] (i : I ⊕ J) : Zero (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, Add (EI i)] [∀ j, Add (FJ j)] (i : I ⊕ J) : Add (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, SMul K (EI i)] [∀ j, SMul K (FJ j)] (i : I ⊕ J) : SMul K (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, Neg (EI i)] [∀ j, Neg (FJ j)] (i : I ⊕ J) : Neg (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, Sub (EI i)] [∀ j, Sub (FJ j)] (i : I ⊕ J) : Sub (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, TopologicalSpace (EI i)] [∀ j, TopologicalSpace (FJ j)] (i : I ⊕ J) : TopologicalSpace (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] (i : I ⊕ J) : Vec K (Sum.rec EI FJ i) := Vec.mkSorryProofs
-- all the proofs should be solvable `by induction i <;> infer_instance`

@[reducible]
instance [∀ i, Inner K (EI i)] [∀ j, Inner K (FJ j)] (i : I ⊕ J) : Inner K (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, TestFunctions (EI i)] [∀ j, TestFunctions (FJ j)] (i : I ⊕ J) : TestFunctions (Sum.rec EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[reducible]
instance [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] (i : I ⊕ J)
: SemiInnerProductSpace K (Sum.rec EI FJ i) := SemiInnerProductSpace.mkSorryProofs

@[reducible]
instance [∀ i, SemiHilbert K (EI i)] [∀ j, SemiHilbert K (FJ j)] (i : I ⊕ J)
: SemiHilbert K (Sum.rec EI FJ i) where
test_functions_true := by induction i <;> apply SemiHilbert.test_functions_true

-- instance [∀ i, FinVec ι K (EI i)] [∀ j, FinVec ι K (FJ j)] (i : I ⊕ J)
-- : FinVec ι K (Sum.rec EI FJ i) :=
-- match i with
-- | .inl _ => by infer_instance
-- | .inr _ => by infer_instance

--------------------------------------------------------------------------------
-- Algebraic struct classes ----------------------------------------------------
--------------------------------------------------------------------------------

class ZeroStruct (X I XI) [StructType X I XI] [Zero X] [∀ i, Zero (XI i)] : Prop where
structProj_zero : ∀ (i : I), structProj (0 : X) i = 0

class AddStruct (X I XI) [StructType X I XI] [Add X] [∀ i, Add (XI i)] : Prop where
structProj_add : ∀ (i : I) (x x' : X), structProj (x + x') i = structProj x i + structProj x' i

class SMulStruct (K X I XI) [StructType X I XI] [SMul K X] [∀ i, SMul K (XI i)] : Prop where
structProj_smul : ∀ (i : I) (k : K) (x : X), structProj (k • x) i = k • structProj x i
structProj_continuous : Continuous (fun (x : X) (i : I) => structProj x i)
structMake_continuous : Continuous (fun (f : (i : I) → XI i) => structMake (X:=X) f)

class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec K (XI i)]
extends ZeroStruct X I XI, AddStruct X I XI, SMulStruct K X I XI : Prop
where
structProj_continuous : Continuous (fun (x : X) (i : I) => structProj x i)
structMake_continuous : Continuous (fun (f : (i : I) → XI i) => structMake (X:=X) f)

--------------------------------------------------------------------------------
-- ZeroStruct instances ---------------------------------------------------------
--------------------------------------------------------------------------------

instance (priority:=low) instZeroStructDefault
{X} [Zero X] : ZeroStruct X Unit (fun _ => X) where
structProj_zero := by simp[structProj]

instance instZeroStructProd
[Zero E] [Zero F] [∀ i, Zero (EI i)] [∀ j, Zero (FJ j)]
[ZeroStruct E I EI] [ZeroStruct F J FJ]
: ZeroStruct (E×F) (I⊕J) (Sum.rec EI FJ) where
structProj_zero := by simp[structProj, ZeroStruct.structProj_zero]


--------------------------------------------------------------------------------
-- AddStruct instances ---------------------------------------------------------
--------------------------------------------------------------------------------

instance (priority:=low) instAddStructDefault
{X} [Add X] : AddStruct X Unit (fun _ => X) where
structProj_add := by simp[structProj]

instance instAddStructProd
[Add E] [Add F] [∀ i, Add (EI i)] [∀ j, Add (FJ j)]
[AddStruct E I EI] [AddStruct F J FJ]
: AddStruct (E×F) (I⊕J) (Sum.rec EI FJ) where
structProj_add := by simp[structProj, AddStruct.structProj_add]


--------------------------------------------------------------------------------
-- SMulStruct instances ---------------------------------------------------------
--------------------------------------------------------------------------------

instance (priority:=low) instSMulStructDefault
{X} [SMul K X] : SMulStruct K X Unit (fun _ => X) where
structProj_smul := by simp[structProj]

instance instSMulStructProd
[SMul K E] [SMul K F] [∀ i, SMul K (EI i)] [∀ j, SMul K (FJ j)]
[SMulStruct K E I EI] [SMulStruct K F J FJ]
: SMulStruct K (E×F) (I⊕J) (Sum.rec EI FJ) where
structProj_smul := by simp[structProj, SMulStruct.structProj_smul]


--------------------------------------------------------------------------------
Expand All @@ -34,23 +162,16 @@ class VecStruct (K X I XI) [StructType X I XI] [IsROrC K] [Vec K X] [∀ i, Vec

instance (priority:=low) instVecStructDefault
{X} [Vec K X] : VecStruct K X Unit (fun _ => X) where
structProj_zero := by simp[structProj]
structProj_add := by simp[structProj]
structProj_smul := by simp[structProj]
structProj_continuous := sorry_proof
structMake_continuous := sorry_proof

@[reducible]
instance [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] (i : I ⊕ J) : Vec K (Prod.TypeFun EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

instance instVecStructProd
[Vec K E] [Vec K F] [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)]
[Vec K E] [Vec K F] [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)]
[VecStruct K E I EI] [VecStruct K F J FJ]
: VecStruct K (E×F) (I⊕J) (Prod.TypeFun EI FJ) where
structProj_add := by simp[structProj, VecStruct.structProj_add]
structProj_smul := by simp[structProj, VecStruct.structProj_smul]
: VecStruct K (E×F) (I⊕J) (Sum.rec EI FJ) where
structProj_continuous := sorry_proof
structMake_continuous := sorry_proof

Expand Down Expand Up @@ -187,29 +308,16 @@ instance (priority:=low) {X} [SemiInnerProductSpace K X] : SemiInnerProductSpace
testFun_structProj := sorry_proof


instance [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] (i : I ⊕ J)
: SemiInnerProductSpace K (Prod.TypeFun EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance


instance
[SemiInnerProductSpace K E] [SemiInnerProductSpace K F]
[∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)]
[EnumType I] [EnumType J]
[SemiInnerProductSpaceStruct K E I EI] [SemiInnerProductSpaceStruct K F J FJ]
: SemiInnerProductSpaceStruct K (E×F) (I⊕J) (Prod.TypeFun EI FJ) := sorry_proof
: SemiInnerProductSpaceStruct K (E×F) (I⊕J) (Sum.rec EI FJ) := sorry_proof
-- inner_structProj := sorry_proof
-- testFun_structProj := sorry_proof


instance [∀ i, FinVec ι K (EI i)] [∀ j, FinVec ι K (FJ j)] (i : I ⊕ J)
: FinVec ι K (Prod.TypeFun EI FJ i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

@[simp, ftrans_simp]
theorem inner_oneHot_eq_inner_structProj [StructType X I XI] [EnumType I] [∀ i, SemiInnerProductSpace K (XI i)] [SemiInnerProductSpace K X] [SemiInnerProductSpaceStruct K X I XI] (i : I) (xi : XI i) (x : X)
: ⟪x, oneHot i xi⟫[K] = ⟪structProj x i, xi⟫[K] := sorry_proof
Expand All @@ -220,3 +328,57 @@ theorem inner_oneHot_eq_inner_proj' [StructType X I XI] [EnumType I] [∀ i, Sem



--------------------------------------------------------------------------------
-- Prod simp lemmas
-- TODO: move somewhere else
--------------------------------------------------------------------------------

section OneHotSimp

variable
[Zero E] [∀ i, Zero (EI i)] [ZeroStruct E I EI]
[Zero F] [∀ j, Zero (FJ j)] [ZeroStruct F J FJ]
[DecidableEq I] [DecidableEq J]

@[simp, ftrans_simp]
theorem oneHot_inl_fst (i : I) (xi : EI i)
: (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).1
=
oneHot i xi :=
by
simp[oneHot, structMake];
congr; funext; congr
funext h; subst h; rfl

@[simp, ftrans_simp]
theorem oneHot_inl_snd (i : I) (xi : EI i)
: (oneHot (X:=E×F) (I:=I⊕J) (.inl i) xi).2
=
(0 : F) :=
by
simp[oneHot, structMake]
apply structExt (I:=J)
simp[ZeroStruct.structProj_zero]

@[simp, ftrans_simp]
theorem oneHot_inr_fst (j : J) (yj : FJ j)
: (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).1
=
(0 : E) :=
by
simp[oneHot, structMake]
apply structExt (I:=I)
simp[ZeroStruct.structProj_zero]

@[simp, ftrans_simp]
theorem oneHot_inr_snd (j : J) (yj : FJ j)
: (oneHot (X:=E×F) (I:=I⊕J) (.inr j) yj).2
=
oneHot j yj :=
by
simp[oneHot, structMake];
congr; funext; congr
funext h; subst h; rfl

end OneHotSimp

Loading

0 comments on commit 3843717

Please sign in to comment.