Skip to content

Commit

Permalink
Merge branch 'main' into implement-rerun
Browse files Browse the repository at this point in the history
  • Loading branch information
edknock committed Oct 2, 2024
2 parents e3e44ac + 3c0ffa2 commit 7314c00
Show file tree
Hide file tree
Showing 8 changed files with 854 additions and 320 deletions.
272 changes: 209 additions & 63 deletions R/sampler-adaptive.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@
##' better, but while the chains find the location and scale of the posterior
##' mode it might be useful to explore with it switched off.
##'
##' @param boundaries Control the behaviour of proposals that are
##' outside the model domain. The supported options are:
##'
##' * "reflect" (the default): we reflect proposed parameters that
##' lie outside the domain back into the domain (as many times as
##' needed)
##'
##' * "reject": we do not evaluate the density function, and return
##' `-Inf` for its density instead.
##'
##' * "ignore": evaluate the point anyway, even if it lies outside
##' the domain.
##'
##' The initial point selected will lie within the domain, as this is
##' enforced by [monty_sample].
##'
##' @return A `monty_sampler` object, which can be used with
##' [monty_sample]
Expand Down Expand Up @@ -127,99 +142,136 @@ monty_sampler_adaptive <- function(initial_vcv,
forget_rate = 0.2,
forget_end = Inf,
adapt_end = Inf,
pre_diminish = 0) {
pre_diminish = 0,
boundaries = "reflect") {
## This sampler is stateful; we will be updating our estimate of the
## mean and vcv of the target distribution, along with the our
## scaling factor, weight and autocorrelations.
check_vcv(initial_vcv, allow_3d = TRUE, call = environment())
internal <- new.env()

boundaries <- match_value(boundaries, c("reflect", "reject", "ignore"))

initialise <- function(pars, model, observer, rng) {
require_deterministic(model,
"Can't use adaptive sampler with stochastic models")
if (is.matrix(pars)) {
cli::cli_abort(
"Can't use 'monty_sampler_adaptive' with simultaneous chains")

internal$multiple_parameters <- length(dim2(pars)) > 1
if (internal$multiple_parameters) {
## this is enforced elsewhere
stopifnot(model$properties$allow_multiple_parameters)
}

initial_vcv <- sampler_validate_vcv(initial_vcv, pars)

if (internal$multiple_parameters) {
internal$adaptive <-
Map(initialise_adaptive,
lapply(asplit(pars, 2), c),
asplit(initial_vcv, 3),
MoreArgs = list(initial_vcv_weight = initial_vcv_weight,
initial_scaling = initial_scaling,
initial_scaling_weight = initial_scaling_weight,
min_scaling = min_scaling,
scaling_increment = scaling_increment,
log_scaling_update = log_scaling_update,
acceptance_target = acceptance_target,
forget_rate = forget_rate,
forget_end = forget_end,
adapt_end = adapt_end,
pre_diminish = pre_diminish)
)
} else {
internal$adaptive <-
initialise_adaptive(pars,
initial_vcv,
initial_vcv_weight,
initial_scaling,
initial_scaling_weight,
min_scaling,
scaling_increment,
log_scaling_update,
acceptance_target,
forget_rate,
forget_end,
adapt_end,
pre_diminish)
}
internal$weight <- 0
internal$iteration <- 0

internal$mean <- unname(pars)
n_pars <- length(model$parameters)
internal$autocorrelation <- matrix(0, n_pars, n_pars)
internal$vcv <- update_vcv(internal$mean, internal$autocorrelation,
internal$weight)

internal$scaling <- initial_scaling
internal$scaling_increment <- scaling_increment %||%
calc_scaling_increment(n_pars, acceptance_target,
log_scaling_update)
internal$scaling_weight <- initial_scaling_weight %||%
5 / (acceptance_target * (1 - acceptance_target))

internal$history_pars <- numeric()
internal$included <- integer()
internal$scaling_history <- internal$scaling

initialise_state(pars, model, observer, rng)
}

step <- function(state, model, observer, rng) {
proposal_vcv <-
calc_proposal_vcv(internal$scaling, internal$vcv, internal$weight,
initial_vcv, initial_vcv_weight)
if (internal$multiple_parameters) {
d <- dim(state$pars)
proposal_vcv <-
vapply(seq_len(d[2]),
function (i)
calc_proposal_vcv(internal$adaptive[[i]]$scaling,
internal$adaptive[[i]]$vcv,
internal$adaptive[[i]]$weight,
internal$adaptive[[i]]$initial_vcv,
internal$adaptive[[i]]$initial_vcv_weight),
array(0, c(d[1], d[1])))
proposal_vcv <- array(proposal_vcv, c(d[1], d[1], d[2]))
} else {
proposal_vcv <-
calc_proposal_vcv(internal$adaptive$scaling,
internal$adaptive$vcv,
internal$adaptive$weight,
internal$adaptive$initial_vcv,
internal$adaptive$initial_vcv_weight)
}

pars_next <- rmvnorm(state$pars, proposal_vcv, rng)
proposal <-
make_random_walk_proposal(proposal_vcv, model$domain, boundaries)
pars_next <- proposal(state$pars, rng)

u <- rng$random_real(1)
density_next <- model$density(pars_next)
reject_some <- boundaries == "reject" &&
!all(i <- is_parameters_in_domain(pars_next, model$domain))
if (reject_some) {
density_next <- rep(-Inf, length(state$density))
if (any(i)) {
density_next[i] <- model$density(pars_next[, i, drop = FALSE])
}
} else {
density_next <- model$density(pars_next)
}

accept_prob <- min(1, exp(density_next - state$density))
accept_prob <- pmin(1, exp(density_next - state$density))

accept <- u < accept_prob
state <- update_state(state, pars_next, density_next, accept,
model, observer, rng)

internal$iteration <- internal$iteration + 1
internal$history_pars <- rbind(internal$history_pars, state$pars)
if (internal$iteration > adapt_end) {
internal$scaling_history <- c(internal$scaling_history, internal$scaling)
return(state)
}

if (internal$iteration > pre_diminish) {
internal$scaling_weight <- internal$scaling_weight + 1
}

is_replacement <-
check_replacement(internal$iteration, forget_rate, forget_end)
if (is_replacement) {
pars_remove <- internal$history_pars[internal$included[1L], ]
internal$included <- c(internal$included[-1L], internal$iteration)
if (internal$multiple_parameters) {
internal$adaptive <-
lapply(seq_len(dim(state$pars)[2]),
function (i) update_adaptive(internal$adaptive[[i]],
state$pars[, i],
accept_prob[i]))
} else {
pars_remove <- NULL
internal$included <- c(internal$included, internal$iteration)
internal$weight <- internal$weight + 1
internal$adaptive <-
update_adaptive(internal$adaptive, state$pars, accept_prob)
}

internal$scaling <-
update_scaling(internal$scaling, internal$scaling_weight, accept_prob,
internal$scaling_increment, min_scaling, acceptance_target,
log_scaling_update)
internal$scaling_history <- c(internal$scaling_history, internal$scaling)
internal$autocorrelation <- update_autocorrelation(
state$pars, internal$weight, internal$autocorrelation, pars_remove)
internal$mean <- update_mean(state$pars, internal$weight, internal$mean,
pars_remove)
internal$vcv <- update_vcv(internal$mean, internal$autocorrelation,
internal$weight)

state
}

finalise <- function(state, model, rng) {
out <- as.list(internal)
out[c("autocorrelation", "mean", "vcv", "weight", "included",
"scaling_history", "scaling_weight", "scaling_increment")]
out <- internal$adaptive

keep_adaptive <- c("autocorrelation", "mean", "vcv", "weight", "included",
"scaling_history", "scaling_weight", "scaling_increment")

if (internal$multiple_parameters) {
out <- lapply(out, function(x) x[keep_adaptive])
} else {
out <- out[keep_adaptive]
}

out
}

get_internal_state <- function() {
Expand All @@ -240,6 +292,100 @@ monty_sampler_adaptive <- function(initial_vcv,
}


initialise_adaptive <- function(pars,
initial_vcv,
initial_vcv_weight,
initial_scaling,
initial_scaling_weight,
min_scaling,
scaling_increment,
log_scaling_update,
acceptance_target,
forget_rate,
forget_end,
adapt_end,
pre_diminish) {
weight <- 0
iteration <- 0

mean <- unname(pars)
n_pars <- length(pars)
autocorrelation <- array(0, dim(initial_vcv))
vcv <- update_vcv(mean, autocorrelation, weight)

scaling <- initial_scaling

scaling_increment <- scaling_increment %||%
calc_scaling_increment(n_pars, acceptance_target, log_scaling_update)
scaling_weight <- initial_scaling_weight %||%
5 / (acceptance_target * (1 - acceptance_target))

history_pars <- NULL
included <- integer()
scaling_history <- scaling

list(initial_vcv = initial_vcv,
initial_vcv_weight = initial_vcv_weight,
weight = weight,
iteration = iteration,
mean = mean,
autocorrelation = autocorrelation,
vcv = vcv,
scaling = scaling,
scaling_increment = scaling_increment,
scaling_weight = scaling_weight,
min_scaling = min_scaling,
log_scaling_update = log_scaling_update,
acceptance_target = acceptance_target,
forget_rate = forget_rate,
forget_end = forget_end,
adapt_end = adapt_end,
pre_diminish = pre_diminish,
history_pars = history_pars,
included = included,
scaling_history = scaling_history
)
}

update_adaptive <- function(adaptive, pars, accept_prob) {
adaptive$iteration <- adaptive$iteration + 1
adaptive$history_pars <- rbind(adaptive$history_pars, pars)
if (adaptive$iteration > adaptive$adapt_end) {
adaptive$scaling_history <- c(adaptive$scaling_history, adaptive$scaling)
return(adaptive)
}

if (adaptive$iteration > adaptive$pre_diminish) {
adaptive$scaling_weight <- adaptive$scaling_weight + 1
}

is_replacement <- check_replacement(adaptive$iteration, adaptive$forget_rate,
adaptive$forget_end)
if (is_replacement) {
pars_remove <- adaptive$history_pars[adaptive$included[1L], ]
adaptive$included <- c(adaptive$included[-1L], adaptive$iteration)
} else {
pars_remove <- NULL
adaptive$included <- c(adaptive$included, adaptive$iteration)
adaptive$weight <- adaptive$weight + 1
}

adaptive$scaling <-
update_scaling(adaptive$scaling, adaptive$scaling_weight, accept_prob,
adaptive$scaling_increment, adaptive$min_scaling,
adaptive$acceptance_target, adaptive$log_scaling_update)
adaptive$scaling_history <- c(adaptive$scaling_history, adaptive$scaling)
adaptive$autocorrelation <- update_autocorrelation(
pars, adaptive$weight, adaptive$autocorrelation, pars_remove)
adaptive$mean <- update_mean(pars, adaptive$weight, adaptive$mean,
pars_remove)
adaptive$vcv <- update_vcv(adaptive$mean, adaptive$autocorrelation,
adaptive$weight)

adaptive
}


calc_scaling_increment <- function(n_pars, acceptance_target,
log_scaling_update) {
if (log_scaling_update) {
Expand Down
Loading

0 comments on commit 7314c00

Please sign in to comment.