Skip to content

Commit

Permalink
split n_samples param into n_sim and itself
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Sep 6, 2024
1 parent 6f0e8bd commit 1401c96
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
20 changes: 16 additions & 4 deletions R/get_baseline_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
#' @param origin string specifying the origin to use when making predictions;
#' recommended to be "median" if the temporal resolution is daily and "obs"
#' if weekly or otherwise. Defaults to "obs".
#' @param n_sim integer number of simulations to predict. Defaults to 100000.
#' @param quantile_levels numeric vector of quantile levels to output; set to NULL
#' if quantile outputs not requested. Defaults to NULL.
#' @param n_samples integer number of samples to output (and predict);
#' Defaults to NULL, in which case 100000 samples are generated.
#' @param n_samples integer number of samples to output, which are drawn from the
#' simulated (sample) predictions (hence n_samples <= n_sim). Set to NULL
#' if sample outputs not requested. Defaults to NULL.
#' @param round_predictions boolean specifying whether to round the output
#' predictions to the nearest whole number. Defaults to FALSE
#' @param seed integer specifying a seed to set for reproducible results.
Expand All @@ -35,6 +37,7 @@ get_baseline_predictions <- function(target_ts,
window_size,
effective_horizons,
origin = "obs",
n_sim = 100000,
quantile_levels = NULL,
n_samples = NULL,
round_predictions = FALSE,
Expand All @@ -54,10 +57,19 @@ get_baseline_predictions <- function(target_ts,
cli::cli_abort("{.arg origin} must be only one of {.val valid_origins}")
}

if (!is.numeric(n_sim) || n_sim < 0 || n_sim != trunc(n_sim) || length(n_sim) != 1) {
cli::cli_abort("{.arg n_sim} must be a single, non-negative integer value.")
}

if (any(quantile_levels > 1) || any(quantile_levels < 0)) {
cli::cli_abort("{.arg quantile_levels} must only contain values between 0 and 1.")
}

if (!is.null(n_samples) &&
(n_samples > n_sim || n_samples < 0 || n_samples != trunc(n_samples) || length(n_samples) != 1)) {
cli::cli_abort("{.arg n_samples} must be a single, non-negative integer value.")
}

if (is.null(quantile_levels) && is.null(n_samples)) {
cli::cli_abort("No forecasts requested: both `quantile_levels` and `n_samples` are NULL")
}
Expand All @@ -80,7 +92,7 @@ get_baseline_predictions <- function(target_ts,
# predict
predictions <- baseline_fit |>
stats::predict(
nsim = ifelse(is.null(n_samples), 100000, n_samples),
nsim = n_sim,
horizon = max(effective_horizons),
origin = origin,
force_nonneg = TRUE
Expand All @@ -94,7 +106,7 @@ get_baseline_predictions <- function(target_ts,
function(h) {
data.frame(
horizon = rep(h, n_samples),
value = predictions[, h]
value = predictions[1:n_samples, h]
) |>
tibble::rownames_to_column(var = "output_type_id") |>
dplyr::mutate(
Expand Down
9 changes: 6 additions & 3 deletions man/get_baseline_predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test-get_baseline_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ test_that("invalid origin value throws an error", {
)
})

test_that("invalid n_sim value throws an error", {
get_baseline_predictions(target_ts,
transformation = "none",
symmetrize = TRUE,
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = "test",
quantile_levels = NULL,
n_samples = 100) |>
expect_error(
regexp = "`n_sim` must be a single, non-negative integer value.", fixed = TRUE
)
})

test_that("invalid quantile_levels value throws an error", {
get_baseline_predictions(target_ts,
transformation = "none",
Expand All @@ -49,6 +64,21 @@ test_that("invalid quantile_levels value throws an error", {
)
})

test_that("invalid n_samples value throws an error", {
get_baseline_predictions(target_ts,
transformation = "none",
symmetrize = TRUE,
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 100000,
quantile_levels = NULL,
n_samples = 10.5) |>
expect_error(
regexp = "`n_samples` must be a single, non-negative integer value.", fixed = TRUE
)
})

test_that("not requesting any forecasts throws an error", {
get_baseline_predictions(target_ts,
transformation = "none",
Expand All @@ -72,6 +102,7 @@ test_that("only non-negative forecast values are returned", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = NULL,
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -88,6 +119,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = NULL,
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -99,6 +131,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = NULL,
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -111,6 +144,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
round_predictions = FALSE,
Expand All @@ -122,6 +156,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
round_predictions = FALSE,
Expand All @@ -134,6 +169,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -145,6 +181,7 @@ test_that("results are reproducible", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = 10000,
round_predictions = FALSE,
Expand Down Expand Up @@ -187,6 +224,7 @@ test_that("the correct combination of horizon, output types, and output type IDs
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = NULL,
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -200,6 +238,7 @@ test_that("the correct combination of horizon, output types, and output type IDs
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
round_predictions = FALSE,
Expand All @@ -213,6 +252,7 @@ test_that("the correct combination of horizon, output types, and output type IDs
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(.1, .5, .9),
n_samples = 10000,
round_predictions = FALSE,
Expand Down Expand Up @@ -254,6 +294,7 @@ test_that("forecasts are correctly rounded when requested", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = c(0.1, 0.5, 0.9),
n_samples = 10000,
round_predictions = TRUE,
Expand Down Expand Up @@ -285,6 +326,7 @@ test_that("quantile forecasts are correctly calculated", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = NULL,
n_samples = 10000,
round_predictions = FALSE,
Expand All @@ -307,6 +349,7 @@ test_that("quantile forecasts are correctly calculated", {
window_size = 3,
effective_horizons = 1:4,
origin = "obs",
n_sim = 10000,
quantile_levels = quantile_levels,
n_samples = NULL,
round_predictions = FALSE,
Expand Down

0 comments on commit 1401c96

Please sign in to comment.