Skip to content

Commit

Permalink
ArrayType is StructType
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 1, 2023
1 parent aea3ae2 commit fdb7a1c
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 30 deletions.
19 changes: 19 additions & 0 deletions SciLean/Data/ArrayType/Algebra.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Core.Objects.FinVec
import SciLean.Data.ArrayType.Basic
import SciLean.Data.StructType.Algebra

namespace SciLean
namespace GenericArrayType
Expand Down Expand Up @@ -143,3 +144,21 @@ instance (priority := low) [ArrayType Cont Idx K] : BasisDuality Cont where
-- to_dual := sorry_proof
-- from_dual := sorry_proof



instance [ArrayType Cont Idx Elem] [Zero Elem] : ZeroStruct Cont Idx (fun _ => Elem) where
structProj_zero := by intro i; simp[OfNat.ofNat,Zero.zero,ArrayType.introElem_structMake]

instance [ArrayType Cont Idx Elem] [Add Elem] : AddStruct Cont Idx (fun _ => Elem) where
structProj_add := by intro i; simp[HAdd.hAdd, Add.add,ArrayType.introElem_structMake, ← ArrayType.getElem_structProj]

instance {K} [ArrayType Cont Idx Elem] [SMul K Elem] : SMulStruct K Cont Idx (fun _ => Elem) where
structProj_smul := by intro i k x; simp[HSMul.hSMul, SMul.smul,ArrayType.introElem_structMake, ← ArrayType.getElem_structProj]

instance {K} [IsROrC K] [ArrayType Cont Idx Elem] [Vec K Elem] : VecStruct K Cont Idx (fun _ => Elem) where
structProj_continuous := sorry_proof
structMake_continuous := sorry_proof

instance {K} [IsROrC K] [ArrayType Cont Idx Elem] [SemiInnerProductSpace K Elem] : SemiInnerProductSpaceStruct K Cont Idx (fun _ => Elem) where
inner_structProj := sorry_proof
testFun_structProj := sorry_proof
56 changes: 44 additions & 12 deletions SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import SciLean.Util.SorryProof
import SciLean.Data.Index
import SciLean.Data.ListN
import SciLean.Data.StructType.Basic

namespace SciLean

Expand Down Expand Up @@ -39,6 +40,7 @@ class ReserveElem (Cont : USize → Type u) (Elem : outParam (Type w)) where
export ReserveElem (reserveElem)
attribute [irreducible] reserveElem


/-- This class says that `Cont` behaves like an array with `Elem` values indexed by `Idx`
Examples for `Idx = Fin n` and `Elem = ℝ` are: `ArrayN ℝ n` or `ℝ^{n}`
Expand All @@ -57,15 +59,13 @@ Alternative notation:
class ArrayType (Cont : Type u) (Idx : Type v |> outParam) (Elem : Type w |> outParam)
extends GetElem Cont Idx Elem (λ _ _ => True),
SetElem Cont Idx Elem,
IntroElem Cont Idx Elem
IntroElem Cont Idx Elem,
StructType Cont Idx (fun _ => Elem)
where
ext : ∀ f g : Cont, (∀ x : Idx, f[x] = g[x]) → f = g
getElem_setElem_eq : ∀ (x : Idx) (y : Elem) (f : Cont), (setElem f x y)[x] = y
getElem_setElem_neq : ∀ (i j : Idx) (val : Elem) (arr : Cont), i ≠ j → (setElem arr i val)[j] = arr[j]
getElem_introElem : ∀ f i, (introElem f)[i] = f i
getElem_structProj : ∀ (x : Cont) (i : Idx), x[i] = structProj x i
setElem_structModify : ∀ (x : Cont) (i : Idx) (xi : Elem), setElem x i xi = structModify i (fun _ => xi) x
introElem_structMake : ∀ (f : Idx → Elem), introElem f = structMake f

attribute [ext] ArrayType.ext
attribute [simp] ArrayType.getElem_setElem_eq ArrayType.getElem_introElem
attribute [default_instance] ArrayType.toGetElem ArrayType.toSetElem ArrayType.toIntroElem

class LinearArrayType (Cont : USize → Type u) (Elem : Type w |> outParam)
Expand All @@ -90,7 +90,37 @@ instance {T} {Y : outParam Type} [inst : LinearArrayType T Y] (n) : ArrayType (T

namespace ArrayType

variable {Cont : Type} {Idx : Type |> outParam} {Elem : Type |> outParam}
variable
{Cont : Type} {Idx : Type |> outParam} {Elem : Type |> outParam}
[ArrayType Cont Idx Elem]


@[ext]
theorem ext (x y : Cont) : (∀ i, x[i] = y[i]) → x = y :=
by
intros h
apply structExt (I:=Idx)
simp[getElem_structProj] at h
exact h

@[simp]
theorem getElem_setElem_eq (i : Idx) (xi : Elem) (x : Cont)
: (setElem x i xi)[i] = xi :=
by
simp[setElem_structModify, getElem_structProj]

@[simp]
theorem getElem_setElem_neq (i j : Idx) (xi : Elem) (x : Cont)
: (i≠j) → (setElem x i xi)[j] = x[j] :=
by
intro h
simp (discharger:=assumption) [setElem_structModify, getElem_structProj]

@[simp]
theorem getElem_introElem (f : Idx → Elem) (i : Idx)
: (introElem (Cont:=Cont) f)[i] = f i :=
by
simp[introElem_structMake, getElem_structProj]

@[simp]
theorem introElem_getElem [ArrayType Cont Idx Elem] (cont : Cont)
Expand All @@ -100,15 +130,17 @@ theorem introElem_getElem [ArrayType Cont Idx Elem] (cont : Cont)
-- Maybe turn this into a class and this is a default implementation
def modifyElem [GetElem Cont Idx Elem λ _ _ => True] [SetElem Cont Idx Elem]
(arr : Cont) (i : Idx) (f : Elem → Elem) : Cont :=
setElem arr i (f (arr[i]))
structModify i f arr

set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
@[simp]
theorem getElem_modifyElem_eq [ArrayType Cont Idx Elem] (cont : Cont) (idx : Idx) (f : Elem → Elem)
: (modifyElem cont idx f)[idx] = f cont[idx] := by simp[modifyElem]; done
: (modifyElem cont idx f)[idx] = f cont[idx] := by simp[getElem_structProj,modifyElem]; done

@[simp]
theorem getElem_modifyElem_neq [inst : ArrayType Cont Idx Elem] (arr : Cont) (i j : Idx) (f : Elem → Elem)
: i ≠ j → (modifyElem arr i f)[j] = arr[j] := by simp[modifyElem]; apply ArrayType.getElem_setElem_neq; done
: i ≠ j → (modifyElem arr i f)[j] = arr[j] := by intro h; simp [h,modifyElem, getElem_structProj,modifyElem]; done


-- Maybe turn this into a class and this is a default implementation
Expand Down Expand Up @@ -233,7 +265,7 @@ section Operations
end Operations

@[simp]
theorem sum_introElem [EnumType Idx] [ArrayType Cont Idx Elem] [AddCommMonoid Elem] {ι} [EnumType ι] (f : ι → Idx → Elem)
theorem sum_introElem [EnumType Idx] [ArrayType Cont Idx Elem] [AddCommMonoid Elem] {ι} [EnumType ι] (f : ι → Idx → Elem)
: ∑ j, introElem (Cont:=Cont) (fun i => f j i)
=
introElem fun i => ∑ j, f j i
Expand Down
16 changes: 12 additions & 4 deletions SciLean/Data/DataArray/DataArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,19 @@ instance : SetElem (DataArrayN α ι) ι α where
instance : IntroElem (DataArrayN α ι) ι α where
introElem f := ⟨DataArray.intro f, sorry_proof⟩

instance : StructType (DataArrayN α ι) ι (fun _ => α) where
structProj x i := x[i]
structMake f := introElem f
structModify i f x := setElem x i (f x[i])
left_inv := sorry_proof
right_inv := sorry_proof
structProj_structModify := sorry_proof
structProj_structModify' := sorry_proof

instance : ArrayType (DataArrayN α ι) ι α where
ext := sorry_proof
getElem_setElem_eq := sorry_proof
getElem_setElem_neq := sorry_proof
getElem_introElem := sorry_proof
getElem_structProj := by intros; rfl
setElem_structModify := by intros; rfl
introElem_structMake := by intros; rfl

instance : ArrayTypeNotation (DataArrayN α ι) ι α := ⟨⟩

Expand Down
48 changes: 36 additions & 12 deletions SciLean/Data/DataArray/VecN.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ namespace Vec2
instance : IntroElem (Vec2 α) (Idx 2) α where
introElem := intro

instance : StructType (Vec2 α) (Idx 2) (fun _ => α) where
structProj x i := x[i]
structMake f := introElem f
structModify i f x := setElem x i (f x[i])
left_inv := sorry_proof
right_inv := sorry_proof
structProj_structModify := sorry_proof
structProj_structModify' := sorry_proof

instance : ArrayType (Vec2 α) (Idx 2) α where
ext := sorry_proof
getElem_setElem_eq := sorry_proof
getElem_setElem_neq := sorry_proof
getElem_introElem := sorry_proof
getElem_structProj := by intros; rfl
setElem_structModify := by intros; rfl
introElem_structMake := by intros; rfl

instance [ba : PlainDataType α] : PlainDataType (Vec2 α) where
btype :=
Expand Down Expand Up @@ -141,11 +149,19 @@ namespace Vec3
instance : IntroElem (Vec3 α) (Idx 3) α where
introElem := intro

instance : StructType (Vec3 α) (Idx 3) (fun _ => α) where
structProj x i := x[i]
structMake f := introElem f
structModify i f x := setElem x i (f x[i])
left_inv := sorry_proof
right_inv := sorry_proof
structProj_structModify := sorry_proof
structProj_structModify' := sorry_proof

instance : ArrayType (Vec3 α) (Idx 3) α where
ext := sorry_proof
getElem_setElem_eq := sorry_proof
getElem_setElem_neq := sorry_proof
getElem_introElem := sorry_proof
getElem_structProj := by intros; rfl
setElem_structModify := by intros; rfl
introElem_structMake := by intros; rfl

instance [ba : PlainDataType α] : PlainDataType (Vec3 α) where
btype :=
Expand Down Expand Up @@ -250,11 +266,19 @@ namespace Vec4
instance : IntroElem (Vec4 α) (Idx 4) α where
introElem := intro

instance : StructType (Vec4 α) (Idx 4) (fun _ => α) where
structProj x i := x[i]
structMake f := introElem f
structModify i f x := setElem x i (f x[i])
left_inv := sorry_proof
right_inv := sorry_proof
structProj_structModify := sorry_proof
structProj_structModify' := sorry_proof

instance : ArrayType (Vec4 α) (Idx 4) α where
ext := sorry_proof
getElem_setElem_eq := sorry_proof
getElem_setElem_neq := sorry_proof
getElem_introElem := sorry_proof
getElem_structProj := by intros; rfl
setElem_structModify := by intros; rfl
introElem_structMake := by intros; rfl

instance [ba : PlainDataType α] : PlainDataType (Vec4 α) where
btype :=
Expand Down
4 changes: 2 additions & 2 deletions SciLean/Tactic/FTrans/Init.lean
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ To register function transformation call:
where <name> is name of the function transformation and <info> is corresponding `FTrans.Info`.
"

if (← getBoolOption `linter.ftransSsaRhs true) then
if (← getBoolOption `linter.ftransSsaRhs) then
let rhs' ← rhs.toSSA #[]
if ¬(rhs.eqv rhs') then
logWarning s!"right hand side is not in single static assigment form, expected form:\n{←ppExpr rhs'}"
Expand All @@ -296,7 +296,7 @@ where <name> is name of the function transformation and <info> is corresponding
|>.append data.declSuffix
|>.append (transName.getString.append "_rule")

if (← getBoolOption `linter.ftransDeclName true) &&
if (← getBoolOption `linter.ftransDeclName) &&
¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then
logWarning s!"suggested name for this rule is {suggestedRuleName}"

Expand Down

0 comments on commit fdb7a1c

Please sign in to comment.