Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev ppmmh #244

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions rhine-bayes/app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ posteriorTemperatureProcess = proc sensor -> do
arrM score -< sensorLikelihood latent sensor
returnA -< (temperature, latent)

{- | Given sensor data and temperature, sample a latent position, and weight them according to the likelihood of the observed sensor position.
Used to infer position when temperature is given.
-}
posterior :: (MonadMeasure m, Diff td ~ Double) => BehaviourF m td (Sensor, Temperature) Pos
posterior = proc (sensor, temperature) -> do
latent <- prior -< temperature
arrM score -< sensorLikelihood latent sensor
returnA -< latent

-- | A collection of all displayable inference results
data Result = Result
{ temperature :: Temperature
Expand Down Expand Up @@ -269,6 +278,7 @@ mains =
[ ("single rate", mainSingleRate)
, ("single rate, parameter collapse", mainSingleRateCollapse)
, ("multi rate, temperature process", mainMultiRate)
, ("multi rate, PPMMH", mainMultiRatePPMMH)
]

main :: IO ()
Expand Down Expand Up @@ -393,6 +403,12 @@ userTemperature = tagS >>> arr (selector >>> fmap Product) >>> mappendS >>> arr
selector (EventKey (SpecialKey KeyDown) Down _ _) = Just (1 / 1.2)
selector _ = Nothing

-- | Visualize the current 'Result' at a rate controlled by the @gloss@ backend, usually 30 FPS.
visualisationRhine :: Rhine (GlossConcT IO) (GlossClockUTC GlossSimClockIO) Result ()
visualisationRhine = hoistClSF sampleIOGloss visualisation @@ glossClockUTC GlossSimClockIO

-- *** Joint inference on state and parameter

{- | This part performs the inference (and passes along temperature, sensor and position simulations).
It runs as fast as possible, so this will potentially drain the CPU.
-}
Expand Down Expand Up @@ -435,6 +451,39 @@ mainMultiRate =
launchInGlossThread glossSettings $
flow mainRhineMultiRate

-- *** Separate inference on state and parameter (PPMMH)

{- | This part performs the inference (and passes along temperature, sensor and position simulations).
It runs as fast as possible, so this will potentially drain the CPU.
-}
inferencePPMMH :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result
inferencePPMMH = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy
where
inferenceBehaviour :: (MonadDistribution m, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result
inferenceBehaviour = proc (temperature, (measured, latent)) -> do
(particles, temperatures) <- ppmmh 20 20 resampleSystematic resampleSystematic (temperatureProcess <<< arr (const ())) posterior -< measured
-- particles <- runPopulationCl nParticles resampleSystematic posteriorTemperatureProcess -< measured
returnA -< Result {temperature, measured, latent, particles = (, 1/20) <$> particles, particlesTemperature = (, 1/20) <$> temperatures}

{- FOURMOLU_DISABLE -}
-- | Compose all four asynchronous components to a single 'Rhine'.
mainRhineMultiRatePPMMH =
userTemperature
@@ glossClockUTC GlossEventClockIO
>-- keepLast initialTemperature -->
modelRhine
>-- keepLast (initialTemperature, (zeroVector, zeroVector)) -->
inferencePPMMH
>-- keepLast Result {temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = [], particlesTemperature = []} -->
visualisationRhine
{- FOURMOLU_ENABLE -}

mainMultiRatePPMMH :: IO ()
mainMultiRatePPMMH =
void $
launchInGlossThread glossSettings $
flow mainRhineMultiRatePPMMH

-- * Utilities

instance (MonadDistribution m) => MonadDistribution (GlossConcT m) where
Expand Down
1 change: 1 addition & 0 deletions rhine-bayes/rhine-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ library
, dunai ^>= 0.9
, log-domain >= 0.12
, monad-bayes >= 1.1.0
, vector ^>= 0.12
hs-source-dirs: src
default-language: Haskell2010
default-extensions:
Expand Down
63 changes: 63 additions & 0 deletions rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ module Data.MonadicStreamFunction.Bayes where

-- base
import Control.Arrow
import Control.Monad (forM)
import Data.Functor (($>))
import Data.List (transpose)
import Data.Tuple (swap)

-- vector
import Data.Vector (fromList)

-- transformers

-- log-domain
import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Class (MonadDistribution (logCategorical))
import Control.Monad.Bayes.Population
( fromWeightedList, runPopulation, spawn, Population)

-- dunai
import Data.MonadicStreamFunction
Expand Down Expand Up @@ -48,6 +55,62 @@ runPopulationsS resampler = go
unzip $
(swap . fmap fst &&& swap . fmap snd) . swap <$> bAndMSFs

-- | "Particle parameter marginalized Metropolis-Hastings" - adaptation of PMMH
ppmmh ::
-- | Number of particles parameter
MonadDistribution m => Int ->
-- | Number of particles state
Int ->
-- | Resampler for parameter
(forall x . Population m x -> Population m x) ->
-- | Resampler for state
(forall x . Population m x -> Population m x) ->
MSF m a p ->
MSF (Population m) (a, p) b ->
MSF m a ([b], [p])
ppmmh nPar nState resPar resState par state = ppmmhS resPar resState (replicate nPar par) (replicate nState state)

ppmmhS ::
forall m a p b .
(MonadDistribution m) =>
-- | Resampler for parameter
(forall x . Population m x -> Population m x) ->
-- | Resampler for state
(forall x . Population m x -> Population m x) ->
[MSF m a p] ->
[MSF (Population m) (a, p) b] ->
MSF m a ([b], [p])
ppmmhS resPar resState = go
where
go ::
MonadDistribution m =>
[MSF m a p] ->
[MSF (Population m) (a, p) b] ->
MSF m a ([b], [p])
go parMSFs stateMSFs = MSF $ \a -> do
pars <- forM parMSFs $ flip unMSF a
bAndStateMSFs <- forM pars $ \(p, _) -> runPopulation $ flip unMSF (a, p) =<< fromWeightedList (pure $ (, 1) <$> stateMSFs)
let parWeights = sum . fmap snd <$> bAndStateMSFs
-- FIXME it's not so nice that the next step is in m, but the side effects should all be pure
parMSFs' <- runPopulation $ resPar $ fromWeightedList $ pure $ zip (snd <$> pars) parWeights
bAndStateMSFsT <- transpose <$> mapM (runPopulation . normalize . resState . fromWeightedList . pure) bAndStateMSFs
bAndStateMSFs <- forM bAndStateMSFsT $ \forallParameters -> do
choice <- logCategorical $ fromList $ snd <$> forallParameters
pure $ fst $ forallParameters !! choice

pure ((fst <$> bAndStateMSFs, fst <$> pars), go (fst <$> parMSFs') (snd <$> bAndStateMSFs))

kernelToProcess ::
Monad m =>
-- | Initial parameter
p ->
-- | Proposition kernel
(a -> p -> m p) ->
MSF m a p
kernelToProcess p0 f = feedback p0 $ arrM $ \(a, p) -> dup <$> f a p
where
dup p = (p, p)

-- FIXME see PR re-adding this to monad-bayes
normalize :: (Monad m) => Population m a -> Population m a
normalize = fromWeightedList . fmap (\particles -> second (/ (sum $ snd <$> particles)) <$> particles) . runPopulation
16 changes: 16 additions & 0 deletions rhine-bayes/src/FRP/Rhine/Bayes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ runPopulationCl ::
ClSF m cl a [(b, Log Double)]
runPopulationCl nParticles resampler = DunaiReader.readerS . DunaiBayes.runPopulationS nParticles resampler . DunaiReader.runReaderS

-- | "Particle parameter marginalized Metropolis-Hastings" - adaptation of PMMH
ppmmh ::
MonadDistribution m =>
-- | Number of particles parameter
Int ->
-- | Number of particles state
Int ->
-- | Resampler for parameter
(forall x . Population m x -> Population m x) ->
-- | Resampler for state
(forall x . Population m x -> Population m x) ->
ClSF m cl a p ->
ClSF (Population m) cl (a, p) b ->
ClSF m cl a ([b], [p])
ppmmh nPar nState resPar resState par state = DunaiReader.readerS $ DunaiBayes.ppmmh nPar nState resPar resState (DunaiReader.runReaderS par) (DunaiReader.runReaderS state <<< arr (\((ti, a), p) -> (ti, (a, p))))

-- * Short standard library of stochastic processes

-- | A stochastic process is a behaviour that uses, as only effect, random sampling.
Expand Down