Skip to content

Commit

Permalink
performance issue fix when iterating over arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 6, 2023
1 parent 8b59d2e commit 3a7aaf4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
50 changes: 34 additions & 16 deletions SciLean/Data/EnumType.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,21 @@ namespace EnumType
forIn := λ init f => f () init
}

@[inline] partial def Fin.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Fin n → β → m (ForInStep β)) :=
let rec @[specialize] forLoop (i : Nat) (b : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do
@[inline] partial def Fin.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Fin n → β → m (ForInStep β)) := do
-- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler
-- seems to have much better time optimizing code with `StateT`
let rec @[specialize] forLoop (i : Nat) (b : β) : StateT Bool m β := do
if h : i < n then
match (← f ⟨i,h⟩ b) with
| ForInStep.done b => pure (.done b)
| ForInStep.done b => set true; pure b
| ForInStep.yield b => forLoop (i + 1) b
else
pure (.yield b)
forLoop 0 init

pure b
let (val,b) ← forLoop 0 init false
if b then
return (ForInStep.done val)
else
return (ForInStep.yield val)

@[inline]
instance : EnumType (Fin n) :=
Expand All @@ -71,15 +76,21 @@ namespace EnumType
}

@[inline]
partial def Idx.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) :=
let rec @[specialize] forLoop (i : USize) (b : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do
partial def Idx.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Idx n → β → m (ForInStep β)) := do
-- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler
-- seems to have much better time optimizing code with `StateT`
let rec @[specialize] forLoop (i : USize) (b : β) : StateT Bool m β := do
if h : i < n then
match (← f ⟨i,h⟩ b) with
| ForInStep.done b => pure (.done b)
| ForInStep.done b => set true; pure b
| ForInStep.yield b => forLoop (i + 1) b
else
pure (.yield b)
forLoop 0 init
pure b
let (val,b) ← forLoop 0 init false
if b then
return (ForInStep.done val)
else
return (ForInStep.yield val)

@[inline]
partial instance : EnumType (Idx n) :=
Expand All @@ -90,15 +101,22 @@ namespace EnumType
}

@[inline]
partial def Idx'.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) :=
let rec @[specialize] forLoop (i : Int64) (val : β) (_ := (⟨init⟩ : Inhabited β)) : m (ForInStep β) := do
partial def Idx'.forIn {m : TypeType} [Monad m] {β : Type} (init : β) (f : Idx' a b → β → m (ForInStep β)) := do
-- Here we use `StateT Bool m β` instead of `m (ForInStep β)` as compiler
-- seems to have much better time optimizing code with `StateT`
let rec @[specialize] forLoop (i : Int64) (val : β) : StateT Bool m β := do
if _h : i ≤ b then
match (← f ⟨i,sorry_proof⟩ val) with
| ForInStep.done val => pure (.done val)
| ForInStep.done val => set true; pure val
| ForInStep.yield val => forLoop (i + 1) val
else
pure (.yield val)
forLoop a init
pure val
let (val,b) ← forLoop a init false
if b then
return (ForInStep.done val)
else
return (ForInStep.yield val)


@[inline]
partial instance : EnumType (Idx' a b) :=
Expand Down
7 changes: 5 additions & 2 deletions SciLean/Data/Index.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export Index (toIdx fromIdx)

namespace Index

@[macro_inline]
instance : Index Empty where
size := 0
isValid := true
Expand All @@ -33,7 +34,7 @@ instance : Index Empty where
fromIdx_toIdx := sorry_proof
toIdx_fromIdx := sorry_proof


@[macro_inline]
instance : Index Unit where
size := 1
isValid := true
Expand All @@ -44,7 +45,7 @@ instance : Index Unit where
fromIdx_toIdx := sorry_proof
toIdx_fromIdx := sorry_proof


@[macro_inline]
instance : Index (Idx n) where
size := n
isValid := true
Expand All @@ -67,6 +68,7 @@ instance : Index (Idx' a b) where


-- Row major ordering, this respects `<` defined on `ι × κ`
@[macro_inline]
instance [Index ι] [Index κ] : Index (ι×κ) where
size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize
isValid :=
Expand All @@ -86,6 +88,7 @@ instance [Index ι] [Index κ] : Index (ι×κ) where


-- Row major ordering, this respects `<` defined on `ι × κ`
@[macro_inline]
instance [Index ι] [Index κ] : Index (ι×ₗκ) where
size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize
isValid :=
Expand Down

0 comments on commit 3a7aaf4

Please sign in to comment.