Skip to content

Commit

Permalink
Implement a faster and unbiased version of shuffle
Browse files Browse the repository at this point in the history
Naive implementation using sorting lists is too slow. Here we add basic
support of mutable boxed arrays and use it to implement Fisher-Yates shuffle
algorithm that proves to be at least x20 more efficient

```
  shuffle
    uniformShuffleList:  OK (0.72s)
      5.43 ms ± 155 μs
    uniformShuffleListM: OK (0.21s)
      14.2 ms ± 818 μs
    naiveShuffleListM:   OK (0.92s)
      313  ms ± 5.7 ms
```
  • Loading branch information
lehins committed Dec 25, 2024
1 parent 6b30bd9 commit a79f427
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 37 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
* Add `Seed`, `SeedGen`, `seedSize`, `mkSeed` and `unSeed`:
[#162](https://github.com/haskell/random/pull/162)
* Add `SplitGen` and `splitGen`: [#160](https://github.com/haskell/random/pull/160)
* Add `shuffleList` and `shuffleListM`: [#140](https://github.com/haskell/random/pull/140)
* Add `unifromShuffleList` and `unifromShuffleListM`: [#140](https://github.com/haskell/random/pull/140)
* Add `uniformWordR`: [#140](https://github.com/haskell/random/pull/140)
* Add `mkStdGen64`: [#155](https://github.com/haskell/random/pull/155)
* Add `uniformListRM`, `uniformList`, `uniformListR`, `uniforms` and `uniformRs`:
[#154](https://github.com/haskell/random/pull/154)
Expand Down
18 changes: 17 additions & 1 deletion bench/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Main (main) where
import Control.Monad
import Control.Monad.State.Strict
import Data.Int
import Data.List (sortOn)
import Data.Proxy
import Data.Typeable
import Data.Word
Expand Down Expand Up @@ -263,9 +264,15 @@ main = do
, env getStdGen $ \gen ->
bench "uniformByteArray 100MB" $ nf (\n -> uniformByteArray False n gen) sz100MiB
, env getStdGen $ \gen ->
bench "genByteString 100MB" $ nf (\k -> genByteString k gen) sz100MiB
bench "genByteString 100MB" $ nf (`genByteString` gen) sz100MiB
]
]
, env (pure [0 :: Integer .. 200000]) $ \xs ->
bgroup "shuffle"
[ env getStdGen $ bench "uniformShuffleList" . nf (uniformShuffleList xs)
, env getStdGen $ bench "uniformShuffleListM" . nf (`runStateGen` uniformShuffleListM xs)
, env getStdGen $ bench "naiveShuffleListM" . nf (`runStateGen` naiveShuffleListM xs)
]
]

pureUniformRFullBench ::
Expand Down Expand Up @@ -351,3 +358,12 @@ fillMutablePrimArrayM f ma g = do
go 0
unsafeFreezePrimArray ma
#endif


naiveShuffleListM :: StatefulGen g m => [a] -> g -> m [a]
naiveShuffleListM xs gen = do
is <- uniformListM n gen
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
where
!n = length xs
{-# INLINE naiveShuffleListM #-}
17 changes: 9 additions & 8 deletions src/System/Random.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ module System.Random
, uniformRs
, uniformList
, uniformListR
, shuffleList
, uniformShuffleList
-- ** Bytes
, uniformByteArray
, uniformByteString
Expand Down Expand Up @@ -94,6 +94,7 @@ import Data.IORef
import Data.Word
import Foreign.C.Types
import GHC.Exts
import System.Random.Array (shuffleListST)
import System.Random.GFinite (Finite)
import System.Random.Internal
import System.Random.Seed
Expand Down Expand Up @@ -294,18 +295,18 @@ uniformListR :: (UniformRange a, RandomGen g) => Int -> (a, a) -> g -> ([a], g)
uniformListR n r g = runStateGen g (uniformListRM n r)
{-# INLINE uniformListR #-}

-- | Shuffle elements of a list in a random order.
-- | Shuffle elements of a list in a uniformly random order.
--
-- ====__Examples__
--
-- >>> let gen = mkStdGen 2023
-- >>> shuffleList ['a'..'z'] gen
-- ("renlhfqmgptwksdiyavbxojzcu",StdGen {unStdGen = SMGen 9882508430712573120 1920468677557965761})
-- >>> uniformShuffleList "ELVIS" $ mkStdGen 252
-- ("LIVES",StdGen {unStdGen = SMGen 17676540583805057877 5302934877338729551})
--
-- @since 1.3.0
shuffleList :: RandomGen g => [a] -> g -> ([a], g)
shuffleList xs g = runStateGen g (shuffleListM xs)
{-# INLINE shuffleList #-}
uniformShuffleList :: RandomGen g => [a] -> g -> ([a], g)
uniformShuffleList xs g =
runStateGenST g $ \gen -> shuffleListST (`uniformWordR` gen) xs
{-# INLINE uniformShuffleList #-}

-- | Generates a 'ByteString' of the specified size using a pure pseudo-random
-- number generator. See 'uniformByteStringM' for the monadic version.
Expand Down
156 changes: 156 additions & 0 deletions src/System/Random/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ module System.Random.Array
, byteArrayToShortByteString
, getSizeOfMutableByteArray
, shortByteStringToByteString
-- ** MutableArray
, Array (..)
, MutableArray (..)
, newMutableArray
, freezeMutableArray
, writeArray
, shuffleListM
, shuffleListST
) where

import Control.Monad.Trans (lift, MonadTrans)
import Control.Monad (when)
import Control.Monad.ST
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
Expand All @@ -54,6 +63,10 @@ import Data.ByteString (ByteString)
wordSizeInBits :: Int
wordSizeInBits = finiteBitSize (0 :: Word)

----------------
-- Byte Array --
----------------

-- Architecture independent helpers:

sizeOfByteArray :: ByteArray -> Int
Expand Down Expand Up @@ -204,3 +217,146 @@ pinnedByteArrayToForeignPtr ba# =
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}
#endif

-----------------
-- Boxed Array --
-----------------

data Array a = Array (Array# a)

data MutableArray s a = MutableArray (MutableArray# s a)

newMutableArray :: Int -> a -> ST s (MutableArray s a)
newMutableArray (I# n#) a =
ST $ \s# ->
case newArray# n# a s# of
(# s'#, ma# #) -> (# s'#, MutableArray ma# #)
{-# INLINE newMutableArray #-}

freezeMutableArray :: MutableArray s a -> ST s (Array a)
freezeMutableArray (MutableArray ma#) =
ST $ \s# ->
case unsafeFreezeArray# ma# s# of
(# s'#, a# #) -> (# s'#, Array a# #)
{-# INLINE freezeMutableArray #-}

sizeOfMutableArray :: MutableArray s a -> Int
sizeOfMutableArray (MutableArray ma#) = I# (sizeofMutableArray# ma#)
{-# INLINE sizeOfMutableArray #-}

readArray :: MutableArray s a -> Int -> ST s a
readArray (MutableArray ma#) (I# i#) = ST (readArray# ma# i#)
{-# INLINE readArray #-}

writeArray :: MutableArray s a -> Int -> a -> ST s ()
writeArray (MutableArray ma#) (I# i#) a = st_ (writeArray# ma# i# a)
{-# INLINE writeArray #-}

swapArray :: MutableArray s a -> Int -> Int -> ST s ()
swapArray ma i j = do
x <- readArray ma i
y <- readArray ma j
writeArray ma j x
writeArray ma i y
{-# INLINE swapArray #-}

-- | Write contents of the list into the mutable array. Make sure that array is big
-- enough or segfault will happen.
fillMutableArrayFromList :: MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList ma = go 0
where
go _ [] = pure ()
go i (x:xs) = writeArray ma i x >> go (i + 1) xs
{-# INLINE fillMutableArrayFromList #-}

readListFromMutableArray :: MutableArray s a -> ST s [a]
readListFromMutableArray ma = go (len - 1) []
where
len = sizeOfMutableArray ma
go i !acc
| i >= 0 = do
x <- readArray ma i
go (i - 1) (x : acc)
| otherwise = pure acc
{-# INLINE readListFromMutableArray #-}


-- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
--
-- @
-- [ (0, n - 1)
-- , (0, n - 2)
-- , (0, n - 3)
-- , ...
-- , (0, 3)
-- , (0, 2)
-- , (0, 1)
-- ]
-- @
genSwapIndices
:: Monad m
=> (Word -> m Word)
-- ^ Action that generates a Word in the supplied range.
-> Word
-- ^ Number of index swaps to generate.
-> m [Int]
genSwapIndices genWordR n = go 1 []
where
go i !acc
| i >= n = pure acc
| otherwise = do
x <- genWordR i
let !xi = fromIntegral x
go (i + 1) (xi : acc)
{-# INLINE genSwapIndices #-}


-- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
-- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
-- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
-- perform the actual swaps.
shuffleListM :: Monad m => (Word -> m Word) -> [a] -> m [a]
shuffleListM genWordR ls
| len <= 1 = pure ls
| otherwise = do
swapIxs <- genSwapIndices genWordR (fromIntegral len)
pure $ runST $ do
ma <- newMutableArray len $ error "Impossible: shuffleListM"
fillMutableArrayFromList ma ls

-- Shuffle elements of the mutable array according to the uniformly generated index swap list
let goSwap _ [] = pure ()
goSwap i (j:js) = swapArray ma i j >> goSwap (i - 1) js
goSwap (len - 1) swapIxs

readListFromMutableArray ma
where
len = length ls
{-# INLINE shuffleListM #-}

-- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
-- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
-- result that will be produced will differ for the same generator, since the order in which index
-- swaps are generated is reversed.
--
-- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
-- used for implementing the pure shuffle.
shuffleListST :: (Monad (t (ST s)), MonadTrans t) => (Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
shuffleListST genWordR ls
| len <= 1 = pure ls
| otherwise = do
ma <- lift $ newMutableArray len $ error "Impossible: shuffleListST"
lift $ fillMutableArrayFromList ma ls

-- Shuffle elements of the mutable array according to the uniformly generated index swap
let goSwap i =
when (i > 0) $ do
j <- genWordR $ (fromIntegral :: Int -> Word) i
lift $ swapArray ma i ((fromIntegral :: Word -> Int) j)
goSwap (i - 1)
goSwap (len - 1)

lift $ readListFromMutableArray ma
where
len = length ls
{-# INLINE shuffleListST #-}
44 changes: 20 additions & 24 deletions src/System/Random/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ module System.Random.Internal
, Uniform(..)
, uniformViaFiniteM
, UniformRange(..)
, uniformWordR
, uniformDouble01M
, uniformDoublePositive01M
, uniformFloat01M
Expand All @@ -65,7 +66,6 @@ module System.Random.Internal
, uniformEnumRM
, uniformListM
, uniformListRM
, shuffleListM
, isInRangeOrd
, isInRangeEnum

Expand Down Expand Up @@ -108,7 +108,6 @@ import Data.ByteString (ByteString)
import Data.ByteString.Short.Internal (ShortByteString(SBS))
import Data.IORef (IORef, newIORef)
import Data.Int
import Data.List (sortOn)
import Data.Word
import Foreign.C.Types
import Foreign.Storable (Storable)
Expand Down Expand Up @@ -221,7 +220,6 @@ class RandomGen g where
-- /Note/ - This function will be removed from the type class in the next major release as
-- it is no longer needed because of `unsafeUniformFillMutableByteArray`.
--
--
-- @since 1.2.0
genShortByteString :: Int -> g -> (ShortByteString, g)
genShortByteString n g =
Expand Down Expand Up @@ -273,10 +271,10 @@ class RandomGen g where
{-# DEPRECATED split "In favor of `splitGen`" #-}

-- | Pseudo-random generators that can be split into two separate and independent
-- psuedo-random generators can have an instance for this type class.
-- psuedo-random generators should provide an instance for this type class.
--
-- Historically this functionality was included in the `RandomGen` type class in the
-- `split` function, however, few pseudo-random generators posses this property of
-- `split` function, however, few pseudo-random generators possess this property of
-- splittability. This lead the old `split` function being usually implemented in terms of
-- `error`.
--
Expand Down Expand Up @@ -784,25 +782,6 @@ uniformListRM :: (StatefulGen g m, UniformRange a) => Int -> (a, a) -> g -> m [a
uniformListRM n range gen = replicateM n (uniformRM range gen)
{-# INLINE uniformListRM #-}

-- | Shuffle elements of a list in a random order.
--
-- ====__Examples__
--
-- >>> import System.Random.Stateful
-- >>> let pureGen = mkStdGen 2023
-- >>> g <- newIOGenM pureGen
-- >>> shuffleListM ['a'..'z'] g :: IO String
-- "renlhfqmgptwksdiyavbxojzcu"
--
-- @since 1.3.0
shuffleListM :: StatefulGen g m => [a] -> g -> m [a]
shuffleListM xs gen = do
is <- uniformListM n gen
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
where
!n = length xs
{-# INLINE shuffleListM #-}

-- | The standard pseudo-random number generator.
newtype StdGen = StdGen { unStdGen :: SM.SMGen }
deriving (Show, RandomGen, SplitGen, NFData)
Expand Down Expand Up @@ -1128,6 +1107,23 @@ instance UniformRange Word where
{-# INLINE uniformRM #-}
isInRange = isInRangeOrd

-- | Architecture specific `Word` generation in the specified lower range
--
-- @since 1.3.0
uniformWordR ::
StatefulGen g m
=> Word
-- ^ Maximum value to generate
-> g
-- ^ Stateful generator
-> m Word
uniformWordR r
| wordSizeInBits == 64 =
fmap (fromIntegral :: Word64 -> Word) . uniformWord64R ((fromIntegral :: Word -> Word64) r)
| otherwise =
fmap (fromIntegral :: Word32 -> Word) . uniformWord32R ((fromIntegral :: Word -> Word32) r)
{-# INLINE uniformWordR #-}

instance Uniform Word8 where
uniformM = uniformWord8
{-# INLINE uniformM #-}
Expand Down
Loading

0 comments on commit a79f427

Please sign in to comment.