diff --git a/SciLean/Data/EnumType.lean b/SciLean/Data/EnumType.lean index a0eacd8a..46f6050f 100644 --- a/SciLean/Data/EnumType.lean +++ b/SciLean/Data/EnumType.lean @@ -52,16 +52,21 @@ namespace EnumType forIn := λ init f => f () init } - @[inline] partial def Fin.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Fin n → β → m (ForInStep β)) := - let rec @[specialize] forLoop (i : Nat) (b : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do + @[inline] partial def Fin.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Fin n → β → m (ForInStep β)) := do + -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler + -- seems to have much better time optimizing code with `StateT` + let rec @[specialize] forLoop (i : Nat) (b : β) : StateT Bool m β := do if h : i < n then match (← f ⟨i,h⟩ b) with - | ForInStep.done b => pure (.done b) + | ForInStep.done b => set true; pure b | ForInStep.yield b => forLoop (i + 1) b else - pure (.yield b) - forLoop 0 init - + pure b + let (val,b) ← forLoop 0 init false + if b then + return (ForInStep.done val) + else + return (ForInStep.yield val) @[inline] instance : EnumType (Fin n) := @@ -71,15 +76,21 @@ namespace EnumType } @[inline] - partial def Idx.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) := - let rec @[specialize] forLoop (i : USize) (b : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do + partial def Idx.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) := do + -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler + -- seems to have much better time optimizing code with `StateT` + let rec @[specialize] forLoop (i : USize) (b : β) : StateT Bool m β := do if h : i < n then match (← f ⟨i,h⟩ b) with - | ForInStep.done b => pure (.done b) + | ForInStep.done b => set true; pure b | ForInStep.yield b => forLoop (i + 1) b else - pure (.yield b) - forLoop 0 init + pure b + let (val,b) ← forLoop 0 init false + if b then + return (ForInStep.done val) + else + return (ForInStep.yield val) @[inline] partial instance : EnumType (Idx n) := @@ -90,15 +101,22 @@ namespace EnumType } @[inline] - partial def Idx'.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) := - let rec @[specialize] forLoop (i : Int64) (val : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do + partial def Idx'.forIn {m : Type → Type} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) := do + -- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler + -- seems to have much better time optimizing code with `StateT` + let rec @[specialize] forLoop (i : Int64) (val : β) : StateT Bool m β := do if _h : i ≤ b then match (← f ⟨i,sorry_proof⟩ val) with - | ForInStep.done val => pure (.done val) + | ForInStep.done val => set true; pure val | ForInStep.yield val => forLoop (i + 1) val else - pure (.yield val) - forLoop a init + pure val + let (val,b) ← forLoop a init false + if b then + return (ForInStep.done val) + else + return (ForInStep.yield val) + @[inline] partial instance : EnumType (Idx' a b) := diff --git a/SciLean/Data/Index.lean b/SciLean/Data/Index.lean index 0cf5c15a..8e5ec982 100644 --- a/SciLean/Data/Index.lean +++ b/SciLean/Data/Index.lean @@ -23,6 +23,7 @@ export Index (toIdx fromIdx) namespace Index +@[macro_inline] instance : Index Empty where size := 0 isValid := true @@ -33,7 +34,7 @@ instance : Index Empty where fromIdx_toIdx := sorry_proof toIdx_fromIdx := sorry_proof - +@[macro_inline] instance : Index Unit where size := 1 isValid := true @@ -44,7 +45,7 @@ instance : Index Unit where fromIdx_toIdx := sorry_proof toIdx_fromIdx := sorry_proof - +@[macro_inline] instance : Index (Idx n) where size := n isValid := true @@ -67,6 +68,7 @@ instance : Index (Idx' a b) where -- Row major ordering, this respects `<` defined on `ι × κ` +@[macro_inline] instance [Index ι] [Index κ] : Index (ι×κ) where size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize isValid := @@ -86,6 +88,7 @@ instance [Index ι] [Index κ] : Index (ι×κ) where -- Row major ordering, this respects `<` defined on `ι × κ` +@[macro_inline] instance [Index ι] [Index κ] : Index (ι×ₗκ) where size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize isValid :=