-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow simultaneous sampling and boundaries for nested random walk sampler #49
Changes from all commits
7df41f2
19e3d24
49a2225
0d63f3c
9590a2b
1f99f85
361aad9
0e35ce1
2173718
a49ecd0
aa864f9
bbdbbdf
7f2d7fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: mcstate2 | ||
Title: Next Generation mcstate | ||
Version: 0.1.8 | ||
Version: 0.1.9 | ||
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), | ||
email = "[email protected]"), | ||
person("Imperial College of Science, Technology and Medicine", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,12 +42,14 @@ | |
##' @param vcv A list of variance covariance matrices. We expect this | ||
##' to be a list with elements `base` and `groups` corresponding to | ||
##' the covariance matrix for base parameters (if any) and groups. | ||
##' | ||
##' @inheritParams mcstate_sampler_random_walk | ||
##' | ||
##' @return A `mcstate_sampler` object, which can be used with | ||
##' [mcstate_sample] | ||
##' | ||
##' @export | ||
mcstate_sampler_nested_random_walk <- function(vcv) { | ||
mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { | ||
if (!is.list(vcv)) { | ||
cli::cli_abort( | ||
"Expected a list for 'vcv'", | ||
|
@@ -59,7 +61,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) { | |
arg = "vcv") | ||
} | ||
if (!is.null(vcv$base)) { | ||
check_vcv(vcv$base, call = environment()) | ||
check_vcv(vcv$base, allow_3d = TRUE, call = environment()) | ||
} | ||
if (!is.list(vcv$groups)) { | ||
cli::cli_abort("Expected 'vcv$groups' to be a list") | ||
|
@@ -68,17 +70,27 @@ mcstate_sampler_nested_random_walk <- function(vcv) { | |
cli::cli_abort("Expected 'vcv$groups' to have at least one element") | ||
} | ||
for (i in seq_along(vcv$groups)) { | ||
check_vcv(vcv$groups[[i]], name = sprintf("vcv$groups[%d]", i), | ||
call = environment()) | ||
check_vcv(vcv$groups[[i]], allow_3d = TRUE, | ||
name = sprintf("vcv$groups[%d]", i), call = environment()) | ||
} | ||
|
||
internal <- new.env(parent = emptyenv()) | ||
|
||
boundaries <- match_value(boundaries, c("reflect", "reject", "ignore")) | ||
|
||
initialise <- function(pars, model, observer, rng) { | ||
if (!model$properties$has_parameter_groups) { | ||
cli::cli_abort("Your model does not have parameter groupings") | ||
} | ||
internal$proposal <- nested_proposal(vcv, model$parameter_groups) | ||
|
||
internal$multiple_parameters <- length(dim2(pars)) > 1 | ||
if (internal$multiple_parameters) { | ||
## this is enforced elsewhere | ||
stopifnot(model$properties$allow_multiple_parameters) | ||
} | ||
|
||
internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars, | ||
model$domain, boundaries) | ||
|
||
initialise_rng_state(model, rng) | ||
density <- model$density(pars, by_group = TRUE) | ||
|
@@ -93,7 +105,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) { | |
"elements corresponding to parameter groups to be", | ||
"included with your density"))) | ||
} | ||
if (length(density_by_group) != n_groups) { | ||
if (dim2(density_by_group)[1] != n_groups) { | ||
cli::cli_abort( | ||
paste("model$density(x, by_group = TRUE) produced a 'by_group'", | ||
"attribute with incorrect length {length(density_by_group)}", | ||
|
@@ -120,10 +132,32 @@ mcstate_sampler_nested_random_walk <- function(vcv) { | |
step <- function(state, model, observer, rng) { | ||
if (!is.null(internal$proposal$base)) { | ||
pars_next <- internal$proposal$base(state$pars, rng) | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
|
||
reject_some <- boundaries == "reject" && | ||
!all(i <- is_parameters_in_domain(pars_next, model$domain)) | ||
if (reject_some) { | ||
density_next <- rep(-Inf, length(state$density)) | ||
density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) | ||
if (any(i)) { | ||
density_next_i <- model$density(pars_next[, i, drop = FALSE], | ||
by_group = TRUE) | ||
density_next[i] <- density_next_i | ||
density_by_group_next[, i] <- attr(density_next_i, "by_group") | ||
} | ||
} else { | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
} | ||
|
||
accept <- density_next - state$density > log(rng$random_real(1)) | ||
if (accept) { | ||
if (any(accept)) { | ||
if (!all(accept)) { | ||
## Retain some older parameters | ||
i <- which(!accept) | ||
pars_next[, i] <- state$pars[, i] | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
} | ||
state$pars <- pars_next | ||
state$density <- density_next | ||
internal$density_by_group <- density_by_group_next | ||
|
@@ -134,16 +168,59 @@ mcstate_sampler_nested_random_walk <- function(vcv) { | |
} | ||
|
||
pars_next <- internal$proposal$groups(state$pars, rng) | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
|
||
reject_some <- boundaries == "reject" && | ||
!all(i <- is_parameters_in_domain_groups(pars_next, model$domain, | ||
model$parameter_groups)) | ||
|
||
## This bit is potentially inefficient - for any proposed parameters out of | ||
## bounds I substitute in the current parameters, so that we can run the | ||
## density on all groups. Ideally we would want to only run the density on | ||
## groups with all parameters in bounds. A bit fiddly to do that in a nice | ||
## way when doing simultaneous sampling | ||
if (reject_some) { | ||
density_next <- rep(-Inf, length(state$density)) | ||
density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) | ||
if (any(i)) { | ||
if (internal$multiple_parameters) { | ||
for (j in seq_len(ncol(i))) { | ||
if (!all(i[, j])) { | ||
i_group <- model$parameter_groups %in% which(!i[, j]) | ||
pars_next[i_group, j] <- state$pars[i_group, j] | ||
} | ||
} | ||
} else { | ||
i_group <- model$parameter_groups %in% which(!i) | ||
pars_next[i_group] <- state$pars[i_group] | ||
} | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
} | ||
} else { | ||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
} | ||
|
||
accept <- density_by_group_next - internal$density_by_group > | ||
log(rng$random_real(length(density_by_group_next))) | ||
log(rng$random_real(dim2(density_by_group_next)[1])) | ||
|
||
if (any(accept)) { | ||
if (!all(accept)) { | ||
## Retain some older parameters | ||
i <- model$parameter_groups %in% which(!accept) | ||
pars_next[i] <- state$pars[i] | ||
if (internal$multiple_parameters) { | ||
for (j in seq_len(ncol(accept))) { | ||
if (!all(accept[, j])) { | ||
i <- model$parameter_groups %in% which(!accept[, j]) | ||
pars_next[i, j] <- state$pars[i, j] | ||
} | ||
} | ||
} else { | ||
i <- model$parameter_groups %in% which(!accept) | ||
pars_next[i] <- state$pars[i] | ||
} | ||
## If e.g. density is provided by a particle filter, would this bit | ||
## mean rerunning it? Increases time cost if so, and would result in | ||
## new value of density (different to that which was accepted) | ||
Comment on lines
+221
to
+223
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it does really, which leaves us in a bit of a pickle. Options here:
I think the latter is the best option, and something we can discuss when we're both back? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've put this down as mrc-5699 |
||
density_next <- model$density(pars_next, by_group = TRUE) | ||
density_by_group_next <- attr(density_next, "by_group") | ||
} | ||
|
@@ -206,7 +283,8 @@ check_parameter_groups <- function(x, n_pars, name = deparse(substitute(x)), | |
} | ||
|
||
|
||
nested_proposal <- function(vcv, parameter_groups, call = NULL) { | ||
nested_proposal <- function(vcv, parameter_groups, pars, domain, | ||
boundaries = "reflect", call = NULL) { | ||
i_base <- parameter_groups == 0 | ||
n_base <- sum(i_base) | ||
n_groups <- max(parameter_groups) | ||
|
@@ -241,25 +319,71 @@ nested_proposal <- function(vcv, parameter_groups, call = NULL) { | |
|
||
has_base <- n_base > 0 | ||
if (has_base) { | ||
mvn_base <- make_rmvnorm(vcv$base) | ||
proposal_base <- function(x, rng) { | ||
## This approach is likely to be a bit fragile, so we'll | ||
## probably want some naming related verification here soon too. | ||
x[i_base] <- mvn_base(x[i_base], rng) | ||
x | ||
if (is.matrix(pars)) { | ||
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE]) | ||
mvn_base <- make_random_walk_proposal( | ||
vcv$base, domain[i_base, , drop = FALSE], boundaries) | ||
proposal_base <- function(x, rng) { | ||
## This approach is likely to be a bit fragile, so we'll | ||
## probably want some naming related verification here soon too. | ||
x[i_base] <- mvn_base(x[i_base, ], rng) | ||
x | ||
} | ||
} else { | ||
vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base]) | ||
mvn_base <- make_random_walk_proposal( | ||
vcv$base, domain[i_base, , drop = FALSE], boundaries) | ||
proposal_base <- function(x, rng) { | ||
## This approach is likely to be a bit fragile, so we'll | ||
## probably want some naming related verification here soon too. | ||
x[i_base] <- mvn_base(x[i_base], rng) | ||
x | ||
} | ||
} | ||
|
||
} else { | ||
proposal_base <- NULL | ||
} | ||
|
||
mvn_groups <- lapply(vcv$groups, make_rmvnorm) | ||
for (i in seq_len(n_groups)) { | ||
if (is.matrix(pars)) { | ||
vcv$groups[[i]] <- | ||
sampler_validate_vcv(vcv$groups[[i]], | ||
pars[i_group[[i]], , drop = FALSE]) | ||
} else { | ||
vcv$groups[[i]] <- | ||
sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]]) | ||
} | ||
} | ||
mvn_groups <- lapply(seq_len(n_groups), function (i) | ||
make_random_walk_proposal(vcv$groups[[i]], | ||
domain[i_group[[i]], , drop = FALSE], | ||
boundaries)) | ||
proposal_groups <- function(x, rng) { | ||
for (i in seq_len(n_groups)) { | ||
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) | ||
if (is.matrix(x)) { | ||
x[i_group[[i]], ] <- mvn_groups[[i]](x[i_group[[i]], ], rng) | ||
} else { | ||
x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) | ||
} | ||
} | ||
x | ||
} | ||
|
||
list(base = proposal_base, | ||
groups = proposal_groups) | ||
} | ||
|
||
is_parameters_in_domain_groups <- function(x, domain, parameter_groups) { | ||
x_min <- domain[, 1] | ||
x_max <- domain[, 2] | ||
i <- x > x_min & x < x_max | ||
n_groups <- max(parameter_groups) | ||
i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i)) | ||
if (is.matrix(x)) { | ||
t(vapply(i_group, function(j) apply(i[j, , drop = FALSE], 2, all), | ||
logical(ncol(x)))) | ||
} else { | ||
vapply(i_group, function(j) all(i[j]), logical(1L)) | ||
} | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can now use index_group to run a subset here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this needs support from dust too (mrc-5701, mrc-5700)