Skip to content

Commit

Permalink
Avoid redundant transformation for forecasts with bootstrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelloharawild committed Sep 15, 2024
1 parent 445d65f commit ff30434
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 70 deletions.
96 changes: 47 additions & 49 deletions R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,14 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL
# Compute forecasts
if(simulate || bootstrap) {
fc <- generate(object, new_data, bootstrap = bootstrap, times = times, ...)
fc <- unname(split(object$transformation[[1]](fc[[".sim"]]), fc[[index_var(fc)]]))
fc <- distributional::dist_sample(fc)
fc_idx <- fc[[index_var(fc)]]
fc <- if (length(resp_vars) > 1) {
do.call(cbind, fc[resp_vars])
} else {
fc[[".sim"]]
}

fc <- distributional::dist_sample(vctrs::vec_split(fc, fc_idx)$val)
} else {
# Compute specials with new_data
object$model$stage <- "forecast"
Expand All @@ -180,54 +186,46 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL
object$model$remove_data()
object$model$stage <- NULL
fc <- forecast(object$fit, new_data, specials = specials, times = times, ...)
}

# Back-transform forecast distributions
bt <- map(object$transformation, function(x){
trans <- x%@%"inverse"
inv_trans <- `attributes<-`(x, NULL)
req_vars <- setdiff(all.vars(body(trans)), names(formals(trans)))
if(any(req_vars %in% names(new_data))) {
trans <- lapply(
vec_chop(new_data[req_vars]),
function(transform_data) {
set_env(trans, new_environment(transform_data, get_env(trans)))
}
)
attr(trans, "inverse") <- lapply(
vec_chop(new_data[req_vars]),
function(transform_data) {
set_env(inv_trans, new_environment(transform_data, get_env(inv_trans)))
}
)
trans
} else {
structure(list(trans), inverse = list(inv_trans))

# Back-transform forecast distributions
bt <- map(object$transformation, function(x){
trans <- x%@%"inverse"
inv_trans <- `attributes<-`(x, NULL)
req_vars <- setdiff(all.vars(body(trans)), names(formals(trans)))
if(any(req_vars %in% names(new_data))) {
trans <- lapply(
vec_chop(new_data[req_vars]),
function(transform_data) {
set_env(trans, new_environment(transform_data, get_env(trans)))
}
)
attr(trans, "inverse") <- lapply(
vec_chop(new_data[req_vars]),
function(transform_data) {
set_env(inv_trans, new_environment(transform_data, get_env(inv_trans)))
}
)
trans
} else {
structure(list(trans), inverse = list(inv_trans))
}
})

is_transformed <- vapply(bt, function(x) !is_symbol(body(x[[1]])), logical(1L))
if(length(bt) > 1) {
if(any(is_transformed)){
abort("Transformations of multivariate forecasts distributions are not supported, use simulate = TRUE or bootstrap = TRUE.")
}
}
# exists_vars <- map_lgl(req_vars, exists, env)
# if(any(!exists_vars)){
# bt <- custom_error(bt, sprintf(
# "Unable to find all required variables to back-transform the forecasts (missing %s).
# These required variables can be provided by specifying `new_data`.",
# paste0("`", req_vars[!exists_vars], "`", collapse = ", ")
# ))
# }
})

is_transformed <- vapply(bt, function(x) !is_symbol(body(x[[1]])), logical(1L))
if(length(bt) > 1) {
if(any(is_transformed)){
abort("Transformations of multivariate forecasts are not yet supported")
}
}
if(any(is_transformed)) {
if (identical(unique(dist_types(fc)), "dist_sample")) {
fc <- distributional::dist_sample(
.mapply(exec, list(bt[[1]], distributional::parameters(fc)$x), MoreArgs = NULL)
)
} else {
bt <- bt[[1]]
fc <- distributional::dist_transformed(fc, `attributes<-`(bt, NULL), bt%@%"inverse")
if(any(is_transformed)) {
if (identical(unique(dist_types(fc)), "dist_sample")) {
fc <- distributional::dist_sample(
.mapply(exec, list(bt[[1]], distributional::parameters(fc)$x), MoreArgs = NULL)
)
} else {
bt <- bt[[1]]
fc <- distributional::dist_transformed(fc, `attributes<-`(bt, NULL), bt%@%"inverse")
}
}
}

Expand Down
26 changes: 12 additions & 14 deletions R/generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' Innovations are sampled by the model's assumed error distribution.
#' If `bootstrap` is `TRUE`, innovations will be sampled from the model's
#' residuals. If `new_data` contains the `.innov` column, those values will be
#' treated as innovations for the simulated paths..
#' treated as innovations for the simulated paths.
#'
#' @param x A mable.
#' @param new_data The data to be generated (time index and exogenous regressors)
Expand Down Expand Up @@ -77,20 +77,18 @@ generate.mdl_ts <- function(x, new_data = NULL, h = NULL, times = 1, seed = NULL
}

if(bootstrap) {
if(length(x$response) > 1) abort("Generating bootstrap paths from multivariate models is not yet supported.")
res <- residuals(x$fit)
res <- stats::na.omit(res) - mean(res, na.rm = TRUE)
new_data$.innov <- if(bootstrap_block_size == 1) {
sample(res, nrow(new_data), replace = TRUE)
f_mean <- if(length(x$response) == 1) mean else colMeans
res <- stats::na.omit(res) - f_mean(res, na.rm = TRUE)
i <- if(bootstrap_block_size == 1) {
sample.int(NROW(res), nrow(new_data), replace = TRUE)
} else {
if(any(has_gaps(x$data)$.gaps)) abort("Residuals must be regularly spaced without gaps to use a block bootstrap method.")
kr <- tsibble::key_rows(new_data)
# idx <- x$data[[index_var(x$data)]]
# new_idx <- new_data[[index_var(new_data)]]
# block_pos <- ((new_idx - min(idx))%%bootstrap_block_size)+1
innov <- lapply(lengths(kr), function(n) block_bootstrap(res, bootstrap_block_size, size = n))
vec_c(!!!innov)
ki <- lapply(lengths(kr), function(n) block_bootstrap(NROW(res), bootstrap_block_size, size = n))
vec_c(!!!ki)
}
new_data$.innov <- if(length(x$response) == 1) res[i] else res[i,]
}

# Compute specials with new_data
Expand Down Expand Up @@ -124,13 +122,13 @@ Does your model require extra variables to produce simulations?", e$message))
.sim
}

block_bootstrap <- function (x, window_size, size = length(x)) {
block_bootstrap <- function (n, window_size, size = length(x)) {
n_blocks <- size%/%window_size + 2
bx <- numeric(n_blocks * window_size)
for (i in seq_len(n_blocks)) {
block_pos <- sample(seq_len(length(x) - window_size + 1), 1)
bx[((i - 1) * window_size + 1):(i * window_size)] <- x[block_pos:(block_pos + window_size - 1)]
block_pos <- sample(seq_len(n - window_size + 1), 1)
bx[((i - 1) * window_size + 1):(i * window_size)] <- block_pos:(block_pos + window_size - 1)
}
start_from <- sample(0:(window_size - 1), 1) + 1
start_from <- sample.int(window_size, 1)
bx[seq(start_from, length.out = size)]
}
4 changes: 3 additions & 1 deletion R/irf.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#' This function calculates the impulse response function (IRF) of a time series model.
#' The IRF describes how a model's variables react to external shocks over time.
#'
#' If `new_data` contains the `.impulse` column, those values will be
#' treated as impulses for the calculated impulse responses.
#'
#' @param x A fitted model object, such as from a VAR or ARIMA model. This model is used to compute the impulse response.
#' @param impulse A character string specifying the name of the variable that is shocked (the impulse variable).
#' @param ... Additional arguments to be passed to lower-level functions.
#'
#' @details
Expand Down
4 changes: 3 additions & 1 deletion inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@ Blaskowitz
CRPS
Heeyoung
Herwartz
IRF
JRSS
MAAPE
MASE
MatrixM
MinT
ORCID
Sungil
VAR
Wickramasuriya
backtransform
dable
dables
doi
dplyr
emperical
env
erroring
etc
forecast's
Expand All @@ -38,6 +39,7 @@ seasonalities
superceded
tibble
tidyr
tidyselect
tidyverse
tidyverts
tsibble
Expand Down
5 changes: 0 additions & 5 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ test_that("generate", {
expect_equal(gen_complex$index, yearmonth("1979 Jan") + rep(0:23, 2*2*3))
expect_equal(unique(gen_complex$key), c("fdeaths", "mdeaths"))
expect_equal(unique(gen_complex$.model), c("ets", "lm"))

expect_error(
mbl_mv %>% generate(),
"Generating paths from multivariate models is not yet supported"
)
})

test_that("generate seed setting", {
Expand Down

0 comments on commit ff30434

Please sign in to comment.