Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow simultaneous sampling and boundaries for nested random walk sampler #49

Merged
merged 13 commits into from
Aug 21, 2024
184 changes: 161 additions & 23 deletions R/sampler-nested-random-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,28 @@
##' @param vcv A list of variance covariance matrices. We expect this
##' to be a list with elements `base` and `groups` corresponding to
##' the covariance matrix for base parameters (if any) and groups.
##'
##' @param boundaries Control the behaviour of proposals that are
##' outside the model domain. The supported options are:
##'
##' * "reflect" (the default): we reflect proposed parameters that
##' lie outside the domain back into the domain (as many times as
##' needed)
##'
##' * "reject": we do not evaluate the density function, and return
##' `-Inf` for its density instead.
##'
##' * "ignore": evaluate the point anyway, even if it lies outside
##' the domain.
##'
##' The initial point selected will lie within the domain, as this is
##' enforced by [mcstate_sample].
richfitz marked this conversation as resolved.
Show resolved Hide resolved
##'
##' @return A `mcstate_sampler` object, which can be used with
##' [mcstate_sample]
##'
##' @export
mcstate_sampler_nested_random_walk <- function(vcv) {
mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") {
if (!is.list(vcv)) {
cli::cli_abort(
"Expected a list for 'vcv'",
Expand All @@ -59,7 +75,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
arg = "vcv")
}
if (!is.null(vcv$base)) {
check_vcv(vcv$base, call = environment())
check_vcv(vcv$base, allow_3d = TRUE, call = environment())
}
if (!is.list(vcv$groups)) {
cli::cli_abort("Expected 'vcv$groups' to be a list")
Expand All @@ -68,17 +84,27 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
cli::cli_abort("Expected 'vcv$groups' to have at least one element")
}
for (i in seq_along(vcv$groups)) {
check_vcv(vcv$groups[[i]], name = sprintf("vcv$groups[%d]", i),
call = environment())
check_vcv(vcv$groups[[i]], allow_3d = TRUE,
name = sprintf("vcv$groups[%d]", i), call = environment())
}

internal <- new.env(parent = emptyenv())

boundaries <- match_value(boundaries, c("reflect", "reject", "ignore"))

initialise <- function(pars, model, observer, rng) {
if (!model$properties$has_parameter_groups) {
cli::cli_abort("Your model does not have parameter groupings")
}
internal$proposal <- nested_proposal(vcv, model$parameter_groups)

internal$multiple_parameters <- length(dim2(pars)) > 1
if (internal$multiple_parameters) {
## this is enforced elsewhere
stopifnot(model$properties$allow_multiple_parameters)
}

internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars,
model$domain, boundaries)

initialise_rng_state(model, rng)
density <- model$density(pars, by_group = TRUE)
Expand All @@ -93,7 +119,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
"elements corresponding to parameter groups to be",
"included with your density")))
}
if (length(density_by_group) != n_groups) {
if (dim2(density_by_group)[1] != n_groups) {
cli::cli_abort(
paste("model$density(x, by_group = TRUE) produced a 'by_group'",
"attribute with incorrect length {length(density_by_group)}",
Expand All @@ -120,10 +146,32 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
step <- function(state, model, observer, rng) {
if (!is.null(internal$proposal$base)) {
pars_next <- internal$proposal$base(state$pars, rng)
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")

reject_some <- boundaries == "reject" &&
!all(i <- is_parameters_in_domain(pars_next, model$domain))
if (reject_some) {
density_next <- rep(-Inf, length(state$density))
density_by_group_next <- array(-Inf, dim2(internal$density_by_group))
if (any(i)) {
density_next_i <- model$density(pars_next[, i, drop = FALSE],
by_group = TRUE)
density_next[i] <- density_next_i
density_by_group_next[, i] <- attr(density_next_i, "by_group")
}
} else {
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}

accept <- density_next - state$density > log(rng$random_real(1))
if (accept) {
if (any(accept)) {
if (!all(accept)) {
## Retain some older parameters
i <- which(!accept)
pars_next[, i] <- state$pars[, i]
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}
state$pars <- pars_next
state$density <- density_next
internal$density_by_group <- density_by_group_next
Expand All @@ -134,16 +182,59 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
}

pars_next <- internal$proposal$groups(state$pars, rng)
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")

reject_some <- boundaries == "reject" &&
!all(i <- is_parameters_in_domain_groups(pars_next, model$domain,
model$parameter_groups))

## This bit is potentially inefficient - for any proposed parameters out of
## bounds I substitute in the current parameters, so that we can run the
## density on all groups. Ideally we would want to only run the density on
## groups with all parameters in bounds. A bit fiddly to do that in a nice
## way when doing simultaneous sampling
if (reject_some) {
density_next <- rep(-Inf, length(state$density))
density_by_group_next <- array(-Inf, dim2(internal$density_by_group))
if (any(i)) {
if (internal$multiple_parameters) {
for (j in seq_len(ncol(i))) {
if (!all(i[, j])) {
i_group <- model$parameter_groups %in% which(!i[, j])
pars_next[i_group, j] <- state$pars[i_group, j]
}
}
} else {
i_group <- model$parameter_groups %in% which(!i)
pars_next[i_group] <- state$pars[i_group]
}
density_next <- model$density(pars_next, by_group = TRUE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can now use index_group to run a subset here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this needs support from dust too (mrc-5701, mrc-5700)

density_by_group_next <- attr(density_next, "by_group")
}
} else {
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}

accept <- density_by_group_next - internal$density_by_group >
log(rng$random_real(length(density_by_group_next)))
log(rng$random_real(dim2(density_by_group_next)[1]))

if (any(accept)) {
if (!all(accept)) {
## Retain some older parameters
i <- model$parameter_groups %in% which(!accept)
pars_next[i] <- state$pars[i]
if (internal$multiple_parameters) {
for (j in seq_len(ncol(accept))) {
if (!all(accept[, j])) {
i <- model$parameter_groups %in% which(!accept[, j])
pars_next[i, j] <- state$pars[i, j]
}
}
} else {
i <- model$parameter_groups %in% which(!accept)
pars_next[i] <- state$pars[i]
}
## If e.g. density is provided by a particle filter, would this bit
## mean rerunning it? Increases time cost if so, and would result in
## new value of density (different to that which was accepted)
Comment on lines +221 to +223
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does really, which leaves us in a bit of a pickle. Options here:

  1. forbid this when the model is stochastic
  2. allow a model to advertise itself as additive (though some subtleties with the prior here)
  3. add a new method that nested models must provide that allows them to replace new groups (in effect moving this part into the model)

I think the latter is the best option, and something we can discuss when we're both back?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put this down as mrc-5699

density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}
Expand Down Expand Up @@ -206,7 +297,8 @@ check_parameter_groups <- function(x, n_pars, name = deparse(substitute(x)),
}


nested_proposal <- function(vcv, parameter_groups, call = NULL) {
nested_proposal <- function(vcv, parameter_groups, pars, domain,
boundaries = "reflect", call = NULL) {
i_base <- parameter_groups == 0
n_base <- sum(i_base)
n_groups <- max(parameter_groups)
Expand Down Expand Up @@ -241,25 +333,71 @@ nested_proposal <- function(vcv, parameter_groups, call = NULL) {

has_base <- n_base > 0
if (has_base) {
mvn_base <- make_rmvnorm(vcv$base)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
if (is.matrix(pars)) {
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE])
mvn_base <- make_random_walk_proposal(
vcv$base, domain[i_base, , drop = FALSE], boundaries)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base, ], rng)
x
}
} else {
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base])
mvn_base <- make_random_walk_proposal(
vcv$base, domain[i_base, , drop = FALSE], boundaries)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
}
}

} else {
proposal_base <- NULL
}

mvn_groups <- lapply(vcv$groups, make_rmvnorm)
for (i in seq_len(n_groups)) {
if (is.matrix(pars)) {
vcv$groups[[i]] <-
sampler_validate_vcv(vcv$groups[[i]],
pars[i_group[[i]], , drop = FALSE])
} else {
vcv$groups[[i]] <-
sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]])
}
}
mvn_groups <- lapply(seq_len(n_groups), function (i)
make_random_walk_proposal(vcv$groups[[i]],
domain[i_group[[i]], , drop = FALSE],
boundaries))
proposal_groups <- function(x, rng) {
for (i in seq_len(n_groups)) {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
if (is.matrix(x)) {
x[i_group[[i]], ] <- mvn_groups[[i]](x[i_group[[i]], ], rng)
} else {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
}
}
x
}

list(base = proposal_base,
groups = proposal_groups)
}

is_parameters_in_domain_groups <- function(x, domain, parameter_groups) {
x_min <- domain[, 1]
x_max <- domain[, 2]
i <- x > x_min & x < x_max
n_groups <- max(parameter_groups)
i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i))
if (is.matrix(x)) {
t(vapply(i_group, function(j) apply(i[j, , drop = FALSE], 2, all),
logical(ncol(x))))
} else {
vapply(i_group, function(j) all(i[j]), logical(1L))
}
}
17 changes: 16 additions & 1 deletion man/mcstate_sampler_nested_random_walk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 18 additions & 7 deletions tests/testthat/helper-mcstate2.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,28 @@ ex_simple_nested_with_base <- function(n_groups) {
c(sigma, rng$normal(n_groups, 0, sigma))
},
density = function(x, by_group = FALSE) {
sigma <- x[[1]]
if (sigma <= 0) {
z <- rep(-Inf, length(x) - 1)
density1 <- function(y) {
sigma <- y[[1]]
if (sigma <= 0) {
rep(-Inf, length(y) - 1)
} else {
dnorm(y[-1], 0, y[[1]], log = TRUE)
}
}
if (is.matrix(x)) {
z <- vapply(seq_len(ncol(x)), function(i) density1(x[, i]),
numeric(nrow(x) - 1))
value <- colSums(z) + dunif(x[1, ], 0, 10, log = TRUE)
if (by_group) structure(value, "by_group" = z) else value
} else {
z <- dnorm(x[-1], 0, x[[1]], log = TRUE)
z <- density1(x)
value <- sum(z) + dunif(x[[1]], 0, 10, log = TRUE)
if (by_group) structure(value, "by_group" = z) else value
}
value <- sum(z) + dunif(x[[1]], 0, 10, log = TRUE)
if (by_group) structure(value, "by_group" = z) else value
},
parameter_groups = c(0, seq_len(n_groups)),
mu = mu)))
mu = mu),
mcstate_model_properties(allow_multiple_parameters = TRUE)))
}


Expand Down
Loading
Loading