diff --git a/SciLean/Core/FloatAsReal.lean b/SciLean/Core/FloatAsReal.lean index 2b86981b..fa020ff3 100644 --- a/SciLean/Core/FloatAsReal.lean +++ b/SciLean/Core/FloatAsReal.lean @@ -25,7 +25,7 @@ instance : CommRing Float where nsmul_zero := sorry_proof nsmul_succ n x := sorry_proof sub_eq_add_neg a b := sorry_proof - natCast n := n.toUSize.toFloat + natCast n := n.toUInt64.toFloat natCast_zero := sorry_proof natCast_succ := sorry_proof intCast n := if n ≥ 0 then n.toNat.toUInt64.toFloat else -((-n).toNat.toUInt64).toFloat diff --git a/SciLean/Core/Meta/ExtendContext.lean b/SciLean/Core/Meta/ExtendContext.lean index 57921891..90f9d374 100644 --- a/SciLean/Core/Meta/ExtendContext.lean +++ b/SciLean/Core/Meta/ExtendContext.lean @@ -1,6 +1,6 @@ import SciLean.Core.Objects.FinVec import SciLean.Lean.Meta.Basic -import SciLean.Data.Index +import SciLean.Data.IndexType open Lean Meta Qq diff --git a/SciLean/Core/Objects/SemiInnerProductSpace.lean b/SciLean/Core/Objects/SemiInnerProductSpace.lean index 0511eb0d..82aca1e9 100644 --- a/SciLean/Core/Objects/SemiInnerProductSpace.lean +++ b/SciLean/Core/Objects/SemiInnerProductSpace.lean @@ -3,8 +3,6 @@ import SciLean.Core.Objects.Vec import SciLean.Core.Objects.Scalar import SciLean.Core.NotationOverField -import SciLean.Data.EnumType - namespace SciLean open LeanColls diff --git a/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean b/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean index 27be0662..3d73b020 100644 --- a/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean +++ b/SciLean/Core/Rand/Distributions/WalkOnSpheres.lean @@ -69,8 +69,7 @@ def harmonicRec_fwdDeriv (n : ℕ) induction n n' du h . simp[harmonicRec]; autodiff . simp[harmonicRec]; - simp only [smul_push] - autodiff; autodiff + simp only [smul_push]; autodiff def harmonicRec.arg_x.fwdDeriv_randApprox (n : ℕ) diff --git a/SciLean/Core/Simp/Sum.lean b/SciLean/Core/Simp/Sum.lean deleted file mode 100644 index aa3663d7..00000000 --- a/SciLean/Core/Simp/Sum.lean +++ /dev/null @@ -1,56 +0,0 @@ -import SciLean.Data.EnumType - -namespace SciLean - -variable {ι κ} [EnumType ι] [EnumType κ] - --- @[simp] --- theorem sum_if {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) --- : (∑ i, if i = j then f i else 0) --- = --- f j --- := sorry_proof - --- @[simp] --- theorem sum_if' {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) --- : (∑ i, if j = i then f i else 0) --- = --- f j --- := sorry_proof - --- @[simp] --- theorem sum_lambda_swap {α β : Type _} [AddCommMonoid β] (f : ι → α → β) --- : ∑ i, (fun a => f i a) --- = --- fun a => ∑ i, f i a --- := sorry_proof - - --- @[simp] --- theorem sum2_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) --- : (∑ i, ∑ j, if ij = (i,j) then f i j else 0) --- = --- f ij.1 ij.2 --- := sorry_proof - --- @[simp] --- theorem sum2'_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) --- : (∑ j, ∑ i, if ij = (i,j) then f i j else 0) --- = --- f ij.1 ij.2 --- := sorry_proof - - --- @[simp] --- theorem sum2_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) --- : (∑ i, ∑ j, if (i,j) = ij then f i j else 0) --- = --- f ij.1 ij.2 --- := sorry_proof - --- @[simp] --- theorem sum2'_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) --- : (∑ j, ∑ i, if (i,j) = ij then f i j else 0) --- = --- f ij.1 ij.2 --- := sorry_proof diff --git a/SciLean/Data/ArrayType/Basic.lean b/SciLean/Data/ArrayType/Basic.lean index 3026a51d..d9b2eaa1 100644 --- a/SciLean/Data/ArrayType/Basic.lean +++ b/SciLean/Data/ArrayType/Basic.lean @@ -1,5 +1,4 @@ import SciLean.Util.SorryProof -import SciLean.Data.Index import SciLean.Data.ListN import SciLean.Data.StructType.Basic import SciLean.Data.Function diff --git a/SciLean/Data/ArrayType/Notation.lean b/SciLean/Data/ArrayType/Notation.lean index 7f04327f..e4050bef 100644 --- a/SciLean/Data/ArrayType/Notation.lean +++ b/SciLean/Data/ArrayType/Notation.lean @@ -27,8 +27,11 @@ abbrev introElemNotation {Cont Idx Elem} [DecidableEq Idx] [ArrayType Cont Idx E open Lean.TSyntax.Compat in -- macro "⊞ " x:term " => " b:term:51 : term => `(introElemNotation fun $x => $b) -- macro "⊞ " x:term " : " X:term " => " b:term:51 : term => `(introElemNotation fun ($x : $X) => $b) + +-- The `by exact` is a hack to make certain case work +-- see: https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/uncurry.20fails.20with.20.60Icc.60 open Term Function in -macro "⊞ " xs:funBinder* " => " b:term:51 : term => `(introElemNotation (HasUncurry.uncurry fun $xs* => $b)) +macro "⊞ " xs:funBinder* " => " b:term:51 : term => `(introElemNotation (HasUncurry.uncurry (by exact (fun $xs* => $b)))) @[app_unexpander introElemNotation] @@ -153,11 +156,11 @@ partial def expand' (l : List (TSyntax `dimSpec)) : TermElabM Expr := match t with | `(dimSpec| $n:term) => do try - let n ← elabTerm n q(USize) + let n ← elabTerm n q(Nat) return ← mkAppM ``Fin #[n] catch _ => return ← elabTerm n none - | `(dimSpec| [$n:term : $m:term]) => do elabTerm (← `(Idx' $n $m)) none + | `(dimSpec| [$n:term : $m:term]) => do elabTerm (← `(↑(Set.Icc ($n : Int) ($m : Int)))) q(Type) | `(dimSpec| [$ds:dimSpec,*]) => expand' ds.getElems.toList | _ => throwError "unexpected type power syntax" | t :: l' => do @@ -184,7 +187,7 @@ elab_rules (kind:=typeIntPower) : term let Y ← expand' ns.getElems.toList let C ← mkFreshTypeMVar - let inst ← synthInstance <| mkAppN (← mkConst ``ArrayTypeNotation) #[C,Y,X] + let inst ← synthInstance <| mkAppN (← mkConstWithFreshMVarLevels ``ArrayTypeNotation) #[C,Y,X] let C ← whnfR (← instantiateMVars C) return ← instantiateMVars <| ← mkAppOptM ``arrayTypeCont #[Y,X,C,inst] diff --git a/SciLean/Data/DataArray/DataArray.lean b/SciLean/Data/DataArray/DataArray.lean index 392919e6..29493c31 100644 --- a/SciLean/Data/DataArray/DataArray.lean +++ b/SciLean/Data/DataArray/DataArray.lean @@ -65,7 +65,7 @@ def DataArray.reserve (arr : DataArray α) (capacity : Nat) : DataArray α := let newBytes := pd.bytes capacity let mut arr' : DataArray α := ⟨ByteArray.mkArray newBytes 0, arr.size, sorry_proof⟩ -- copy over the old data - for i in fullRange (Fin arr.size) do + for i in IndexType.univ (Fin arr.size) do arr' := arr'.set ⟨i.1,sorry_proof⟩ (arr.get i) arr' diff --git a/SciLean/Data/DataArray/PlainDataType.lean b/SciLean/Data/DataArray/PlainDataType.lean index 03e6f6ef..c01f9b36 100644 --- a/SciLean/Data/DataArray/PlainDataType.lean +++ b/SciLean/Data/DataArray/PlainDataType.lean @@ -1,5 +1,6 @@ import SciLean.Util.SorryProof -import SciLean.Data.Index +-- import SciLean.Data.Index +import LeanColls.Classes.IndexType namespace SciLean open LeanColls @@ -302,101 +303,102 @@ instance : PlainDataType Bool where btype := .inl Bool.bitType ---------------- Idx n ------------------------------------------------ +--------------- Fin n ------------------------------------------------ ---------------------------------------------------------------------- -/-- Number of bits necessary to store `Idx n` -/ -def Idx.bitSize (n : USize) : USize := (USize.log2 n + (n - (1 <<< (USize.log2 n)) != 0).toUInt64.toUSize) -def Idx.byteSize (n : USize) : USize := (Idx.bitSize n + 7) / 8 +/-- Number of bits necessary to store `Fin n` -/ +def Fin.bitSize (n : Nat) : Nat := (Nat.log2 n + (n - (1 <<< (Nat.log2 n)) != 0).toUInt64.toNat) +def Fin.byteSize (n : Nat) : Nat := (Fin.bitSize n + 7) / 8 --- INCONSISTENT: This breaks consistency with (n=0) as we could make `Idx 0` from a byte +-- INCONSISTENT: This breaks consistency with (n=0) as we could make `Fin 0` from a byte -- Adding assumption (n≠0) is really annoying, what to do about this? -def Idx.bitType (n : USize) (_ : n ≤ 256) : BitType (Idx n) where +def Fin.bitType (n : Nat) (_ : n ≤ 256) : BitType (Fin n) where bits := (bitSize n).toUInt8 h_size := sorry_proof - fromByte b := ⟨b.toUSize % n, sorry_proof⟩ --- The modulo here is just in case to remove junk bit values, also we need `n≠0` for consistency + fromByte b := ⟨b.toNat % n, sorry_proof⟩ --- The modulo here is just in case to remove junk bit values, also we need `n≠0` for consistency toByte b := b.1.toUInt8 fromByte_toByte := sorry_proof -def Idx.byteType (n : USize) (_ : 256 < n) : ByteType (Idx n) where - bytes := byteSize n +def Fin.byteType (n : Nat) (_ : 256 < n) : ByteType (Fin n) where + bytes := (byteSize n).toUSize h_size := sorry_proof fromByteArray b i _ := Id.run do let bytes := byteSize n - let ofByte := i * bytes + let ofByte := i * bytes.toUSize let mut val : USize := 0 - for j in fullRange (Idx bytes) do - val := val + ((b[ofByte+j.1]'sorry_proof).toUSize <<< (j.1*(8:USize))) - ⟨val, sorry_proof⟩ + for j in IndexType.univ (Fin bytes) do + val := val + ((b[ofByte+j.1.toUSize]'sorry_proof).toUSize <<< (j.1.toUSize*(8:USize))) + ⟨val.toNat, sorry_proof⟩ toByteArray b i _ val := Id.run do let bytes := byteSize n - let ofByte := i * bytes + let ofByte := i * bytes.toUSize let mut b := b - for j in fullRange (Idx bytes) do - b := b.uset (ofByte+j.1) (val.1 >>> (j.1*(8:USize))).toUInt8 sorry_proof + for j in IndexType.univ (Fin bytes) do + b := b.uset (ofByte+j.1.toUSize) (val.1.toUSize >>> (j.1.toUSize*(8:USize))).toUInt8 sorry_proof b toByteArray_size := sorry_proof fromByteArray_toByteArray := sorry_proof fromByteArray_toByteArray_other := sorry_proof --- INCONSISTENT: This breaks consistency see Idx.bitType -instance (n) : PlainDataType (Idx n) where +-- INCONSISTENT: This breaks consistency see Fin.bitType +instance (n) : PlainDataType (Fin n) where btype := if h : n ≤ 256 - then .inl (Idx.bitType n h) - else .inr (Idx.byteType n (by simp at h; apply h)) + then .inl (Fin.bitType n h) + else .inr (Fin.byteType n (by simp at h; apply h)) --------------- Index ---------------------------------------------- ----------------------------------------------------------------------- +-- TODO: change to IndexType +-- -------------- Index ---------------------------------------------- +-- ---------------------------------------------------------------------- -def Index.bitType (α : Type) [Index α] (h : Index.size α ≤ 256) : BitType α where - bits := Idx.bitSize (Index.size α) |>.toUInt8 - h_size := sorry_proof - fromByte b := fromIdx <| (Idx.bitType (Index.size α) h).fromByte b - toByte a := (Idx.bitType (Index.size α) h).toByte (toIdx a) - fromByte_toByte := sorry_proof +-- def Index.bitType (α : Type) [Index α] (h : Index.size α ≤ 256) : BitType α where +-- bits := Idx.bitSize (Index.size α) |>.toUInt8 +-- h_size := sorry_proof +-- fromByte b := fromIdx <| (Idx.bitType (Index.size α) h).fromByte b +-- toByte a := (Idx.bitType (Index.size α) h).toByte (toIdx a) +-- fromByte_toByte := sorry_proof -def Index.byteType (α : Type) [Index α] (hn : 256 < Index.size α ) : ByteType α where - bytes := Idx.byteSize (Index.size α) - h_size := sorry_proof +-- def Index.byteType (α : Type) [Index α] (hn : 256 < Index.size α ) : ByteType α where +-- bytes := Idx.byteSize (Index.size α) +-- h_size := sorry_proof - fromByteArray b i h := fromIdx <| (Idx.byteType (Index.size α) hn).fromByteArray b i h - toByteArray b i h a := (Idx.byteType (Index.size α) hn).toByteArray b i h (toIdx a) +-- fromByteArray b i h := fromIdx <| (Idx.byteType (Index.size α) hn).fromByteArray b i h +-- toByteArray b i h a := (Idx.byteType (Index.size α) hn).toByteArray b i h (toIdx a) - toByteArray_size := sorry_proof - fromByteArray_toByteArray := sorry_proof - fromByteArray_toByteArray_other := sorry_proof +-- toByteArray_size := sorry_proof +-- fromByteArray_toByteArray := sorry_proof +-- fromByteArray_toByteArray_other := sorry_proof -/-- Index is `PlainDataType` via conversion from/to `Idx n` +-- /-- Index is `PlainDataType` via conversion from/to `Idx n` -**Instance diamond** This instance `instPlainDataTypeProd` is prefered over this one. +-- **Instance diamond** This instance `instPlainDataTypeProd` is prefered over this one. -This instance makes a diamond together with `instPlainDataTypeProd`. Using this instance is more computationally intensive when writting and reading from `DataArra` but it consumes less memory. The `instPlainDataTypeProd` is doing the exact opposite. +-- This instance makes a diamond together with `instPlainDataTypeProd`. Using this instance is more computationally intensive when writting and reading from `DataArra` but it consumes less memory. The `instPlainDataTypeProd` is doing the exact opposite. -Example: `Idx (2^4+1) × Idx (2^4-1)` +-- Example: `Idx (2^4+1) × Idx (2^4-1)` -As Product: - The type `Idx (2^4+1)` needs 5 bits. - The type `Idx (2^4-1)` needs 4 bits. - Thus `Idx (2^4+1) × Idx (2^4-1)` needs 9 bits, thus 2 bytes, as `instPlainDataTypeProd` +-- As Product: +-- The type `Idx (2^4+1)` needs 5 bits. +-- The type `Idx (2^4-1)` needs 4 bits. +-- Thus `Idx (2^4+1) × Idx (2^4-1)` needs 9 bits, thus 2 bytes, as `instPlainDataTypeProd` -As Index: - `Idx (2^4+1) × Idx (2^4-1) ≈ Idx (2^8-1)` - The type `Idx (2^8-1)` needs 8 bits thus only a single byte as `instPlainDataTypeIndex` +-- As Index: +-- `Idx (2^4+1) × Idx (2^4-1) ≈ Idx (2^8-1)` +-- The type `Idx (2^8-1)` needs 8 bits thus only a single byte as `instPlainDataTypeIndex` --/ -instance (priority := low) instPlainDataTypeIndex {α : Type} [Index α] : PlainDataType α where - btype := - if h : (Index.size α) ≤ 256 - then .inl (Index.bitType α h) - else .inr (Index.byteType α (by simp at h; apply h)) +-- -/ +-- instance (priority := low) instPlainDataTypeIndex {α : Type} [Index α] : PlainDataType α where +-- btype := +-- if h : (Index.size α) ≤ 256 +-- then .inl (Index.bitType α h) +-- else .inr (Index.byteType α (by simp at h; apply h)) -------------- Float ------------------------------------------------- ---------------------------------------------------------------------- diff --git a/SciLean/Data/EnumType.lean b/SciLean/Data/EnumType.lean index 83c9fba2..f7538622 100644 --- a/SciLean/Data/EnumType.lean +++ b/SciLean/Data/EnumType.lean @@ -3,8 +3,6 @@ import Mathlib.Algebra.Group.Defs import SciLean.Util.SorryProof import SciLean.Data.ColProd -import SciLean.Data.Idx -import SciLean.Data.IndexType import LeanColls @@ -81,56 +79,56 @@ namespace EnumType forIn := Fin.forIn } - @[inline] - partial def Idx.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) := do - -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler - -- seems to have much better time optimizing code with `StateT` - let rec @[specialize] forLoop (i : USize) (b : β) : StateT Bool m β := do - if h : i < n then - match (← f ⟨i,h⟩ b) with - | ForInStep.done b => set true; pure b - | ForInStep.yield b => forLoop (i + 1) b - else - pure b - let (val,b) ← forLoop 0 init false - if b then - return (ForInStep.done val) - else - return (ForInStep.yield val) + -- @[inline] + -- partial def Idx.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) := do + -- -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler + -- -- seems to have much better time optimizing code with `StateT` + -- let rec @[specialize] forLoop (i : USize) (b : β) : StateT Bool m β := do + -- if h : i < n then + -- match (← f ⟨i,h⟩ b) with + -- | ForInStep.done b => set true; pure b + -- | ForInStep.yield b => forLoop (i + 1) b + -- else + -- pure b + -- let (val,b) ← forLoop 0 init false + -- if b then + -- return (ForInStep.done val) + -- else + -- return (ForInStep.yield val) - @[inline] - partial instance : EnumType (Idx n) := - { - decEq := by infer_instance + -- @[inline] + -- partial instance : EnumType (Idx n) := + -- { + -- decEq := by infer_instance - forIn := Idx.forIn - } + -- forIn := Idx.forIn + -- } - @[inline] - partial def Idx'.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) := do - -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler - -- seems to have much better time optimizing code with `StateT` - let rec @[specialize] forLoop (i : Int64) (val : β) : StateT Bool m β := do - if _h : i ≤ b then - match (← f ⟨i,sorry_proof⟩ val) with - | ForInStep.done val => set true; pure val - | ForInStep.yield val => forLoop (i + 1) val - else - pure val - let (val,b) ← forLoop a init false - if b then - return (ForInStep.done val) - else - return (ForInStep.yield val) + -- @[inline] + -- partial def Idx'.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) := do + -- -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler + -- -- seems to have much better time optimizing code with `StateT` + -- let rec @[specialize] forLoop (i : Int64) (val : β) : StateT Bool m β := do + -- if _h : i ≤ b then + -- match (← f ⟨i,sorry_proof⟩ val) with + -- | ForInStep.done val => set true; pure val + -- | ForInStep.yield val => forLoop (i + 1) val + -- else + -- pure val + -- let (val,b) ← forLoop a init false + -- if b then + -- return (ForInStep.done val) + -- else + -- return (ForInStep.yield val) - @[inline] - partial instance : EnumType (Idx' a b) := - { - decEq := by infer_instance + -- @[inline] + -- partial instance : EnumType (Idx' a b) := + -- { + -- decEq := by infer_instance - forIn := Idx'.forIn - } + -- forIn := Idx'.forIn + -- } -- /-- Embeds `ForInStep β` to `FoInStep (ForInStep β)`, useful for exiting from double for loops. diff --git a/SciLean/Data/Function.lean b/SciLean/Data/Function.lean index 5743e4e9..cb527539 100644 --- a/SciLean/Data/Function.lean +++ b/SciLean/Data/Function.lean @@ -1,5 +1,6 @@ import Mathlib.Logic.Function.Basic -import SciLean.Data.Index +-- import SciLean.Data.Index +import SciLean.Data.IndexType def Function.Inverse (g : β → α) (f : α → β) := Function.LeftInverse g f ∧ Function.RightInverse g f @@ -66,28 +67,28 @@ theorem Function.modify_noteq {a a' : α} (h : a ≠ a') (g : β a' → β a') ( end FunctionModify -def Function.repeatIdx (f : ι → α → α) (init : α) : α := Id.run do - let mut x := init - for i in IndexType.univ ι do - x := f i x - x +-- def Function.repeatIdx (f : ι → α → α) (init : α) : α := Id.run do +-- let mut x := init +-- for i in IndexType.univ ι do +-- x := f i x +-- x -def Function.repeat (n : Nat) (f : α → α) (init : α) : α := - repeatIdx (fun (_ : Fin n) x => f x) init +-- def Function.repeat (n : Nat) (f : α → α) (init : α) : α := +-- repeatIdx (fun (_ : Fin n) x => f x) init -@[simp] -theorem Function.repeatIdx_update {α : Type _} (f : ι → α → α) (g : ι → α) - : repeatIdx (fun i g' => Function.update g' i (f i (g' i))) g - = - fun i => f i (g i) := sorry_proof +-- @[simp] +-- theorem Function.repeatIdx_update {α : Type _} (f : ι → α → α) (g : ι → α) +-- : repeatIdx (fun i g' => Function.update g' i (f i (g' i))) g +-- = +-- fun i => f i (g i) := sorry_proof -/-- Specialized formulation of `Function.repeatIdx_update` which is sometimes more -succesfull with unification -/ -@[simp] -theorem Function.repeatIdx_update' {α : Type _} (f : ι → α) (g : ι → α) (op : α → α → α) - : repeatIdx (fun i g' => Function.update g' i (op (g' i) (f i))) g - = - fun i => op (g i) (f i) := -by - apply Function.repeatIdx_update (f := fun i x => op x (f i)) +--/-- Specialized formulation of `Function.repeatIdx_update` which is sometimes more +-- succesfull with unification -/ +-- @[simp] +-- theorem Function.repeatIdx_update' {α : Type _} (f : ι → α) (g : ι → α) (op : α → α → α) +-- : repeatIdx (fun i g' => Function.update g' i (op (g' i) (f i))) g +-- = +-- fun i => op (g i) (f i) := +-- by +-- apply Function.repeatIdx_update (f := fun i x => op x (f i)) diff --git a/SciLean/Data/Idx.lean b/SciLean/Data/Idx.lean deleted file mode 100644 index 44fc1220..00000000 --- a/SciLean/Data/Idx.lean +++ /dev/null @@ -1,156 +0,0 @@ -import Mathlib.Data.Fintype.Basic -import SciLean.Util.SorryProof -import SciLean.Data.Int64 - -namespace SciLean - -/-- -Similar to `Fin n` but uses `USize` internally instead of `Nat` - -WARRNING: Needs serious revision such that overflows are handled correctly! --/ -structure Idx (n : USize) where - val : USize - property : val < n - -- Maybe add this assumption then adding two `Idx n` won't cause overflow - -- not_too_big : n < (USize.size/2).toUSize -deriving DecidableEq - -instance : ToString (Idx n) := ⟨λ i => toString i.1⟩ - -instance {n} : LT (Idx n) where - lt a b := a.val < b.val - -instance {n} : LE (Idx n) where - le a b := a.val ≤ b.val - -instance Idx.decLt {n} (a b : Idx n) : Decidable (a < b) := USize.decLt .. -instance Idx.decLe {n} (a b : Idx n) : Decidable (a ≤ b) := USize.decLe .. - -namespace Idx - -def elim0.{u} {α : Sort u} : Idx 0 → α - | ⟨_, h⟩ => absurd h (Nat.not_lt_zero _) - -variable {n : USize} - -protected def ofUSize {n : USize} (a : USize) (_ : n > 0) : Idx n := - ⟨a % n, sorry_proof⟩ - -private theorem mlt {b : USize} : {a : USize} → a < n → b % n < n := sorry_proof - --- shifting index with wrapping --- We want these operations to be invertible in `x` for every `y`. Is that the case? --- Maybe we need to require that `n < USize.size/2` -@[default_instance] -instance : HAdd (Idx n) USize (Idx n) := ⟨λ x y => ⟨(x.1 + y)%n, sorry_proof⟩⟩ -@[default_instance] -instance : HSub (Idx n) USize (Idx n) := ⟨λ x y => ⟨((x.1 + n) - (y + n))%n, sorry_proof⟩⟩ -@[default_instance] -instance : HMul USize (Idx n) (Idx n) := ⟨λ x y => ⟨(x * y.1)%n, sorry_proof⟩⟩ - -@[default_instance] -instance : HAdd (Idx n) Int64 (Idx n) := ⟨λ x y => ⟨(x.1 + y.1 + n)%n, sorry_proof⟩⟩ -@[default_instance] -instance : HSub (Idx n) Int64 (Idx n) := ⟨λ x y => ⟨(x.1 - (y.1 + n))%n, sorry_proof⟩⟩ - -@[default_instance] -instance : VAdd Int64 (Idx n) := ⟨λ x y => y + x⟩ - -def toFin {n} (i : Idx n) : Fin n.toNat := ⟨i.1.toNat, sorry_proof⟩ -def toFin' {n : Nat} (i : Idx n.toUSize) : Fin n := ⟨i.1.toNat, sorry_proof⟩ - --- @[extern c inline "(double)#1"] -def _root_.USize.toFloat (n : USize) : Float := n.toUInt64.toFloat -def toFloat {n} (i : Idx n) : Float := i.1.toFloat - -@[macro_inline] -def cast (i : Idx n) (h : n = m) : Idx m := ⟨i.1, by rw[← h]; apply i.2⟩ - -@[macro_inline] -def cast' (i : Idx n) (h : m = n) : Idx m := ⟨i.1, by rw[h]; apply i.2⟩ - -def shiftPos (x : Idx n) (s : USize) := x + s -def shiftNeg (x : Idx n) (s : USize) := x - s -def shift (x : Idx n) (s : Int) := - match s with - | .ofNat n => x.shiftPos n.toUSize - | .negSucc n => x.shiftNeg (n+1).toUSize - -/-- Splits index `i : Idx (n*m)` into `(i / n, i % n)`-/ -def prodSplit (i : Idx (n*m)) : Idx n × Idx m := - (⟨i.1 / n, sorry_proof⟩, ⟨i.1 % n, sorry_proof⟩) - -/-- Splits index `i : Idx (n*m)` into `(i / m, i % m)`-/ -def prodSplit' (i : Idx (n*m)) : Idx m × Idx n := - (⟨i.1 % m, sorry_proof⟩, ⟨i.1 / m, sorry_proof⟩) - --- This does not work as intended :( - -instance : OfNat (Idx (no_index (n+1))) i where - ofNat := Idx.ofUSize i.toUSize sorry_proof - -instance : Inhabited (Idx (no_index (n+1))) where - default := 0 - -instance : Fintype (Idx n) where - elems := { - val := Id.run do - let mut l : List (Idx n) := [] - for i in [0:n.toNat] do - l := ⟨i.toUSize, sorry_proof⟩ :: l - Multiset.ofList l.reverse - nodup := sorry_proof - } - complete := sorry_proof - - -instance (n : Nat) : Nonempty (Idx (no_index (OfNat.ofNat (n+1)))) := sorry_proof -instance (n : Nat) : OfNat (Idx (no_index (OfNat.ofNat (n+1)))) i := ⟨(i % (n+1)).toUSize, sorry_proof⟩ - - -end Idx - --------------------------------------------------------------------------------- - - - -/-- `Idx' a b = {x : Int64 // a ≤ x ∧ x ≤ b}` - -WARRNING: Needs serious revision such that overflows are handled correctly! --/ -structure Idx' (a b : Int64) where - val : Int64 - property : a ≤ val ∧ val ≤ b - -- Maybe add this assumption then adding two `Idx n` won't cause overflow - -- not_too_big : n < (USize.size/2).toUSize -deriving DecidableEq - -instance : ToString (Idx' a b) := ⟨λ i => toString i.1⟩ - -instance {a b} : LT (Idx' a b) where - lt a b := a.val < b.val - -instance {a b} : LE (Idx' a b) where - le a b := a.val ≤ b.val - -namespace Idx' - -variable {a b : Int64} - -def toFloat (i : Idx' a b) : Float := i.1.toFloat - -instance : Fintype (Idx n) where - elems := { - val := Id.run do - let mut l : List (Idx n) := [] - for i in [0:n.toNat] do - l := ⟨i.toUSize, sorry_proof⟩ :: l - Multiset.ofList l.reverse - nodup := sorry_proof - } - complete := sorry_proof - -instance : Inhabited (Idx' (no_index (-a)) (no_index a)) where - default := ⟨0, sorry_proof⟩ -instance : Nonempty (Idx' (no_index (-a)) (no_index a)) := by infer_instance diff --git a/SciLean/Data/Index.lean b/SciLean/Data/Index.lean deleted file mode 100644 index 591c0334..00000000 --- a/SciLean/Data/Index.lean +++ /dev/null @@ -1,173 +0,0 @@ -import SciLean.Util.SorryProof -import SciLean.Data.EnumType - -namespace SciLean - -class Index (ι : Type u) extends EnumType ι where - size : USize - -- This is used to assert that the number of elements is smaller then `USize.size` - -- The point is that we want to have an instance `Index (ι×κ)` from `Index ι` and `Index κ` - -- without proving `numOf ι * numOf κ < USize.size-1`. - -- Thus if `numOf ι * numOf κ ≥ USize.size` we set `isValid` to `false` and - -- give up any formal guarantees and we also panic. - isValid : Bool - - fromIdx : Idx size → ι - toIdx : ι → Idx size - - fromIdx_toIdx : isValid = true → fromIdx ∘ toIdx = id - toIdx_fromIdx : isValid = true → toIdx ∘ fromIdx = id - - -export Index (toIdx fromIdx) - -namespace Index - --- @[macro_inline] -instance : Index Empty where - size := 0 - isValid := true - - fromIdx := λ a => absurd (a := a.1<0) a.2 (by intro h; cases h) - toIdx := λ a => (by induction a; done) - - fromIdx_toIdx := sorry_proof - toIdx_fromIdx := sorry_proof - --- @[macro_inline] -instance : Index Unit where - size := 1 - isValid := true - - fromIdx := λ _ => () - toIdx := λ _ => ⟨0, sorry_proof⟩ - - fromIdx_toIdx := sorry_proof - toIdx_fromIdx := sorry_proof - --- @[macro_inline] -instance : Index (Idx n) where - size := n - isValid := true - - fromIdx := id - toIdx := id - - fromIdx_toIdx := by simp - toIdx_fromIdx := by simp - --- @[macro_inline] -instance : Index (Idx' a b) where - size := let n := b - a; if 0 < n then n.toUSize else 0 - isValid := true - - fromIdx i := ⟨i.1.toInt64 + a, sorry_proof⟩ - toIdx i := ⟨(i.1 - a).1, sorry_proof⟩ - - fromIdx_toIdx := sorry_proof - toIdx_fromIdx := sorry_proof - - --- Row major ordering, this respects `<` defined on `ι × κ` -@[macro_inline] -instance [Index ι] [Index κ] : Index (ι×κ) where - size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize - isValid := - if (size ι).toNat * (size κ).toNat < USize.size - 1 then - Index.isValid ι && Index.isValid κ - else - -- this is using the fact that `(default : Bool) = false` - panic! s!"Attempting to create `Index (ι×κ)` for too big `ι` and `κ`.\n `size ι = {size ι}`\n `size κ = {size κ}`" - - -- This has still some issues when overflow happends - -- numOf := numOf ι * numOf κ - fromIdx := λ i => (fromIdx ⟨i.1 / size κ, sorry_proof⟩, fromIdx ⟨i.1 % size κ, sorry_proof⟩) - toIdx := λ (i,j) => ⟨(toIdx i).1 * size κ + (toIdx j).1, sorry_proof⟩ - - fromIdx_toIdx := λ _ => sorry_proof - toIdx_fromIdx := λ _ => sorry_proof - - --- Row major ordering, this respects `<` defined on `ι × κ` --- @[macro_inline] -instance [Index ι] [Index κ] : Index (ι×ₗκ) where - size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize - isValid := - if (size ι).toNat * (size κ).toNat < USize.size - 1 then - Index.isValid ι && Index.isValid κ - else - -- this is using the fact that `(default : Bool) = false` - panic! s!"Attempting to create `Index (ι×ₗκ)` for too big `ι` and `κ`.\n `size ι = {size ι}`\n `size κ = {size κ}`" - - -- This has still some issues when overflow happends - -- numOf := numOf ι * numOf κ - fromIdx := λ i => (fromIdx ⟨i.1 % size ι, sorry_proof⟩, fromIdx ⟨i.1 / size ι, sorry_proof⟩) - toIdx := λ (i,j) => ⟨(toIdx j).1 * size ι + (toIdx i).1, sorry_proof⟩ - - fromIdx_toIdx := λ _ => sorry_proof - toIdx_fromIdx := λ _ => sorry_proof - - instance [Index ι] [Index κ] : Index (ι ⊕ κ) where - size := (min ((size ι).toNat + (size κ).toNat) (USize.size -1)).toUSize - isValid := - if (size ι).toNat + (size κ).toNat < USize.size - 1 then - Index.isValid ι && Index.isValid κ - else - -- this is using the fact that `(default : Bool) = false` - panic! s!"Attempting to create `Index (ι⊕κ)` for too big `ι` and `κ`.\n `size ι = {size ι}`\n `size κ = {size κ}`" - - fromIdx := λ i => - if i.1 < size ι - then Sum.inl $ fromIdx ⟨i.1, sorry_proof⟩ - else Sum.inr $ fromIdx ⟨i.1 - size ι, sorry_proof⟩ - toIdx := λ ij => - match ij with - | Sum.inl i => ⟨(toIdx i).1, sorry_proof⟩ - | Sum.inr j => ⟨(toIdx j).1 + size ι, sorry_proof⟩ - - fromIdx_toIdx := λ _ => sorry_proof - toIdx_fromIdx := λ _ => sorry_proof - - - -- TODO: revive parallel sum/join . I ditched ranges as the required decidable order, which is too much to ask sometimes - - -- /-- - -- Joins all values `f i` from left to right with operation `op` - - -- The operation `op` is assumed to be associative and `unit` is the left unit of this operation i.e. `∀ a, op unit a = a` - - -- WARRNING: This does not work correctly for small `m`. FIX THIS!!!! - -- -/ - -- def parallelJoin {α ι} [Index ι] (m : USize) (f : ι → α) (op : α → α → α) (unit : α) : α := Id.run do - -- dbg_trace "!!!FIX ME!!! Index.parallelJoin is not implemented correctly!" - -- let n := size ι - -- if n == 0 then - -- return unit - -- let mut tasks : Array (Task α) := Array.mkEmpty m.toNat - -- let m := max 1 (min m n) -- cap min/max of `m` - -- let stride : USize := (n+m-1)/m - -- let mut start : Idx n := ⟨0,sorry_proof⟩ - -- let mut stop : Idx n := ⟨stride-1, sorry_proof⟩ - -- for i in fullRange (Idx m) do - -- let r : EnumType.Range ι := some (fromIdx start, fromIdx stop) - -- let partialJoin : Unit → α := λ _ => Id.run do - -- let mut a := unit - -- for i in r do - -- a := op a (f i) - -- a - -- tasks := tasks.push (Task.spawn partialJoin) - -- start := ⟨min (start.1 + stride) (n-1), sorry_proof⟩ - -- stop := ⟨min (stop.1 + stride) (n-1), sorry_proof⟩ - - -- let mut a := unit - -- for t in tasks do - -- a := op a t.get - -- a - - - -- open EnumType in - -- /-- - -- Split the sum `∑ i, f i` into `m` chuncks and compute them in parallel - -- -/ - -- def parallelSum {X ι} [Zero X] [Add X] [Index ι] (m : USize) (f : ι → X) : X := - -- parallelJoin m f (λ x y => x + y) 0 diff --git a/SciLean/Data/IndexType.lean b/SciLean/Data/IndexType.lean index 936ece6e..d6a175b6 100644 --- a/SciLean/Data/IndexType.lean +++ b/SciLean/Data/IndexType.lean @@ -16,6 +16,44 @@ instance : LawfulIndexType (Fin n) where toFin_leftInv := by intro x; rfl toFin_rightInv := by intro x; rfl +open Set + +instance (a b : Int) : IndexType (Icc a b) where + card := ((b + 1) - a).toNat + toFin i := ⟨(i.1 - a).toNat, sorry_proof⟩ + fromFin i := ⟨a + i.1, sorry_proof⟩ + +instance (a b : Int) : LawfulIndexType (Icc a b) where + toFin_leftInv := by intro x; simp[IndexType.toFin, IndexType.fromFin] + toFin_rightInv := by intro x; sorry_proof + +instance (a b : Int) : IndexType (Ioo a b) where + card := (b - (a + 1)).toNat + toFin i := ⟨(i.1 - (a + 1)).toNat, sorry_proof⟩ + fromFin i := ⟨(a + 1) + i.1, sorry_proof⟩ + +instance (a b : Int) : LawfulIndexType (Ioo a b) where + toFin_leftInv := by intro x; simp[IndexType.toFin, IndexType.fromFin] + toFin_rightInv := by intro x; sorry_proof + +instance (a b : Int) : IndexType (Ioc a b) where + card := ((b + 1) - (a+1)).toNat + toFin i := ⟨(i.1 - (a + 1)).toNat, sorry_proof⟩ + fromFin i := ⟨(a + 1) + i.1, sorry_proof⟩ + +instance (a b : Int) : LawfulIndexType (Ioc a b) where + toFin_leftInv := by intro x; simp[IndexType.toFin, IndexType.fromFin] + toFin_rightInv := by intro x; sorry_proof + +instance (a b : Int) : IndexType (Ico a b) where + card := (b - a).toNat + toFin i := ⟨(i.1 - a).toNat, sorry_proof⟩ + fromFin i := ⟨a + i.1, sorry_proof⟩ + +instance (a b : Int) : LawfulIndexType (Ico a b) where + toFin_leftInv := by intro x; simp[IndexType.toFin, IndexType.fromFin] + toFin_rightInv := by intro x; sorry_proof + namespace SciLean -- use lean colls diff --git a/SciLean/Tactic/DeduceBy.lean b/SciLean/Tactic/DeduceBy.lean deleted file mode 100644 index fd0e37d3..00000000 --- a/SciLean/Tactic/DeduceBy.lean +++ /dev/null @@ -1,172 +0,0 @@ -import Lean -import Qq - -import Mathlib.Tactic.NormNum.Basic -import Mathlib.Tactic.NormNum.Result -import Mathlib.Data.UInt -import Mathlib.Tactic.Ring - -import SciLean.Util.RewriteBy -import SciLean.Data.Index - -/- -This norm num extension is by Kyle Miller -source: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/norm_num.20for.20USize/near/405939157 --/ -namespace Mathlib.Meta.NormNum -open Qq Lean Meta - -@[simp] -theorem succ_pow_numBits : - Nat.succ (2 ^ System.Platform.numBits - 1) = 2 ^ System.Platform.numBits := by - obtain (hbits | hbits) := System.Platform.numBits_eq <;> norm_num [hbits] - -theorem isNat_USizeDiv : {a b : USize} → {a' b' c : ℕ} → - IsNat a a' → IsNat b b' → - Nat.div (a' % 2^32) (b' % 2^32) = c → - Nat.div (a' % 2^64) (b' % 2^64) = c → - IsNat (a / b) c - | _, _, a', b', _, ⟨rfl⟩, ⟨rfl⟩, h32, h64 => by - constructor - unfold_projs - unfold USize.div - unfold_projs - unfold Fin.div - unfold_projs - simp [USize.size] - obtain (hbits | hbits) := System.Platform.numBits_eq - <;> · simp [hbits, *] - generalize_proofs h - apply USize.eq_of_val_eq - ext; rename_i x; change x = (x : Fin USize.size) - rw [Fin.val_cast_of_lt h] - -@[norm_num (_ : USize) / (_ : USize)] -def evalUSizeDiv : NormNumExt where eval {u α} e := do - let .app (.app f (a : Q($α))) (b : Q($α)) ← whnfR e | failure - haveI' : u =QL 0 := ⟨⟩; haveI' : $α =Q USize := ⟨⟩ - haveI' : $e =Q $a / $b := ⟨⟩ - guard <|← withNewMCtxDepth <| isDefEq f q(HDiv.hDiv (α := USize)) - let sUSize : Q(AddMonoidWithOne USize) := q(inferInstance) - let ⟨na, pa⟩ ← deriveNat a sUSize; let ⟨nb, pb⟩ ← deriveNat b sUSize - have nc32 : Q(ℕ) := mkRawNatLit ((na.natLit! % 2^32) / (nb.natLit! % 2^32)) - have nc64 : Q(ℕ) := mkRawNatLit ((na.natLit! % 2^64) / (nb.natLit! % 2^64)) - guard <| nc32 == nc64 - haveI' : $nc32 =Q $nc64 := ⟨⟩ - have pf32 : Q(Nat.div ($na % 2 ^ 32) ($nb % 2 ^ 32) = $nc32) := (q(Eq.refl $nc32) : Expr) - have pf64 : Q(Nat.div ($na % 2 ^ 64) ($nb % 2 ^ 64) = $nc64) := (q(Eq.refl $nc64) : Expr) - haveI' : Nat.div ($na % 2 ^ 64) ($nb % 2 ^ 64) =Q $nc64 := ⟨⟩ - return .isNat sUSize nc64 q(isNat_USizeDiv $pa $pb $pf32 $pf64) - -end Mathlib.Meta.NormNum - - -namespace SciLean - -syntax (name:=deduceBy) "deduce_by " conv : tactic - -namespace DeduceBy -open Qq Lean Meta - -/-- -Assuming that `a` has mvar `m` and `b` is an expression. - -Return mvar `m` and value `x` for it such that `a=b` is likely to hold. - -Examples: -- `a = 4 * ?m + 2`, `b = 2*n` => `(?m, (2*n-2)/4)` --/ -partial def invertNat (a b : Q(Nat)) : MetaM (Q(Nat) × Q(Nat)) := do - if a.isMVar then - return (a,b) - else - match a with - | ~q($x * $y) => - if x.hasMVar - then invertNat x q($b / $y) - else invertNat y q($b / $x) - | ~q($x / $y) => - if x.hasMVar - then invertNat x q($b * $y) - else invertNat y q($x / $b) - | ~q($x + $y) => - if x.hasMVar - then invertNat x q($b - $y) - else invertNat y q($b - $x) - | ~q($x - $y) => - if x.hasMVar - then invertNat x q($b + $y) - else invertNat y q($x - $b) - | _ => - throwError s!"`decuce_by` does not support Nat operation {← ppExpr a}" - -/-- -Assuming that `a` has mvar `m` and `b` is an expression. - -Return mvar `m` and value `x` for it such that `a=b` is likely to hold. - -Examples: -- `a = 4 * ?m + 2`, `b = 2*n` => `(?m, (2*n-2)/4)` --/ -partial def invertUSize (a b : Q(USize)) : MetaM (Q(USize) × Q(USize)) := do - if a.isMVar then - return (a,b) - else - match a with - | ~q($x * $y) => - if x.hasMVar - then invertUSize x q($b / $y) - else invertUSize y q($b / $x) - | ~q($x / $y) => - if x.hasMVar - then invertUSize x q($b * $y) - else invertUSize y q($x / $b) - | ~q($x + $y) => - if x.hasMVar - then invertUSize x q($b - $y) - else invertUSize y q($b - $x) - | ~q($x - $y) => - if x.hasMVar - then invertUSize x q($b + $y) - else invertUSize y q($x - $b) - | _ => - throwError s!"`decuce_by` does not support USize operation {← ppExpr a}" - - -open Lean Meta Elab Tactic Qq -@[tactic deduceBy] -partial def deduceByTactic : Tactic -| `(tactic| deduce_by $t:conv) => do - - let goal ← getMainGoal - let .some (_,lhs,rhs) := Expr.eq? (← goal.getType) | throwError "expected `?m = e`, got {← ppExpr (← goal.getType)}" - - -- if ¬lhs.isMVar then - -- throwError "lhs is not mvar" - - -- now we assume that a is mvar and b is and expression we want to simplify - let (goal,a,b) ← - if lhs.hasMVar then - pure (goal,lhs,rhs) - else - let goal' ← mkFreshExprMVar (← mkEq rhs lhs) - goal.assign (← mkEqSymm goal') - pure (goal'.mvarId!,rhs,lhs) - - let A ← inferType a - if A == q(Nat) || A == q(USize) then - let (m,x) ← - if A == q(Nat) - then invertNat a b - else invertUSize a b - let (x', _) ← elabConvRewrite x t - m.mvarId!.assign x' - let subgoals ← evalTacticAt (← `(tactic| (conv => (conv => lhs; $t); (conv => rhs; $t)))) goal - if subgoals.length ≠ 0 then - throwError "`decide_by` failed to show {← ppExpr (← goal.getType)}" - - else - let (b',prf) ← elabConvRewrite b t - a.mvarId!.assign b' - goal.assign prf -| _ => throwUnsupportedSyntax