diff --git a/CHANGELOG.md b/CHANGELOG.md index bf876a2f..f6b35d99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,9 @@ # 1.3.0 -* Add `Uniform` instance for `Maybe` and `Either` -* Add `SplitGen` and `splitGen` +* Add `Uniform` instance for `Maybe` and `Either`: [#167](https://github.com/haskell/random/pull/167) +* Add `Seed`, `SeedGen`, `seedSize`, `mkSeed` and `unSeed`: + [#162](https://github.com/haskell/random/pull/162) +* Add `SplitGen` and `splitGen`: [#160](https://github.com/haskell/random/pull/160) * Add `shuffleList` and `shuffleListM`: [#140](https://github.com/haskell/random/pull/140) * Add `mkStdGen64`: [#155](https://github.com/haskell/random/pull/155) * Add `uniformListRM`, `uniformList`, `uniformListR`, `uniforms` and `uniformRs`: diff --git a/bench-legacy/SimpleRNGBench.hs b/bench-legacy/SimpleRNGBench.hs index b941a1b8..dfffeb86 100644 --- a/bench-legacy/SimpleRNGBench.hs +++ b/bench-legacy/SimpleRNGBench.hs @@ -1,8 +1,9 @@ -{-# LANGUAGE BangPatterns, ScopedTypeVariables, ForeignFunctionInterface #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fwarn-unused-imports #-} -- | A simple script to do some very basic timing of the RNGs. - module Main where import System.Exit (exitSuccess, exitFailure) diff --git a/random.cabal b/random.cabal index 8e238dc3..b8246409 100644 --- a/random.cabal +++ b/random.cabal @@ -85,6 +85,7 @@ library exposed-modules: System.Random System.Random.Internal + System.Random.Seed System.Random.Stateful other-modules: System.Random.GFinite @@ -131,6 +132,7 @@ test-suite spec other-modules: Spec.Range Spec.Run + Spec.Seed Spec.Stateful default-language: Haskell2010 diff --git a/src/System/Random.hs b/src/System/Random.hs index 1e2e497c..a0bce1b3 100644 --- a/src/System/Random.hs +++ b/src/System/Random.hs @@ -37,6 +37,8 @@ module System.Random , Uniform , UniformRange , Finite + -- ** Seed + , module System.Random.Seed -- * Generators for sequences of pseudo-random bytes -- ** Lists , uniforms @@ -94,6 +96,7 @@ import Foreign.C.Types import GHC.Exts import System.Random.GFinite (Finite) import System.Random.Internal +import System.Random.Seed import qualified System.Random.SplitMix as SM -- $introduction diff --git a/src/System/Random/GFinite.hs b/src/System/Random/GFinite.hs index 6b2b1e1e..d179a0ba 100644 --- a/src/System/Random/GFinite.hs +++ b/src/System/Random/GFinite.hs @@ -1,10 +1,3 @@ --- | --- Module : System.Random.GFinite --- Copyright : (c) Andrew Lelechenko 2020 --- License : BSD-style (see the file LICENSE in the 'random' repository) --- Maintainer : libraries@haskell.org --- - {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} @@ -12,6 +5,12 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} +-- | +-- Module : System.Random.GFinite +-- Copyright : (c) Andrew Lelechenko 2020 +-- License : BSD-style (see the file LICENSE in the 'random' repository) +-- Maintainer : libraries@haskell.org +-- module System.Random.GFinite ( Cardinality(..) , Finite(..) diff --git a/src/System/Random/Internal.hs b/src/System/Random/Internal.hs index f6bba9e2..0a102de7 100644 --- a/src/System/Random/Internal.hs +++ b/src/System/Random/Internal.hs @@ -3,18 +3,18 @@ {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GHCForeignImportPrim #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnliftedFFITypes #-} -{-# LANGUAGE TypeFamilyDependencies #-} {-# OPTIONS_HADDOCK hide, not-home #-} -- | @@ -29,6 +29,8 @@ module System.Random.Internal (-- * Pure and monadic pseudo-random number generator interfaces RandomGen(..) , SplitGen(..) + , Seed(..) + -- * Stateful , StatefulGen(..) , FrozenGen(..) , ThawedGen(..) @@ -77,12 +79,20 @@ module System.Random.Internal , genByteArrayST , genShortByteStringIO , genShortByteStringST + , defaultUnsafeFillMutableByteArrayT , defaultUnsafeUniformFillMutableByteArray -- ** Helpers for dealing with MutableByteArray , newMutableByteArray , newPinnedMutableByteArray , freezeMutableByteArray , writeWord8 + , writeWord64LE + , indexWord8 + , indexWord64LE + , indexByteSliceWord64LE + , sizeOfByteArray + , shortByteStringToByteArray + , byteArrayToShortByteString ) where import Control.Arrow @@ -95,7 +105,8 @@ import Control.Monad.State.Strict (MonadState(..), State, StateT(..), execStateT import Control.Monad.Trans (lift, MonadTrans) import Data.Array.Byte (ByteArray(..), MutableByteArray(..)) import Data.Bits -import Data.ByteString.Short.Internal (ShortByteString(SBS), fromShort) +import Data.ByteString.Short.Internal (ShortByteString(SBS)) +import qualified Data.ByteString.Short.Internal as SBS (fromShort) import Data.IORef (IORef, newIORef) import Data.Int import Data.List (sortOn) @@ -123,6 +134,19 @@ import Data.ByteString (ByteString) -- Needed for WORDS_BIGENDIAN #include "MachDeps.h" +-- | This is a binary form of pseudo-random number generator's state. It is designed to be +-- safe and easy to use for input/output operations like restoring from file, transmitting +-- over the network, etc. +-- +-- Constructor is not exported, becasue it is important for implementation to enforce the +-- invariant of the underlying byte array being of the exact same length as the generator has +-- specified in `System.Random.Seed.SeedSize`. Use `System.Random.Seed.mkSize` and +-- `System.Random.Seed.unSize` to get access to the raw bytes in a safe manner. +-- +-- @since 1.3.0 +newtype Seed g = Seed ByteArray + deriving (Eq, Ord, Show) + -- | 'RandomGen' is an interface to pure pseudo-random number generators. -- @@ -280,7 +304,7 @@ class RandomGen g => SplitGen g where -- -- @since 1.2.0 class Monad m => StatefulGen g m where - {-# MINIMAL (uniformWord32|uniformWord64) #-} + {-# MINIMAL uniformWord32|uniformWord64 #-} -- | @uniformWord32R upperBound g@ generates a 'Word32' that is uniformly -- distributed over the range @[0, upperBound]@. -- @@ -492,7 +516,7 @@ genByteArrayST isPinned n0 action = do mba <- if isPinned then newPinnedMutableByteArray n else newMutableByteArray n - runIdentityT $ defaultUnsafeUniformFillMutableByteArrayT mba 0 n (lift action) + runIdentityT $ defaultUnsafeFillMutableByteArrayT mba 0 n (lift action) freezeMutableByteArray mba {-# INLINE genByteArrayST #-} @@ -520,14 +544,14 @@ uniformFillMutableByteArray mba i0 n g = do unsafeUniformFillMutableByteArray mba offset numBytes g {-# INLINE uniformFillMutableByteArray #-} -defaultUnsafeUniformFillMutableByteArrayT :: +defaultUnsafeFillMutableByteArrayT :: (Monad (t (ST s)), MonadTrans t) => MutableByteArray s -> Int -> Int -> t (ST s) Word64 -> t (ST s) () -defaultUnsafeUniformFillMutableByteArrayT mba offset n gen64 = do +defaultUnsafeFillMutableByteArrayT mba offset n gen64 = do let !n64 = n `quot` 8 !endIx64 = offset + n64 * 8 !nrem = n `rem` 8 @@ -547,14 +571,14 @@ defaultUnsafeUniformFillMutableByteArrayT mba offset n gen64 = do -- still need using smaller generators (eg. uniformWord8), but that would -- result in inconsistent tail when total length is slightly varied. lift $ writeByteSliceWord64LE mba (endIx - nrem) endIx w64 -{-# INLINEABLE defaultUnsafeUniformFillMutableByteArrayT #-} -{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT +{-# INLINEABLE defaultUnsafeFillMutableByteArrayT #-} +{-# SPECIALIZE defaultUnsafeFillMutableByteArrayT :: MutableByteArray s -> Int -> Int -> IdentityT (ST s) Word64 -> IdentityT (ST s) () #-} -{-# SPECIALIZE defaultUnsafeUniformFillMutableByteArrayT +{-# SPECIALIZE defaultUnsafeFillMutableByteArrayT :: MutableByteArray s -> Int -> Int @@ -574,7 +598,7 @@ defaultUnsafeUniformFillMutableByteArray :: -> ST s g defaultUnsafeUniformFillMutableByteArray mba i0 n g = flip execStateT g - $ defaultUnsafeUniformFillMutableByteArrayT mba i0 n (state genWord64) + $ defaultUnsafeFillMutableByteArrayT mba i0 n (state genWord64) {-# INLINE defaultUnsafeUniformFillMutableByteArray #-} @@ -590,6 +614,9 @@ uniformByteString n g = -- Architecture independent helpers: +sizeOfByteArray :: ByteArray -> Int +sizeOfByteArray (ByteArray ba#) = I# (sizeofByteArray# ba#) + st_ :: (State# s -> State# s) -> ST s () st_ m# = ST $ \s# -> (# m# s#, () #) {-# INLINE st_ #-} @@ -631,12 +658,54 @@ writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx go (i + 1) (z `shiftR` 8) {-# INLINE writeByteSliceWord64LE #-} +indexWord8 :: + ByteArray + -> Int -- ^ Offset into immutable byte array in number of bytes + -> Word8 +indexWord8 (ByteArray ba#) (I# i#) = + W8# (indexWord8Array# ba# i#) +{-# INLINE indexWord8 #-} + +indexWord64LE :: + ByteArray + -> Int -- ^ Offset into immutable byte array in number of bytes + -> Word64 +#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806) +indexWord64LE ba i = indexByteSliceWord64LE ba i (i + 8) +#else +indexWord64LE (ByteArray ba#) (I# i#) + | wordSizeInBits == 64 = W64# (indexWord8ArrayAsWord64# ba# i#) + | otherwise = + let !w32l = W32# (indexWord8ArrayAsWord32# ba# i#) + !w32u = W32# (indexWord8ArrayAsWord32# ba# (i# +# 4#)) + in (fromIntegral w32u `shiftL` 32) .|. fromIntegral w32l +#endif +{-# INLINE indexWord64LE #-} + +indexByteSliceWord64LE :: + ByteArray + -> Int -- ^ Starting offset in number of bytes + -> Int -- ^ Ending offset in number of bytes + -> Word64 +indexByteSliceWord64LE ba fromByteIx toByteIx = goWord8 fromByteIx 0 + where + r = (toByteIx - fromByteIx) `rem` 8 + nPadBits = if r == 0 then 0 else 8 * (8 - r) + goWord8 i !w64 + | i < toByteIx = goWord8 (i + 1) (shiftL w64 8 .|. fromIntegral (indexWord8 ba i)) + | otherwise = byteSwap64 (shiftL w64 nPadBits) +{-# INLINE indexByteSliceWord64LE #-} + -- On big endian machines we need to write one byte at a time for consistency with little -- endian machines. Also for GHC versions prior to 8.6 we don't have primops that can -- write with byte offset, eg. writeWord8ArrayAsWord64# and writeWord8ArrayAsWord32#, so we -- also must fallback to writing one byte a time. Such fallback results in about 3 times -- slow down, which is not the end of the world. -writeWord64LE :: MutableByteArray s -> Int -> Word64 -> ST s () +writeWord64LE :: + MutableByteArray s + -> Int -- ^ Offset into mutable byte array in number of bytes + -> Word64 -- ^ 8 bytes that will be written into the supplied array + -> ST s () #if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806) writeWord64LE mba i w64 = writeByteSliceWord64LE mba i (i + 8) w64 @@ -662,6 +731,10 @@ getSizeOfMutableByteArray (MutableByteArray mba#) = #endif {-# INLINE getSizeOfMutableByteArray #-} +shortByteStringToByteArray :: ShortByteString -> ByteArray +shortByteStringToByteArray (SBS ba#) = ByteArray ba# +{-# INLINE shortByteStringToByteArray #-} + byteArrayToShortByteString :: ByteArray -> ShortByteString byteArrayToShortByteString (ByteArray ba#) = SBS ba# {-# INLINE byteArrayToShortByteString #-} @@ -671,12 +744,12 @@ byteArrayToShortByteString (ByteArray ba#) = SBS ba# shortByteStringToByteString :: ShortByteString -> ByteString shortByteStringToByteString ba = #if __GLASGOW_HASKELL__ < 802 - fromShort ba + SBS.fromShort ba #else let !(SBS ba#) = ba in if isTrue# (isByteArrayPinned# ba#) then pinnedByteArrayToByteString ba# - else fromShort ba + else SBS.fromShort ba {-# INLINE shortByteStringToByteString #-} pinnedByteArrayToByteString :: ByteArray# -> ByteString diff --git a/src/System/Random/Seed.hs b/src/System/Random/Seed.hs new file mode 100644 index 00000000..8d6c5db1 --- /dev/null +++ b/src/System/Random/Seed.hs @@ -0,0 +1,325 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE Trustworthy #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} +{-# OPTIONS_GHC -Wno-orphans #-} +-- | +-- Module : System.Random.Seed +-- Copyright : (c) Alexey Kuleshevich 2024 +-- License : BSD-style (see the file LICENSE in the 'random' repository) +-- Maintainer : libraries@haskell.org +-- + +module System.Random.Seed + ( SeedGen(..) + , -- ** Seed + Seed + , seedSize + , mkSeed + , unSeed + , mkSeedFromByteString + , unSeedToByteString + , withSeed + , withSeedM + , withSeedFile + , seedGenTypeName + , nonEmptyToSeed + , nonEmptyFromSeed + ) where + +import Control.Monad (unless) +import qualified Control.Monad.Fail as F +import Control.Monad.IO.Class +import Control.Monad.ST +import Control.Monad.State.Strict (get, put, runStateT) +import Data.Array.Byte (ByteArray(..)) +import Data.Bits +import qualified Data.ByteString as BS +import qualified Data.ByteString.Short.Internal as SBS (fromShort, toShort) +import Data.Coerce +import Data.Functor.Identity (runIdentity) +import Data.List.NonEmpty as NE (NonEmpty(..), nonEmpty, toList) +import Data.Typeable +import Data.Word +import GHC.TypeLits (Nat, KnownNat, natVal, type (<=)) +import System.Random.Internal +import qualified System.Random.SplitMix as SM +import qualified System.Random.SplitMix32 as SM32 + + +-- | Interface for converting a pure pseudo-random number generator to and from non-empty +-- sequence of bytes. Seeds are stored in Little-Endian order regardless of the platform +-- it is being used on, which provides cross-platform compatibility, while providing +-- optimal performance for the most common platform type. +-- +-- Conversion to and from a `Seed` serves as a building block for implementing +-- serialization for any pure or frozen pseudo-random number generator. +-- +-- It is not trivial to implement platform independence. For this reason this type class +-- has two alternative ways of creating an instance for this class. The easiest way for +-- constructing a platform indepent seed is by converting the inner state of a generator +-- to and from a list of 64 bit words using `unseedGen64` and `seedGen64` respectively. In +-- that case cross-platform support will be handled automaticaly. +-- +-- >>> :set -XDataKinds -XTypeFamilies +-- >>> import Data.Word (Word8, Word32) +-- >>> import Data.Bits ((.|.), shiftR, shiftL) +-- >>> import Data.List.NonEmpty (NonEmpty ((:|))) +-- >>> data FiveByteGen = FiveByteGen Word8 Word32 deriving Show +-- >>> :{ +-- instance SeedGen FiveByteGen where +-- type SeedSize FiveByteGen = 5 +-- seedGen64 (w64 :| _) = +-- FiveByteGen (fromIntegral (w64 `shiftR` 32)) (fromIntegral w64) +-- unseedGen64 (FiveByteGen x1 x4) = +-- let w64 = (fromIntegral x1 `shiftL` 32) .|. fromIntegral x4 +-- in (w64 :| []) +-- :} +-- +-- >>> FiveByteGen 0x80 0x01020304 +-- FiveByteGen 128 16909060 +-- >>> seedGen (unseedGen (FiveByteGen 0x80 0x01020304)) +-- FiveByteGen 128 16909060 +-- >>> unseedGen (FiveByteGen 0x80 0x01020304) +-- Seed [0x04, 0x03, 0x02, 0x01, 0x80] +-- >>> unseedGen64 (FiveByteGen 0x80 0x01020304) +-- 549772722948 :| [] +-- +-- However, when performance is of utmost importance or default handling of cross platform +-- independence is not sufficient, then an adventurous developer can try implementing +-- conversion into bytes directly with `unseedGen` and `seedGen`. +-- +-- Properties that must hold: +-- +-- @ +-- > seedGen (unseedGen gen) == gen +-- @ +-- +-- @ +-- > seedGen64 (unseedGen64 gen) == gen +-- @ +-- +-- Note, that there is no requirement for every `Seed` to roundtrip, eg. this proprty does +-- not even hold for `StdGen`: +-- +-- >>> let seed = nonEmptyToSeed (0xab :| [0xff00]) :: Seed StdGen +-- >>> seed == unseedGen (seedGen seed) +-- False +-- +-- @since 1.3.0 +class (KnownNat (SeedSize g), 1 <= SeedSize g, Typeable g) => SeedGen g where + -- | Number of bytes that is required for storing the full state of a pseudo-random + -- number generator. It should be big enough to satisfy the roundtrip property: + -- + -- @ + -- > seedGen (unseedGen gen) == gen + -- @ + -- + type SeedSize g :: Nat + {-# MINIMAL (seedGen, unseedGen)|(seedGen64, unseedGen64) #-} + + -- | Convert from a binary representation to a pseudo-random number generator + -- + -- @since 1.3.0 + seedGen :: Seed g -> g + seedGen = seedGen64 . nonEmptyFromSeed + + -- | Convert to a binary representation of a pseudo-random number generator + -- + -- @since 1.3.0 + unseedGen :: g -> Seed g + unseedGen = nonEmptyToSeed . unseedGen64 + + -- | Construct pseudo-random number generator from a list of words. Whenever list does + -- not have enough bytes to satisfy the `SeedSize` requirement, it will be padded with + -- zeros. On the other hand when it has more than necessary, extra bytes will be dropped. + -- + -- For example if `SeedSize` is set to 2, then only the lower 16 bits of the first + -- element in the list will be used. + -- + -- @since 1.3.0 + seedGen64 :: NonEmpty Word64 -> g + seedGen64 = seedGen . nonEmptyToSeed + + -- | Convert pseudo-random number generator to a list of words + -- + -- In case when `SeedSize` is not a multiple of 8, then the upper bits of the last word + -- in the list will be set to zero. + -- + -- @since 1.3.0 + unseedGen64 :: g -> NonEmpty Word64 + unseedGen64 = nonEmptyFromSeed . unseedGen + +instance SeedGen StdGen where + type SeedSize StdGen = SeedSize SM.SMGen + seedGen = coerce (seedGen :: Seed SM.SMGen -> SM.SMGen) + unseedGen = coerce (unseedGen :: SM.SMGen -> Seed SM.SMGen) + +instance SeedGen g => SeedGen (StateGen g) where + type SeedSize (StateGen g) = SeedSize g + seedGen = coerce (seedGen :: Seed g -> g) + unseedGen = coerce (unseedGen :: g -> Seed g) + +instance SeedGen SM.SMGen where + type SeedSize SM.SMGen = 16 + seedGen (Seed ba) = + SM.seedSMGen (indexWord64LE ba 0) (indexWord64LE ba 8) + unseedGen g = + case SM.unseedSMGen g of + (seed, gamma) -> Seed $ runST $ do + mba <- newMutableByteArray 16 + writeWord64LE mba 0 seed + writeWord64LE mba 8 gamma + freezeMutableByteArray mba + +instance SeedGen SM32.SMGen where + type SeedSize SM32.SMGen = 8 + seedGen (Seed ba) = + let x = indexWord64LE ba 0 + seed, gamma :: Word32 + seed = fromIntegral (shiftR x 32) + gamma = fromIntegral x + in SM32.seedSMGen seed gamma + unseedGen g = + let seed, gamma :: Word32 + (seed, gamma) = SM32.unseedSMGen g + in Seed $ runST $ do + mba <- newMutableByteArray 8 + let w64 :: Word64 + w64 = shiftL (fromIntegral seed) 32 .|. fromIntegral gamma + writeWord64LE mba 0 w64 + freezeMutableByteArray mba + +instance SeedGen g => Uniform (Seed g) where + uniformM = fmap Seed . uniformByteArrayM False (seedSize @g) + +-- | Get the expected size of the `Seed` in number bytes +-- +-- @since 1.3.0 +seedSize :: forall g. SeedGen g => Int +seedSize = fromIntegral $ natVal (Proxy :: Proxy (SeedSize g)) + +-- | Construct a `Seed` from a `ByteArray` of expected length. Whenever `ByteArray` does +-- not match the `SeedSize` specified by the pseudo-random generator, this function will +-- `F.fail`. +-- +-- @since 1.3.0 +mkSeed :: forall g m. (SeedGen g, F.MonadFail m) => ByteArray -> m (Seed g) +mkSeed ba = do + unless (sizeOfByteArray ba == seedSize @g) $ do + F.fail $ "Unexpected number of bytes: " + ++ show (sizeOfByteArray ba) + ++ ". Exactly " + ++ show (seedSize @g) + ++ " bytes is required by the " + ++ show (seedGenTypeName @g) + pure $ Seed ba + +-- | Helper function that allows for operating directly on the `Seed`, while supplying a +-- function that uses the pseudo-random number generator that is constructed from that +-- `Seed`. +-- +-- ====__Example__ +-- +-- >>> :set -XTypeApplications +-- >>> import System.Random +-- >>> withSeed (nonEmptyToSeed (pure 2024) :: Seed StdGen) (uniform @Int) +-- (1039666877624726199,Seed [0xe9, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) +-- +-- @since 1.3.0 +withSeed :: SeedGen g => Seed g -> (g -> (a, g)) -> (a, Seed g) +withSeed seed f = runIdentity (withSeedM seed (pure . f)) + +-- | Same as `withSeed`, except it is useful with monadic computation and frozen generators. +-- +-- See `System.Random.Stateful.withMutableSeedGen` for a helper that also handles seeds +-- for mutable pseduo-random number generators. +-- +-- @since 1.3.0 +withSeedM :: (SeedGen g, Functor f) => Seed g -> (g -> f (a, g)) -> f (a, Seed g) +withSeedM seed f = fmap unseedGen <$> f (seedGen seed) + +-- | This is a function that shows the name of the generator type, which is useful for +-- error reporting. +-- +-- @since 1.3.0 +seedGenTypeName :: forall g. SeedGen g => String +seedGenTypeName = show (typeOf (Proxy @g)) + + +-- | Just like `mkSeed`, but uses `ByteString` as argument. Results in a memcopy of the seed. +-- +-- @since 1.3.0 +mkSeedFromByteString :: (SeedGen g, F.MonadFail m) => BS.ByteString -> m (Seed g) +mkSeedFromByteString = mkSeed . shortByteStringToByteArray . SBS.toShort + +-- | Unwrap the `Seed` and get the underlying `ByteArray` +-- +-- @since 1.3.0 +unSeed :: Seed g -> ByteArray +unSeed (Seed ba) = ba + +-- | Just like `unSeed`, but produced a `ByteString`. Results in a memcopy of the seed. +-- +-- @since 1.3.0 +unSeedToByteString :: Seed g -> BS.ByteString +unSeedToByteString = SBS.fromShort . byteArrayToShortByteString . unSeed + + +-- | Read the seed from a file and use it for constructing a pseudo-random number +-- generator. After supplied action has been applied to the constructed generator, the +-- resulting generator will be converted back to a seed and written to the same file. +-- +-- @since 1.3.0 +withSeedFile :: (SeedGen g, MonadIO m) => FilePath -> (g -> m (a, g)) -> m a +withSeedFile fileName f = do + bs <- liftIO $ BS.readFile fileName + seed <- liftIO $ mkSeedFromByteString bs + (res, seed') <- withSeedM seed f + liftIO $ BS.writeFile fileName $ unSeedToByteString seed' + pure res + +-- | Construct a seed from a list of 64-bit words. At most `SeedSize` many bytes will be used. +-- +-- @since 1.3.0 +nonEmptyToSeed :: forall g. SeedGen g => NonEmpty Word64 -> Seed g +nonEmptyToSeed xs = Seed $ runST $ do + let n = seedSize @g + mba <- newMutableByteArray n + _ <- flip runStateT (NE.toList xs) $ do + defaultUnsafeFillMutableByteArrayT mba 0 n $ do + get >>= \case + [] -> pure 0 + w:ws -> w <$ put ws + freezeMutableByteArray mba + +-- | Convert a `Seed` to a list of 64bit words. +-- +-- @since 1.3.0 +nonEmptyFromSeed :: forall g. SeedGen g => Seed g -> NonEmpty Word64 +nonEmptyFromSeed (Seed ba) = + case nonEmpty $ reverse $ goWord64 0 [] of + Just ne -> ne + Nothing -> -- Seed is at least 1 byte in size, so it can't be empty + error $ "Impossible: Seed for " + ++ seedGenTypeName @g + ++ " must be at least: " + ++ show (seedSize @g) + ++ " bytes, but got " + ++ show n + where + n = sizeOfByteArray ba + n8 = 8 * (n `quot` 8) + goWord64 i !acc + | i < n8 = goWord64 (i + 8) (indexWord64LE ba i : acc) + | i == n = acc + | otherwise = indexByteSliceWord64LE ba i n : acc diff --git a/src/System/Random/Stateful.hs b/src/System/Random/Stateful.hs index 532ef7ea..a1a3f190 100644 --- a/src/System/Random/Stateful.hs +++ b/src/System/Random/Stateful.hs @@ -7,7 +7,6 @@ {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} - -- | -- Module : System.Random.Stateful -- Copyright : (c) The University of Glasgow 2001 @@ -41,6 +40,8 @@ module System.Random.Stateful , ThawedGen(..) , withMutableGen , withMutableGen_ + , withMutableSeedGen + , withMutableSeedGen_ , randomM , randomRM , splitGenM @@ -136,6 +137,7 @@ import Control.Monad.IO.Class import Control.Monad.ST import GHC.Conc.Sync (STM, TVar, newTVar, newTVarIO, readTVar, writeTVar) import Control.Monad.State.Strict (MonadState, state) +import Data.Coerce import Data.IORef import Data.STRef import Foreign.Storable @@ -294,6 +296,21 @@ withMutableGen_ :: ThawedGen f m => f -> (MutableGen f m -> m a) -> m a withMutableGen_ fg action = thawGen fg >>= action +-- | Just like `withMutableGen`, except uses a `Seed` instead of a frozen generator. +-- +-- @since 1.3.0 +withMutableSeedGen :: (SeedGen g, ThawedGen g m) => Seed g -> (MutableGen g m -> m a) -> m (a, Seed g) +withMutableSeedGen seed f = withSeedM seed (`withMutableGen` f) + +-- | Just like `withMutableSeedGen`, except it doesn't return the final generator, only +-- the resulting value. This is slightly more efficient, since it doesn't incur overhead +-- from freezeing the mutable generator +-- +-- @since 1.3.0 +withMutableSeedGen_ :: (SeedGen g, ThawedGen g m) => Seed g -> (MutableGen g m -> m a) -> m a +withMutableSeedGen_ seed = withMutableGen_ (seedGen seed) + + -- | Generates a pseudo-random value using monadic interface and `Random` instance. -- -- ====__Examples__ @@ -353,6 +370,12 @@ newtype AtomicGenM g = AtomicGenM { unAtomicGenM :: IORef g} newtype AtomicGen g = AtomicGen { unAtomicGen :: g} deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) +-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families +instance SeedGen g => SeedGen (AtomicGen g) where + type SeedSize (AtomicGen g) = SeedSize g + seedGen = coerce (seedGen :: Seed g -> g) + unseedGen = coerce (unseedGen :: g -> Seed g) + -- | Creates a new 'AtomicGenM'. -- -- @since 1.2.0 @@ -444,6 +467,11 @@ newtype IOGenM g = IOGenM { unIOGenM :: IORef g } newtype IOGen g = IOGen { unIOGen :: g } deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) +-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families +instance SeedGen g => SeedGen (IOGen g) where + type SeedSize (IOGen g) = SeedSize g + seedGen = coerce (seedGen :: Seed g -> g) + unseedGen = coerce (unseedGen :: g -> Seed g) -- | Creates a new 'IOGenM'. -- @@ -515,6 +543,12 @@ newtype STGenM g s = STGenM { unSTGenM :: STRef s g } newtype STGen g = STGen { unSTGen :: g } deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) +-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families +instance SeedGen g => SeedGen (STGen g) where + type SeedSize (STGen g) = SeedSize g + seedGen = coerce (seedGen :: Seed g -> g) + unseedGen = coerce (unseedGen :: g -> Seed g) + -- | Creates a new 'STGenM'. -- -- @since 1.2.0 @@ -610,6 +644,12 @@ newtype TGenM g = TGenM { unTGenM :: TVar g } newtype TGen g = TGen { unTGen :: g } deriving (Eq, Ord, Show, RandomGen, SplitGen, Storable, NFData) +-- Standalone definition due to GHC-8.0 not supporting deriving with associated type families +instance SeedGen g => SeedGen (TGen g) where + type SeedSize (TGen g) = SeedSize g + seedGen = coerce (seedGen :: Seed g -> g) + unseedGen = coerce (unseedGen :: g -> Seed g) + -- | Creates a new 'TGenM' in `STM`. -- -- @since 1.2.1 diff --git a/test/Spec.hs b/test/Spec.hs index c6234193..2ab67cf2 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,11 +1,13 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} module Main (main) where import Control.Monad (replicateM, forM_) @@ -14,6 +16,7 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Int import Data.List (sortOn) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Typeable import Data.Void import Data.Word @@ -33,6 +36,7 @@ import Data.Monoid ((<>)) import qualified Spec.Range as Range import qualified Spec.Run as Run +import qualified Spec.Seed as Seed import qualified Spec.Stateful as Stateful main :: IO () @@ -103,6 +107,7 @@ main = , uniformSpec (Proxy :: Proxy (Int8, Word8, Word16, Word32, Word64, Word)) , uniformSpec (Proxy :: Proxy (Int8, Int16, Word8, Word16, Word32, Word64, Word)) , Stateful.statefulGenSpec + , Seed.spec ] floatTests :: TestTree @@ -297,6 +302,11 @@ instance Monad m => Serial m Foo newtype ConstGen = ConstGen Word64 +instance SeedGen ConstGen where + type SeedSize ConstGen = 8 + seedGen64 (w :| _) = ConstGen w + unseedGen64 (ConstGen w) = pure w + instance RandomGen ConstGen where genWord64 g@(ConstGen c) = (c, g) instance SplitGen ConstGen where diff --git a/test/Spec/Seed.hs b/test/Spec/Seed.hs new file mode 100644 index 00000000..591ed611 --- /dev/null +++ b/test/Spec/Seed.hs @@ -0,0 +1,115 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +module Spec.Seed where + +import Data.Bits +import Data.List.NonEmpty as NE +import Data.Maybe (fromJust) +import Data.Proxy +import Data.Word +import System.Random +import Test.Tasty +import Test.Tasty.SmallCheck as SC +import qualified Data.ByteString as BS +import GHC.TypeLits +import qualified GHC.Exts as GHC (IsList(..)) +import Test.SmallCheck.Series hiding (NonEmpty(..)) +import Spec.Stateful () + +newtype GenN (n :: Nat) = GenN BS.ByteString + deriving (Eq, Show) + +instance (KnownNat n, Monad m) => Serial m (GenN n) where + series = GenN . fst . uniformByteString n . mkStdGen <$> series + where + n = fromInteger (natVal (Proxy :: Proxy n)) + +instance (KnownNat n, Monad m) => Serial m (Gen64 n) where + series = + Gen64 . dropExtra . fst . uniformList n . mkStdGen <$> series + where + (n, r8) = + case fromInteger (natVal (Proxy :: Proxy n)) `quotRem` 8 of + (q, 0) -> (q, 0) + (q, r) -> (q + 1, (8 - r) * 8) + -- We need to drop extra top most bits in the last generated Word64 in order for + -- roundtrip to work, because that is exactly what SeedGen will do + dropExtra xs = + case NE.reverse (fromJust (NE.nonEmpty xs)) of + w64 :| rest -> NE.reverse ((w64 `shiftL` r8) `shiftR` r8 :| rest) + +instance (1 <= n, KnownNat n) => SeedGen (GenN n) where + type SeedSize (GenN n) = n + unseedGen (GenN bs) = fromJust . mkSeed . GHC.fromList $ BS.unpack bs + seedGen = GenN . BS.pack . GHC.toList . unSeed + +newtype Gen64 (n :: Nat) = Gen64 (NonEmpty Word64) + deriving (Eq, Show) + +instance (1 <= n, KnownNat n) => SeedGen (Gen64 n) where + type SeedSize (Gen64 n) = n + unseedGen64 (Gen64 ws) = ws + seedGen64 = Gen64 + +seedGenSpec :: + forall g. (SeedGen g, Eq g, Show g, Serial IO g) + => TestTree +seedGenSpec = + testGroup (seedGenTypeName @g) + [ testProperty "seedGen/unseedGen" $ + forAll $ \(g :: g) -> g == seedGen (unseedGen g) + , testProperty "seedGen64/unseedGen64" $ + forAll $ \(g :: g) -> g == seedGen64 (unseedGen64 g) + ] + + +spec :: TestTree +spec = + testGroup + "SeedGen" + [ seedGenSpec @StdGen + , seedGenSpec @(GenN 1) + , seedGenSpec @(GenN 2) + , seedGenSpec @(GenN 3) + , seedGenSpec @(GenN 4) + , seedGenSpec @(GenN 5) + , seedGenSpec @(GenN 6) + , seedGenSpec @(GenN 7) + , seedGenSpec @(GenN 8) + , seedGenSpec @(GenN 9) + , seedGenSpec @(GenN 10) + , seedGenSpec @(GenN 11) + , seedGenSpec @(GenN 12) + , seedGenSpec @(GenN 13) + , seedGenSpec @(GenN 14) + , seedGenSpec @(GenN 15) + , seedGenSpec @(GenN 16) + , seedGenSpec @(GenN 17) + , seedGenSpec @(Gen64 1) + , seedGenSpec @(Gen64 2) + , seedGenSpec @(Gen64 3) + , seedGenSpec @(Gen64 4) + , seedGenSpec @(Gen64 5) + , seedGenSpec @(Gen64 6) + , seedGenSpec @(Gen64 7) + , seedGenSpec @(Gen64 8) + , seedGenSpec @(Gen64 9) + , seedGenSpec @(Gen64 10) + , seedGenSpec @(Gen64 11) + , seedGenSpec @(Gen64 12) + , seedGenSpec @(Gen64 13) + , seedGenSpec @(Gen64 14) + , seedGenSpec @(Gen64 15) + , seedGenSpec @(Gen64 16) + , seedGenSpec @(Gen64 17) + ] + diff --git a/test/Spec/Stateful.hs b/test/Spec/Stateful.hs index d0a64e4d..e575f117 100644 --- a/test/Spec/Stateful.hs +++ b/test/Spec/Stateful.hs @@ -174,7 +174,7 @@ frozenGenSpecFor fromStdGen toStdGen runStatefulGen = toStdGen runStatefulGen , testProperty "uniformByteArrayM/genByteArray" $ - forAll $ \(NonNegative n', isPinned1, isPinned2) -> + forAll $ \(NonNegative n', isPinned1 :: Bool, isPinned2 :: Bool) -> let n = n' `mod` 100000 -- Ensure it is not too big in matchRandomGenSpec (uniformByteArrayM isPinned1 n) @@ -220,4 +220,3 @@ statefulGenSpec = pure (res, g') ] ] -