Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow simultaneous sampling and boundaries for nested random walk sampler #49

Merged
merged 13 commits into from
Aug 21, 2024
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate2
Title: Next Generation mcstate
Version: 0.1.8
Version: 0.1.9
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Imperial College of Science, Technology and Medicine",
Expand Down
69 changes: 67 additions & 2 deletions R/sampler-nested-adaptive.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv,
list(base = autocorrelation_base, groups = autocorrelation_groups)
internal$vcv <- list(base = vcv_base, groups = vcv_groups)
proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups)
internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups)
internal$proposal <- nested_proposal_adaptive(proposal_vcv,
model$parameter_groups)

internal$history_pars <- numeric()
internal$included <- integer()
Expand Down Expand Up @@ -311,7 +312,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv,

## Update proposal
proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups)
internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups)
internal$proposal <- nested_proposal_adaptive(proposal_vcv,
model$parameter_groups)

state
}
Expand Down Expand Up @@ -425,3 +427,66 @@ check_nested_adaptive <- function(x, n_groups, has_base, null_allowed = FALSE,

ret
}


## TODO: this is a simpler version of nested_proposal that does not
## cope with boundaries etc - that's being looked at in #46 for now.
## nocov start
nested_proposal_adaptive <- function(vcv, parameter_groups, call = NULL) {
i_base <- parameter_groups == 0
n_base <- sum(i_base)
n_groups <- max(parameter_groups)
i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i))
if (NROW(vcv$base) != n_base) {
cli::cli_abort(
c("Incompatible number of base parameters in your model and sampler",
i = paste("Your model has {n_base} base parameters, but 'vcv$base'",
"implies {NROW(vcv$base)} parameters")),
call = call)
}
if (length(vcv$groups) != n_groups) {
cli::cli_abort(
c("Incompatible number of parameter groups in your model and sampler",
i = paste("Your model has {n_groups} parameter groups, but",
"'vcv$groups' has {length(vcv$groups)} groups")),
call = call)
}
n_pars_by_group <- lengths(i_group)
n_pars_by_group_vcv <- vnapply(vcv$groups, nrow)
err <- n_pars_by_group_vcv != n_pars_by_group
if (any(err)) {
detail <- sprintf(
"Group %d has %d parameters but 'vcv$groups[[%d]]' has %d",
which(err), n_pars_by_group[err],
which(err), n_pars_by_group_vcv[err])
cli::cli_abort(
c("Incompatible number of parameters within parameter group",
set_names(detail, "i")),
call = call)
}

has_base <- n_base > 0
if (has_base) {
mvn_base <- make_rmvnorm(vcv$base)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
}
} else {
proposal_base <- NULL
}

mvn_groups <- lapply(vcv$groups, make_rmvnorm)
proposal_groups <- function(x, rng) {
for (i in seq_len(n_groups)) {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
}
x
}

list(base = proposal_base,
groups = proposal_groups)
}
## nocov end
170 changes: 147 additions & 23 deletions R/sampler-nested-random-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
##' @param vcv A list of variance covariance matrices. We expect this
##' to be a list with elements `base` and `groups` corresponding to
##' the covariance matrix for base parameters (if any) and groups.
##'
##' @inheritParams mcstate_sampler_random_walk
##'
##' @return A `mcstate_sampler` object, which can be used with
##' [mcstate_sample]
##'
##' @export
mcstate_sampler_nested_random_walk <- function(vcv) {
mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") {
if (!is.list(vcv)) {
cli::cli_abort(
"Expected a list for 'vcv'",
Expand All @@ -59,7 +61,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
arg = "vcv")
}
if (!is.null(vcv$base)) {
check_vcv(vcv$base, call = environment())
check_vcv(vcv$base, allow_3d = TRUE, call = environment())
}
if (!is.list(vcv$groups)) {
cli::cli_abort("Expected 'vcv$groups' to be a list")
Expand All @@ -68,17 +70,27 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
cli::cli_abort("Expected 'vcv$groups' to have at least one element")
}
for (i in seq_along(vcv$groups)) {
check_vcv(vcv$groups[[i]], name = sprintf("vcv$groups[%d]", i),
call = environment())
check_vcv(vcv$groups[[i]], allow_3d = TRUE,
name = sprintf("vcv$groups[%d]", i), call = environment())
}

internal <- new.env(parent = emptyenv())

boundaries <- match_value(boundaries, c("reflect", "reject", "ignore"))

initialise <- function(pars, model, observer, rng) {
if (!model$properties$has_parameter_groups) {
cli::cli_abort("Your model does not have parameter groupings")
}
internal$proposal <- nested_proposal(vcv, model$parameter_groups)

internal$multiple_parameters <- length(dim2(pars)) > 1
if (internal$multiple_parameters) {
## this is enforced elsewhere
stopifnot(model$properties$allow_multiple_parameters)
}

internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars,
model$domain, boundaries)

initialise_rng_state(model, rng)
density <- model$density(pars, by_group = TRUE)
Expand All @@ -93,7 +105,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
"elements corresponding to parameter groups to be",
"included with your density")))
}
if (length(density_by_group) != n_groups) {
if (dim2(density_by_group)[1] != n_groups) {
cli::cli_abort(
paste("model$density(x, by_group = TRUE) produced a 'by_group'",
"attribute with incorrect length {length(density_by_group)}",
Expand All @@ -120,10 +132,32 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
step <- function(state, model, observer, rng) {
if (!is.null(internal$proposal$base)) {
pars_next <- internal$proposal$base(state$pars, rng)
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")

reject_some <- boundaries == "reject" &&
!all(i <- is_parameters_in_domain(pars_next, model$domain))
if (reject_some) {
density_next <- rep(-Inf, length(state$density))
density_by_group_next <- array(-Inf, dim2(internal$density_by_group))
if (any(i)) {
density_next_i <- model$density(pars_next[, i, drop = FALSE],
by_group = TRUE)
density_next[i] <- density_next_i
density_by_group_next[, i] <- attr(density_next_i, "by_group")
}
} else {
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}

accept <- density_next - state$density > log(rng$random_real(1))
if (accept) {
if (any(accept)) {
if (!all(accept)) {
## Retain some older parameters
i <- which(!accept)
pars_next[, i] <- state$pars[, i]
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}
state$pars <- pars_next
state$density <- density_next
internal$density_by_group <- density_by_group_next
Expand All @@ -134,16 +168,59 @@ mcstate_sampler_nested_random_walk <- function(vcv) {
}

pars_next <- internal$proposal$groups(state$pars, rng)
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")

reject_some <- boundaries == "reject" &&
!all(i <- is_parameters_in_domain_groups(pars_next, model$domain,
model$parameter_groups))

## This bit is potentially inefficient - for any proposed parameters out of
## bounds I substitute in the current parameters, so that we can run the
## density on all groups. Ideally we would want to only run the density on
## groups with all parameters in bounds. A bit fiddly to do that in a nice
## way when doing simultaneous sampling
if (reject_some) {
density_next <- rep(-Inf, length(state$density))
density_by_group_next <- array(-Inf, dim2(internal$density_by_group))
if (any(i)) {
if (internal$multiple_parameters) {
for (j in seq_len(ncol(i))) {
if (!all(i[, j])) {
i_group <- model$parameter_groups %in% which(!i[, j])
pars_next[i_group, j] <- state$pars[i_group, j]
}
}
} else {
i_group <- model$parameter_groups %in% which(!i)
pars_next[i_group] <- state$pars[i_group]
}
density_next <- model$density(pars_next, by_group = TRUE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can now use index_group to run a subset here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this needs support from dust too (mrc-5701, mrc-5700)

density_by_group_next <- attr(density_next, "by_group")
}
} else {
density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}

accept <- density_by_group_next - internal$density_by_group >
log(rng$random_real(length(density_by_group_next)))
log(rng$random_real(dim2(density_by_group_next)[1]))

if (any(accept)) {
if (!all(accept)) {
## Retain some older parameters
i <- model$parameter_groups %in% which(!accept)
pars_next[i] <- state$pars[i]
if (internal$multiple_parameters) {
for (j in seq_len(ncol(accept))) {
if (!all(accept[, j])) {
i <- model$parameter_groups %in% which(!accept[, j])
pars_next[i, j] <- state$pars[i, j]
}
}
} else {
i <- model$parameter_groups %in% which(!accept)
pars_next[i] <- state$pars[i]
}
## If e.g. density is provided by a particle filter, would this bit
## mean rerunning it? Increases time cost if so, and would result in
## new value of density (different to that which was accepted)
Comment on lines +221 to +223
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does really, which leaves us in a bit of a pickle. Options here:

  1. forbid this when the model is stochastic
  2. allow a model to advertise itself as additive (though some subtleties with the prior here)
  3. add a new method that nested models must provide that allows them to replace new groups (in effect moving this part into the model)

I think the latter is the best option, and something we can discuss when we're both back?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put this down as mrc-5699

density_next <- model$density(pars_next, by_group = TRUE)
density_by_group_next <- attr(density_next, "by_group")
}
Expand Down Expand Up @@ -206,7 +283,8 @@ check_parameter_groups <- function(x, n_pars, name = deparse(substitute(x)),
}


nested_proposal <- function(vcv, parameter_groups, call = NULL) {
nested_proposal <- function(vcv, parameter_groups, pars, domain,
boundaries = "reflect", call = NULL) {
i_base <- parameter_groups == 0
n_base <- sum(i_base)
n_groups <- max(parameter_groups)
Expand Down Expand Up @@ -241,25 +319,71 @@ nested_proposal <- function(vcv, parameter_groups, call = NULL) {

has_base <- n_base > 0
if (has_base) {
mvn_base <- make_rmvnorm(vcv$base)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
if (is.matrix(pars)) {
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE])
mvn_base <- make_random_walk_proposal(
vcv$base, domain[i_base, , drop = FALSE], boundaries)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base, ], rng)
x
}
} else {
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base])
mvn_base <- make_random_walk_proposal(
vcv$base, domain[i_base, , drop = FALSE], boundaries)
proposal_base <- function(x, rng) {
## This approach is likely to be a bit fragile, so we'll
## probably want some naming related verification here soon too.
x[i_base] <- mvn_base(x[i_base], rng)
x
}
}

} else {
proposal_base <- NULL
}

mvn_groups <- lapply(vcv$groups, make_rmvnorm)
for (i in seq_len(n_groups)) {
if (is.matrix(pars)) {
vcv$groups[[i]] <-
sampler_validate_vcv(vcv$groups[[i]],
pars[i_group[[i]], , drop = FALSE])
} else {
vcv$groups[[i]] <-
sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]])
}
}
mvn_groups <- lapply(seq_len(n_groups), function (i)
make_random_walk_proposal(vcv$groups[[i]],
domain[i_group[[i]], , drop = FALSE],
boundaries))
proposal_groups <- function(x, rng) {
for (i in seq_len(n_groups)) {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
if (is.matrix(x)) {
x[i_group[[i]], ] <- mvn_groups[[i]](x[i_group[[i]], ], rng)
} else {
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng)
}
}
x
}

list(base = proposal_base,
groups = proposal_groups)
}

is_parameters_in_domain_groups <- function(x, domain, parameter_groups) {
x_min <- domain[, 1]
x_max <- domain[, 2]
i <- x > x_min & x < x_max
n_groups <- max(parameter_groups)
i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i))
if (is.matrix(x)) {
t(vapply(i_group, function(j) apply(i[j, , drop = FALSE], 2, all),
logical(ncol(x))))
} else {
vapply(i_group, function(j) all(i[j]), logical(1L))
}
}
17 changes: 16 additions & 1 deletion man/mcstate_sampler_nested_random_walk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading