From 9f7d512b22577b2972b574eb7d8ff49bc70d24fa Mon Sep 17 00:00:00 2001 From: lecopivo Date: Mon, 26 Feb 2024 00:56:41 +0100 Subject: [PATCH] fix Idx --- SciLean/Data/Idx.lean | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/SciLean/Data/Idx.lean b/SciLean/Data/Idx.lean index 47397b03..503299db 100644 --- a/SciLean/Data/Idx.lean +++ b/SciLean/Data/Idx.lean @@ -39,7 +39,7 @@ protected def ofUSize {n : USize} (a : USize) (_ : n > 0) : Idx n := private theorem mlt {b : USize} : {a : USize} → a < n → b % n < n := sorry_proof --- shifting index with wrapping +-- shifting index with wrapping -- We want these operations to be invertible in `x` for every `y`. Is that the case? -- Maybe we need to require that `n < USize.size/2` @[default_instance] @@ -60,7 +60,6 @@ instance : VAdd Int64 (Idx n) := ⟨λ x y => y + x⟩ def toFin {n} (i : Idx n) : Fin n.toNat := ⟨i.1.toNat, sorry_proof⟩ def toFin' {n : Nat} (i : Idx n.toUSize) : Fin n := ⟨i.1.toNat, sorry_proof⟩ -@[extern c inline "(double)#1"] def _root_.USize.toFloat (n : USize) : Float := n.toNat.toFloat def toFloat {n} (i : Idx n) : Float := i.1.toFloat @@ -72,19 +71,19 @@ def cast' (i : Idx n) (h : m = n) : Idx m := ⟨i.1, by rw[h]; apply i.2⟩ def shiftPos (x : Idx n) (s : USize) := x + s def shiftNeg (x : Idx n) (s : USize) := x - s -def shift (x : Idx n) (s : Int) := +def shift (x : Idx n) (s : Int) := match s with | .ofNat n => x.shiftPos n.toUSize | .negSucc n => x.shiftNeg (n+1).toUSize /-- Splits index `i : Idx (n*m)` into `(i / n, i % n)`-/ -def prodSplit (i : Idx (n*m)) : Idx n × Idx m := +def prodSplit (i : Idx (n*m)) : Idx n × Idx m := (⟨i.1 / n, sorry_proof⟩, ⟨i.1 % n, sorry_proof⟩) /-- Splits index `i : Idx (n*m)` into `(i / m, i % m)`-/ -def prodSplit' (i : Idx (n*m)) : Idx m × Idx n := +def prodSplit' (i : Idx (n*m)) : Idx m × Idx n := (⟨i.1 % m, sorry_proof⟩, ⟨i.1 / m, sorry_proof⟩) - + -- This does not work as intended :( instance : OfNat (Idx (no_index (n+1))) i where @@ -139,7 +138,7 @@ namespace Idx' variable {a b : Int64} def toFloat (i : Idx' a b) : Float := i.1.toFloat - + instance : Fintype (Idx n) where elems := { val := Id.run do @@ -154,4 +153,3 @@ instance : Fintype (Idx n) where instance : Inhabited (Idx' (no_index (-a)) (no_index a)) where default := ⟨0, sorry_proof⟩ instance : Nonempty (Idx' (no_index (-a)) (no_index a)) := by infer_instance -