From 7df41f25e45c65eda19ca9c551f695339ef6b513 Mon Sep 17 00:00:00 2001 From: edknock Date: Wed, 24 Jul 2024 13:59:17 +0100 Subject: [PATCH 01/10] base update working for simultaneous sampling --- R/sampler-nested-random-walk.R | 39 +++++++++++++++---- tests/testthat/helper-mcstate2.R | 25 ++++++++---- .../test-sampler-nested-random-walk.R | 17 ++++++++ 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 252c0121..db5fd383 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -59,7 +59,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) { arg = "vcv") } if (!is.null(vcv$base)) { - check_vcv(vcv$base, call = environment()) + check_vcv(vcv$base, allow_3d = TRUE, call = environment()) } if (!is.list(vcv$groups)) { cli::cli_abort("Expected 'vcv$groups' to be a list") @@ -68,8 +68,8 @@ mcstate_sampler_nested_random_walk <- function(vcv) { cli::cli_abort("Expected 'vcv$groups' to have at least one element") } for (i in seq_along(vcv$groups)) { - check_vcv(vcv$groups[[i]], name = sprintf("vcv$groups[%d]", i), - call = environment()) + check_vcv(vcv$groups[[i]], allow_3d = TRUE, + name = sprintf("vcv$groups[%d]", i), call = environment()) } internal <- new.env(parent = emptyenv()) @@ -78,7 +78,8 @@ mcstate_sampler_nested_random_walk <- function(vcv) { if (!model$properties$has_parameter_groups) { cli::cli_abort("Your model does not have parameter groupings") } - internal$proposal <- nested_proposal(vcv, model$parameter_groups) + + internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars) initialise_rng_state(model, rng) density <- model$density(pars, by_group = TRUE) @@ -93,7 +94,7 @@ mcstate_sampler_nested_random_walk <- function(vcv) { "elements corresponding to parameter groups to be", "included with your density"))) } - if (length(density_by_group) != n_groups) { + if (dim2(density_by_group)[1] != n_groups) { cli::cli_abort( paste("model$density(x, by_group = TRUE) produced a 'by_group'", "attribute with incorrect length {length(density_by_group)}", @@ -118,12 +119,20 @@ mcstate_sampler_nested_random_walk <- function(vcv) { ## either changing the behaviour of the step function or swapping in ## a different version. step <- function(state, model, observer, rng) { + browser() if (!is.null(internal$proposal$base)) { pars_next <- internal$proposal$base(state$pars, rng) density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") accept <- density_next - state$density > log(rng$random_real(1)) - if (accept) { + if (any(accept)) { + if (!all(accept)) { + ## Retain some older parameters + i <- which(!accept) + pars_next[, i] <- state$pars[, i] + density_next <- model$density(pars_next, by_group = TRUE) + density_by_group_next <- attr(density_next, "by_group") + } state$pars <- pars_next state$density <- density_next internal$density_by_group <- density_by_group_next @@ -205,7 +214,7 @@ check_parameter_groups <- function(x, n_pars, name = deparse(substitute(x)), } -nested_proposal <- function(vcv, parameter_groups, call = NULL) { +nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { i_base <- parameter_groups == 0 n_base <- sum(i_base) n_groups <- max(parameter_groups) @@ -240,6 +249,11 @@ nested_proposal <- function(vcv, parameter_groups, call = NULL) { has_base <- n_base > 0 if (has_base) { + if (is.matrix(pars)) { + vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE]) + } else { + vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base]) + } mvn_base <- make_rmvnorm(vcv$base) proposal_base <- function(x, rng) { ## This approach is likely to be a bit fragile, so we'll @@ -251,6 +265,17 @@ nested_proposal <- function(vcv, parameter_groups, call = NULL) { proposal_base <- NULL } + for (i in seq_len(n_groups)) { + if (is.matrix(pars)) { + vcv$groups[[i]] <- + sampler_validate_vcv(vcv$groups[[i]], + pars[i_group[[i]], , drop = FALSE]) + } else { + vcv$groups[[i]] <- + sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]]) + } + } + browser() mvn_groups <- lapply(vcv$groups, make_rmvnorm) proposal_groups <- function(x, rng) { for (i in seq_len(n_groups)) { diff --git a/tests/testthat/helper-mcstate2.R b/tests/testthat/helper-mcstate2.R index c9557b7c..3ccb6294 100644 --- a/tests/testthat/helper-mcstate2.R +++ b/tests/testthat/helper-mcstate2.R @@ -56,17 +56,28 @@ ex_simple_nested_with_base <- function(n_groups) { c(sigma, rng$normal(n_groups, 0, sigma)) }, density = function(x, by_group = FALSE) { - sigma <- x[[1]] - if (sigma <= 0) { - z <- rep(-Inf, length(x) - 1) + density1 <- function(y) { + sigma <- y[[1]] + if (sigma <= 0) { + rep(-Inf, length(y) - 1) + } else { + dnorm(y[-1], 0, y[[1]], log = TRUE) + } + } + if (is.matrix(x)) { + z <- vapply(seq_len(ncol(x)), function(i) density1(x[, i]), + numeric(nrow(x) - 1)) + value <- colSums(z) + dunif(x[1, ], 0, 10, log = TRUE) + if (by_group) structure(value, "by_group" = z) else value } else { - z <- dnorm(x[-1], 0, x[[1]], log = TRUE) + z <- density1(x) + value <- sum(z) + dunif(x[[1]], 0, 10, log = TRUE) + if (by_group) structure(value, "by_group" = z) else value } - value <- sum(z) + dunif(x[[1]], 0, 10, log = TRUE) - if (by_group) structure(value, "by_group" = z) else value }, parameter_groups = c(0, seq_len(n_groups)), - mu = mu))) + mu = mu), + mcstate_model_properties(allow_multiple_parameters = TRUE))) } diff --git a/tests/testthat/test-sampler-nested-random-walk.R b/tests/testthat/test-sampler-nested-random-walk.R index ee238683..17fc3457 100644 --- a/tests/testthat/test-sampler-nested-random-walk.R +++ b/tests/testthat/test-sampler-nested-random-walk.R @@ -197,3 +197,20 @@ test_that("can run an observer during a nested fit", { c(1, 100, 1)) expect_gt(max(res$observations$n), 120) # called way more than once per step }) + + +test_that("can run nested random walk sampler simultaneously", { + set.seed(1) + ng <- 5 + m <- ex_simple_nested_with_base(ng) + sampler <- mcstate_sampler_nested_random_walk( + list(base = diag(1), groups = rep(list(diag(1)), ng))) + + set.seed(1) + res1 <- mcstate_sample(m, sampler, 100, n_chains = 3) + + set.seed(1) + runner <- mcstate_runner_simultaneous() + res2 <- mcstate_sample(m, sampler, 100, n_chains = 3, runner = runner) + expect_equal(res1, res2) +}) \ No newline at end of file From 19e3d243ee2e52e6827613bd179a187038b8104b Mon Sep 17 00:00:00 2001 From: edknock Date: Wed, 24 Jul 2024 15:43:00 +0100 Subject: [PATCH 02/10] working simultaneous nested --- R/sampler-nested-random-walk.R | 52 ++++++++++++++----- .../test-sampler-nested-random-walk.R | 12 +++-- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index db5fd383..f79b2ef5 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -79,6 +79,12 @@ mcstate_sampler_nested_random_walk <- function(vcv) { cli::cli_abort("Your model does not have parameter groupings") } + internal$multiple_parameters <- length(dim2(pars)) > 1 + if (internal$multiple_parameters) { + ## this is enforced elsewhere + stopifnot(model$properties$allow_multiple_parameters) + } + internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars) initialise_rng_state(model, rng) @@ -119,7 +125,6 @@ mcstate_sampler_nested_random_walk <- function(vcv) { ## either changing the behaviour of the step function or swapping in ## a different version. step <- function(state, model, observer, rng) { - browser() if (!is.null(internal$proposal$base)) { pars_next <- internal$proposal$base(state$pars, rng) density_next <- model$density(pars_next, by_group = TRUE) @@ -146,13 +151,23 @@ mcstate_sampler_nested_random_walk <- function(vcv) { density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") accept <- density_by_group_next - internal$density_by_group > - log(rng$random_real(length(density_by_group_next))) + log(rng$random_real(dim2(density_by_group_next)[1])) if (any(accept)) { if (!all(accept)) { ## Retain some older parameters - i <- model$parameter_groups %in% which(!accept) - pars_next[i] <- state$pars[i] + if (internal$multiple_parameters) { + for (j in seq_len(ncol(accept))) { + if (!all(accept[, j])) { + i <- model$parameter_groups %in% which(!accept[, j]) + pars_next[i, j] <- state$pars[i, j] + } + } + } else { + i <- model$parameter_groups %in% which(!accept) + pars_next[i] <- state$pars[i] + } + density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") } @@ -251,16 +266,24 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { if (has_base) { if (is.matrix(pars)) { vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE]) + mvn_base <- make_rmvnorm(vcv$base) + proposal_base <- function(x, rng) { + ## This approach is likely to be a bit fragile, so we'll + ## probably want some naming related verification here soon too. + x[i_base] <- mvn_base(x[i_base, ], rng) + x + } } else { vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base]) + mvn_base <- make_rmvnorm(vcv$base) + proposal_base <- function(x, rng) { + ## This approach is likely to be a bit fragile, so we'll + ## probably want some naming related verification here soon too. + x[i_base] <- mvn_base(x[i_base], rng) + x + } } - mvn_base <- make_rmvnorm(vcv$base) - proposal_base <- function(x, rng) { - ## This approach is likely to be a bit fragile, so we'll - ## probably want some naming related verification here soon too. - x[i_base] <- mvn_base(x[i_base], rng) - x - } + } else { proposal_base <- NULL } @@ -275,11 +298,14 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]]) } } - browser() mvn_groups <- lapply(vcv$groups, make_rmvnorm) proposal_groups <- function(x, rng) { for (i in seq_len(n_groups)) { - x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) + if (is.matrix(x)) { + x[i_group[[i]], ] <- mvn_groups[[i]](x[i_group[[i]], ], rng) + } else { + x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) + } } x } diff --git a/tests/testthat/test-sampler-nested-random-walk.R b/tests/testthat/test-sampler-nested-random-walk.R index 17fc3457..39bcece6 100644 --- a/tests/testthat/test-sampler-nested-random-walk.R +++ b/tests/testthat/test-sampler-nested-random-walk.R @@ -25,7 +25,7 @@ test_that("validate vcv inputs on construction of sampler", { "Expected 'vcv' to have elements 'base' and 'groups'") expect_error( mcstate_sampler_nested_random_walk(list(base = TRUE, groups = TRUE)), - "Expected a matrix for 'vcv$base'", + "Expected a matrix or 3d array for 'vcv$base'", fixed = TRUE) expect_error( mcstate_sampler_nested_random_walk(list(base = NULL, groups = TRUE)), @@ -37,7 +37,7 @@ test_that("validate vcv inputs on construction of sampler", { fixed = TRUE) expect_error( mcstate_sampler_nested_random_walk(list(base = NULL, groups = list(TRUE))), - "Expected a matrix for 'vcv$groups[1]'", + "Expected a matrix or 3d array for 'vcv$groups[1]'", fixed = TRUE) vcv <- list(base = diag(1), groups = list(diag(2), diag(3))) @@ -92,7 +92,8 @@ test_that("can build nested proposal functions", { vcv <- list(base = NULL, groups = list(v, v / 100)) g <- c(1, 1, 2, 2) - f <- nested_proposal(vcv, g) + + f <- nested_proposal(vcv, g, rep(0, 4)) expect_null(f$base) expect_true(is.function(f$groups)) @@ -112,8 +113,9 @@ test_that("can build nested proposal functions with base components", { v <- matrix(c(1, .5, .5, 1), 2, 2) vcv <- list(base = v / 10, groups = list(v, v / 100)) - f <- nested_proposal(vcv, c(0, 0, 1, 1, 2, 2)) - g <- nested_proposal(list(base = NULL, groups = vcv$groups), c(1, 1, 2, 2)) + f <- nested_proposal(vcv, c(0, 0, 1, 1, 2, 2), 1:6) + g <- + nested_proposal(list(base = NULL, groups = vcv$groups),c(1, 1, 2, 2), 3:6) expect_true(is.function(f$base)) expect_true(is.function(f$groups)) From 49a22254e32370835c07fffcb8787fe6c7768765 Mon Sep 17 00:00:00 2001 From: edknock Date: Thu, 25 Jul 2024 15:51:21 +0100 Subject: [PATCH 03/10] add boundaries to nested random walk sampler --- R/sampler-nested-random-walk.R | 103 ++++++++++++++++-- man/mcstate_sampler_nested_random_walk.Rd | 19 +++- .../test-sampler-nested-random-walk.R | 34 +++++- 3 files changed, 138 insertions(+), 18 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index f79b2ef5..6995d5d1 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -42,12 +42,28 @@ ##' @param vcv A list of variance covariance matrices. We expect this ##' to be a list with elements `base` and `groups` corresponding to ##' the covariance matrix for base parameters (if any) and groups. +##' +##' ##' @param boundaries Control the behaviour of proposals that are +##' outside the model domain. The supported options are: +##' +##' * "reflect" (the default): we reflect proposed parameters that +##' lie outside the domain back into the domain (as many times as +##' needed) +##' +##' * "reject": we do not evaluate the density function, and return +##' `-Inf` for its density instead. +##' +##' * "ignore": evaluate the point anyway, even if it lies outside +##' the domain. +##' +##' The initial point selected will lie within the domain, as this is +##' enforced by [mcstate_sample]. ##' ##' @return A `mcstate_sampler` object, which can be used with ##' [mcstate_sample] ##' ##' @export -mcstate_sampler_nested_random_walk <- function(vcv) { +mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { if (!is.list(vcv)) { cli::cli_abort( "Expected a list for 'vcv'", @@ -74,6 +90,8 @@ mcstate_sampler_nested_random_walk <- function(vcv) { internal <- new.env(parent = emptyenv()) + boundaries <- match_value(boundaries, c("reflect", "reject", "ignore")) + initialise <- function(pars, model, observer, rng) { if (!model$properties$has_parameter_groups) { cli::cli_abort("Your model does not have parameter groupings") @@ -85,7 +103,8 @@ mcstate_sampler_nested_random_walk <- function(vcv) { stopifnot(model$properties$allow_multiple_parameters) } - internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars) + internal$proposal <- nested_proposal(vcv, model$parameter_groups, pars, + model$domain, boundaries) initialise_rng_state(model, rng) density <- model$density(pars, by_group = TRUE) @@ -127,13 +146,28 @@ mcstate_sampler_nested_random_walk <- function(vcv) { step <- function(state, model, observer, rng) { if (!is.null(internal$proposal$base)) { pars_next <- internal$proposal$base(state$pars, rng) - density_next <- model$density(pars_next, by_group = TRUE) - density_by_group_next <- attr(density_next, "by_group") + + reject_some <- boundaries == "reject" && + !all(i <- is_parameters_in_domain(pars_next, model$domain)) + if (reject_some) { + density_next <- rep(-Inf, length(state$density)) + density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) + if (any(i)) { + density_next_i <- model$density(pars_next[, i, drop = FALSE], + by_group = TRUE) + density_next[i] <- density_next_i + density_by_group_next[, i] <- attr(density_next_i, "by_group") + } + } else { + density_next <- model$density(pars_next, by_group = TRUE) + density_by_group_next <- attr(density_next, "by_group") + } + accept <- density_next - state$density > log(rng$random_real(1)) if (any(accept)) { if (!all(accept)) { ## Retain some older parameters - i <- which(!accept) + i <- which(!accept) pars_next[, i] <- state$pars[, i] density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") @@ -148,8 +182,35 @@ mcstate_sampler_nested_random_walk <- function(vcv) { } pars_next <- internal$proposal$groups(state$pars, rng) - density_next <- model$density(pars_next, by_group = TRUE) - density_by_group_next <- attr(density_next, "by_group") + + reject_some <- boundaries == "reject" && + !all(i <- is_parameters_in_domain_groups(pars_next, model$domain, + model$parameter_groups)) + + if (reject_some) { + if (any(i)) { + if (internal$multiple_parameters) { + for (j in seq_len(ncol(i))) { + if (!all(i[, j])) { + i_group <- model$parameter_groups %in% which(!i[, j]) + pars_next[i_group, j] <- state$pars[i_group, j] + } + } + } else { + i_group <- model$parameter_groups %in% which(!i) + pars_next[i_group] <- state$pars[i_group] + } + density_next <- model$density(pars_next, by_group = TRUE) + density_by_group_next <- attr(density_next, "by_group") + } else { + density_next <- rep(-Inf, length(state$density)) + density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) + } + } else { + density_next <- model$density(pars_next, by_group = TRUE) + density_by_group_next <- attr(density_next, "by_group") + } + accept <- density_by_group_next - internal$density_by_group > log(rng$random_real(dim2(density_by_group_next)[1])) @@ -229,7 +290,8 @@ check_parameter_groups <- function(x, n_pars, name = deparse(substitute(x)), } -nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { +nested_proposal <- function(vcv, parameter_groups, pars, domain, + boundaries = "reflect", call = NULL) { i_base <- parameter_groups == 0 n_base <- sum(i_base) n_groups <- max(parameter_groups) @@ -266,7 +328,8 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { if (has_base) { if (is.matrix(pars)) { vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base, , drop = FALSE]) - mvn_base <- make_rmvnorm(vcv$base) + mvn_base <- make_random_walk_proposal( + vcv$base, domain[i_base, , drop = FALSE], boundaries) proposal_base <- function(x, rng) { ## This approach is likely to be a bit fragile, so we'll ## probably want some naming related verification here soon too. @@ -275,7 +338,8 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { } } else { vcv$base <- sampler_validate_vcv(vcv$base, pars[i_base]) - mvn_base <- make_rmvnorm(vcv$base) + mvn_base <- make_random_walk_proposal( + vcv$base, domain[i_base, , drop = FALSE], boundaries) proposal_base <- function(x, rng) { ## This approach is likely to be a bit fragile, so we'll ## probably want some naming related verification here soon too. @@ -298,7 +362,10 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { sampler_validate_vcv(vcv$groups[[i]], pars[i_group[[i]]]) } } - mvn_groups <- lapply(vcv$groups, make_rmvnorm) + mvn_groups <- lapply(seq_len(n_groups), function (i) + make_random_walk_proposal(vcv$groups[[i]], + domain[i_group[[i]], , drop = FALSE], + boundaries)) proposal_groups <- function(x, rng) { for (i in seq_len(n_groups)) { if (is.matrix(x)) { @@ -313,3 +380,17 @@ nested_proposal <- function(vcv, parameter_groups, pars, call = NULL) { list(base = proposal_base, groups = proposal_groups) } + +is_parameters_in_domain_groups <- function(x, domain, parameter_groups) { + x_min <- domain[, 1] + x_max <- domain[, 2] + i <- x > x_min & x < x_max + n_groups <- max(parameter_groups) + i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i)) + if (is.matrix(x)) { + t(vapply(i_group, function(j) apply(i[j, , drop = FALSE], 2, all), + logical(ncol(x)))) + } else { + vapply(i_group, function(j) all(i[j]), logical(1L)) + } +} diff --git a/man/mcstate_sampler_nested_random_walk.Rd b/man/mcstate_sampler_nested_random_walk.Rd index 265409f3..f174573f 100644 --- a/man/mcstate_sampler_nested_random_walk.Rd +++ b/man/mcstate_sampler_nested_random_walk.Rd @@ -4,12 +4,27 @@ \alias{mcstate_sampler_nested_random_walk} \title{Nested Random Walk Sampler} \usage{ -mcstate_sampler_nested_random_walk(vcv) +mcstate_sampler_nested_random_walk(vcv, boundaries = "reflect") } \arguments{ \item{vcv}{A list of variance covariance matrices. We expect this to be a list with elements \code{base} and \code{groups} corresponding to -the covariance matrix for base parameters (if any) and groups.} +the covariance matrix for base parameters (if any) and groups. + +##' @param boundaries Control the behaviour of proposals that are +outside the model domain. The supported options are: +\itemize{ +\item "reflect" (the default): we reflect proposed parameters that +lie outside the domain back into the domain (as many times as +needed) +\item "reject": we do not evaluate the density function, and return +\code{-Inf} for its density instead. +\item "ignore": evaluate the point anyway, even if it lies outside +the domain. +} + +The initial point selected will lie within the domain, as this is +enforced by \link{mcstate_sample}.} } \value{ A \code{mcstate_sampler} object, which can be used with diff --git a/tests/testthat/test-sampler-nested-random-walk.R b/tests/testthat/test-sampler-nested-random-walk.R index 39bcece6..a3709473 100644 --- a/tests/testthat/test-sampler-nested-random-walk.R +++ b/tests/testthat/test-sampler-nested-random-walk.R @@ -92,8 +92,9 @@ test_that("can build nested proposal functions", { vcv <- list(base = NULL, groups = list(v, v / 100)) g <- c(1, 1, 2, 2) + domain <- t(array(c(-Inf, Inf), c(2, 4))) - f <- nested_proposal(vcv, g, rep(0, 4)) + f <- nested_proposal(vcv, g, rep(0, 4), domain) expect_null(f$base) expect_true(is.function(f$groups)) @@ -113,9 +114,11 @@ test_that("can build nested proposal functions with base components", { v <- matrix(c(1, .5, .5, 1), 2, 2) vcv <- list(base = v / 10, groups = list(v, v / 100)) - f <- nested_proposal(vcv, c(0, 0, 1, 1, 2, 2), 1:6) - g <- - nested_proposal(list(base = NULL, groups = vcv$groups),c(1, 1, 2, 2), 3:6) + domain <- t(array(c(-Inf, Inf), c(2, 6))) + f <- nested_proposal(vcv, c(0, 0, 1, 1, 2, 2), 1:6, domain) + g <- nested_proposal(list(base = NULL, groups = vcv$groups), c(1, 1, 2, 2), + 3:6, domain[3:6, ]) + expect_true(is.function(f$base)) expect_true(is.function(f$groups)) @@ -215,4 +218,25 @@ test_that("can run nested random walk sampler simultaneously", { runner <- mcstate_runner_simultaneous() res2 <- mcstate_sample(m, sampler, 100, n_chains = 3, runner = runner) expect_equal(res1, res2) -}) \ No newline at end of file +}) + + +test_that("can run nested random walk sampler with rejecting boundaries + simultaneously", { + set.seed(1) + ng <- 5 + m <- ex_simple_nested_with_base(ng) + m$domain[, 1] <- -3 + m$domain[, 2] <- 3 + sampler <- mcstate_sampler_nested_random_walk( + list(base = diag(1), groups = rep(list(diag(1)), ng)), + boundaries = "reject") + + set.seed(1) + res1 <- mcstate_sample(m, sampler, 100, n_chains = 3) + + set.seed(1) + runner <- mcstate_runner_simultaneous() + res2 <- mcstate_sample(m, sampler, 100, n_chains = 3, runner = runner) + expect_equal(res1, res2) +}) From 0d63f3cc33afb48de90fc8bef1a2f41b8484dbb3 Mon Sep 17 00:00:00 2001 From: edknock Date: Thu, 25 Jul 2024 16:11:55 +0100 Subject: [PATCH 04/10] code coverage --- R/sampler-nested-random-walk.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 6995d5d1..197966e1 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -188,6 +188,8 @@ mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { model$parameter_groups)) if (reject_some) { + density_next <- rep(-Inf, length(state$density)) + density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) if (any(i)) { if (internal$multiple_parameters) { for (j in seq_len(ncol(i))) { @@ -202,9 +204,6 @@ mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { } density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") - } else { - density_next <- rep(-Inf, length(state$density)) - density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) } } else { density_next <- model$density(pars_next, by_group = TRUE) From 9590a2bcaf095c1955d2805cdc47fcf91579884c Mon Sep 17 00:00:00 2001 From: edknock Date: Thu, 25 Jul 2024 16:19:43 +0100 Subject: [PATCH 05/10] add some comments --- R/sampler-nested-random-walk.R | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 197966e1..d0d28c12 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -187,6 +187,11 @@ mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { !all(i <- is_parameters_in_domain_groups(pars_next, model$domain, model$parameter_groups)) + ## This bit is potentially inefficient - for any proposed parameters out of + ## bounds I substitute in the current parameters, so that we can run the + ## density on all groups. Ideally we would want to only run the density on + ## groups with all parameters in bounds. A bit fiddly to do that in a nice + ## way when doing simultaneous sampling if (reject_some) { density_next <- rep(-Inf, length(state$density)) density_by_group_next <- array(-Inf, dim2(internal$density_by_group)) @@ -227,7 +232,9 @@ mcstate_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { i <- model$parameter_groups %in% which(!accept) pars_next[i] <- state$pars[i] } - + ## If e.g. density is provided by a particle filter, would this bit + ## mean rerunning it? Increases time cost if so, and would result in + ## new value of density (different to that which was accepted) density_next <- model$density(pars_next, by_group = TRUE) density_by_group_next <- attr(density_next, "by_group") } From 1f99f8503c5a4de111fc1951fc9045be797e7561 Mon Sep 17 00:00:00 2001 From: edknock Date: Fri, 26 Jul 2024 12:07:36 +0100 Subject: [PATCH 06/10] fix docs --- R/sampler-nested-random-walk.R | 2 +- man/mcstate_sampler_nested_random_walk.Rd | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index d0d28c12..92063316 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -43,7 +43,7 @@ ##' to be a list with elements `base` and `groups` corresponding to ##' the covariance matrix for base parameters (if any) and groups. ##' -##' ##' @param boundaries Control the behaviour of proposals that are +##' @param boundaries Control the behaviour of proposals that are ##' outside the model domain. The supported options are: ##' ##' * "reflect" (the default): we reflect proposed parameters that diff --git a/man/mcstate_sampler_nested_random_walk.Rd b/man/mcstate_sampler_nested_random_walk.Rd index f174573f..7c876462 100644 --- a/man/mcstate_sampler_nested_random_walk.Rd +++ b/man/mcstate_sampler_nested_random_walk.Rd @@ -9,9 +9,9 @@ mcstate_sampler_nested_random_walk(vcv, boundaries = "reflect") \arguments{ \item{vcv}{A list of variance covariance matrices. We expect this to be a list with elements \code{base} and \code{groups} corresponding to -the covariance matrix for base parameters (if any) and groups. +the covariance matrix for base parameters (if any) and groups.} -##' @param boundaries Control the behaviour of proposals that are +\item{boundaries}{Control the behaviour of proposals that are outside the model domain. The supported options are: \itemize{ \item "reflect" (the default): we reflect proposed parameters that From a49ecd0ebcd8da37f4eb9fc84238bf43d6649861 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 21 Aug 2024 13:48:35 +0100 Subject: [PATCH 07/10] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index fb262cc2..14f4f2a6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mcstate2 Title: Next Generation mcstate -Version: 0.1.8 +Version: 0.1.9 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Imperial College of Science, Technology and Medicine", From aa864f975277220f4cbc1609d48f11b28a0d2f4e Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 21 Aug 2024 13:50:01 +0100 Subject: [PATCH 08/10] Update R/sampler-nested-random-walk.R --- R/sampler-nested-random-walk.R | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 2e4ff8f1..1188e305 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -43,21 +43,7 @@ ##' to be a list with elements `base` and `groups` corresponding to ##' the covariance matrix for base parameters (if any) and groups. ##' -##' @param boundaries Control the behaviour of proposals that are -##' outside the model domain. The supported options are: -##' -##' * "reflect" (the default): we reflect proposed parameters that -##' lie outside the domain back into the domain (as many times as -##' needed) -##' -##' * "reject": we do not evaluate the density function, and return -##' `-Inf` for its density instead. -##' -##' * "ignore": evaluate the point anyway, even if it lies outside -##' the domain. -##' -##' The initial point selected will lie within the domain, as this is -##' enforced by [mcstate_sample]. +##" @inheritParams mcstate_sampler_random_walk ##' ##' @return A `mcstate_sampler` object, which can be used with ##' [mcstate_sample] From bbdbbdf393973f8d66e0b5186ad2263f8ce51591 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 21 Aug 2024 13:51:10 +0100 Subject: [PATCH 09/10] Fix and redocument --- R/sampler-nested-random-walk.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 1188e305..b295e874 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -43,7 +43,7 @@ ##' to be a list with elements `base` and `groups` corresponding to ##' the covariance matrix for base parameters (if any) and groups. ##' -##" @inheritParams mcstate_sampler_random_walk +##' @inheritParams mcstate_sampler_random_walk ##' ##' @return A `mcstate_sampler` object, which can be used with ##' [mcstate_sample] From 7f2d7fede69c51444f185e6bf8345a63ebd60f6f Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 21 Aug 2024 14:12:48 +0100 Subject: [PATCH 10/10] Add integration helper until #46 is merged --- R/sampler-nested-adaptive.R | 69 +++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/R/sampler-nested-adaptive.R b/R/sampler-nested-adaptive.R index 7f919921..9f19417a 100644 --- a/R/sampler-nested-adaptive.R +++ b/R/sampler-nested-adaptive.R @@ -159,7 +159,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv, list(base = autocorrelation_base, groups = autocorrelation_groups) internal$vcv <- list(base = vcv_base, groups = vcv_groups) proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups) - internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups) + internal$proposal <- nested_proposal_adaptive(proposal_vcv, + model$parameter_groups) internal$history_pars <- numeric() internal$included <- integer() @@ -311,7 +312,8 @@ mcstate_sampler_nested_adaptive <- function(initial_vcv, ## Update proposal proposal_vcv <- list(base = proposal_vcv_base, groups = proposal_vcv_groups) - internal$proposal <- nested_proposal(proposal_vcv, model$parameter_groups) + internal$proposal <- nested_proposal_adaptive(proposal_vcv, + model$parameter_groups) state } @@ -425,3 +427,66 @@ check_nested_adaptive <- function(x, n_groups, has_base, null_allowed = FALSE, ret } + + +## TODO: this is a simpler version of nested_proposal that does not +## cope with boundaries etc - that's being looked at in #46 for now. +## nocov start +nested_proposal_adaptive <- function(vcv, parameter_groups, call = NULL) { + i_base <- parameter_groups == 0 + n_base <- sum(i_base) + n_groups <- max(parameter_groups) + i_group <- lapply(seq_len(n_groups), function(i) which(parameter_groups == i)) + if (NROW(vcv$base) != n_base) { + cli::cli_abort( + c("Incompatible number of base parameters in your model and sampler", + i = paste("Your model has {n_base} base parameters, but 'vcv$base'", + "implies {NROW(vcv$base)} parameters")), + call = call) + } + if (length(vcv$groups) != n_groups) { + cli::cli_abort( + c("Incompatible number of parameter groups in your model and sampler", + i = paste("Your model has {n_groups} parameter groups, but", + "'vcv$groups' has {length(vcv$groups)} groups")), + call = call) + } + n_pars_by_group <- lengths(i_group) + n_pars_by_group_vcv <- vnapply(vcv$groups, nrow) + err <- n_pars_by_group_vcv != n_pars_by_group + if (any(err)) { + detail <- sprintf( + "Group %d has %d parameters but 'vcv$groups[[%d]]' has %d", + which(err), n_pars_by_group[err], + which(err), n_pars_by_group_vcv[err]) + cli::cli_abort( + c("Incompatible number of parameters within parameter group", + set_names(detail, "i")), + call = call) + } + + has_base <- n_base > 0 + if (has_base) { + mvn_base <- make_rmvnorm(vcv$base) + proposal_base <- function(x, rng) { + ## This approach is likely to be a bit fragile, so we'll + ## probably want some naming related verification here soon too. + x[i_base] <- mvn_base(x[i_base], rng) + x + } + } else { + proposal_base <- NULL + } + + mvn_groups <- lapply(vcv$groups, make_rmvnorm) + proposal_groups <- function(x, rng) { + for (i in seq_len(n_groups)) { + x[i_group[[i]]] <- mvn_groups[[i]](x[i_group[[i]]], rng) + } + x + } + + list(base = proposal_base, + groups = proposal_groups) +} +## nocov end