-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from reichlab/1-write-get_baseline_predictions-…
…function 1 write get baseline predictions function
- Loading branch information
Showing
13 changed files
with
1,532 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
^trendsEnsemble\.Rproj$ | ||
^\.Rproj\.user$ | ||
^\.github$ | ||
^\.lintr$ | ||
^README\.Rmd$ | ||
^LICENSE\.md$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples | ||
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help | ||
on: | ||
push: | ||
branches: [main, master] | ||
pull_request: | ||
branches: [main, master] | ||
|
||
name: lint | ||
|
||
permissions: read-all | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
env: | ||
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- uses: r-lib/actions/setup-r@v2 | ||
with: | ||
use-public-rspm: true | ||
|
||
- uses: r-lib/actions/setup-r-dependencies@v2 | ||
with: | ||
extra-packages: any::lintr, local::. | ||
needs: lint | ||
|
||
- name: Lint | ||
run: lintr::lint_package() | ||
shell: Rscript {0} | ||
env: | ||
LINTR_ERROR_ON_LINT: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
linters: linters_with_defaults( | ||
line_length_linter = line_length_linter(120L), | ||
commented_code_linter = NULL | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,24 @@ Authors@R: c( | |
email = "[email protected]", | ||
comment = c(ORCID = "0000-0003-3503-9899"))) | ||
Description: A collection of functions used to create the UMass-trends_ensemble, | ||
an equally weighted ensemble of simple time series baseline models | ||
License: `use_mit_license()`, `use_gpl3_license()` or friends to pick a | ||
license | ||
an equally weighted ensemble of simple time series baseline models. | ||
License: GPL (>= 3) | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.3.1 | ||
Imports: | ||
cli, | ||
dplyr, | ||
purrr, | ||
rlang, | ||
simplets, | ||
tibble | ||
Remotes: | ||
reichlab/simplets | ||
Suggests: | ||
car, | ||
fabletools, | ||
feasts, | ||
testthat (>= 3.2.1), | ||
tidyr | ||
Config/testthat/edition: 3 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
importFrom(rlang,.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
#' Get sample and/or quantile type predictions for a single baseline model | ||
#' | ||
#' @param target_ts a `data.frame` of target data in a time series format | ||
#' (contains columns `time_index`, `location`, and `observation`) for a single | ||
#' location | ||
#' @param transformation string specifying the transformation used on the | ||
#' distribution which determines its shape; can be one of "none" or "sqrt". | ||
#' @param symmetrize boolean specifying whether to make the distribution symmetric; | ||
#' can be one of `TRUE` or `FALSE`. | ||
#' @param window_size integer specifying how many previous observations in the | ||
#' target data should be used to inform the forecasts | ||
#' @param effective_horizons numeric vector of prediction horizons relative to | ||
#' the last observed date in `target_ts` | ||
#' @param origin string specifying the origin to use when making predictions; | ||
#' recommended to be "median" if the temporal resolution is daily and "obs" | ||
#' if weekly or otherwise. Defaults to "obs". | ||
#' @param n_sim integer number of simulations to predict. Defaults to 100000. | ||
#' @param quantile_levels numeric vector of quantile levels to output; set to NULL | ||
#' if quantile outputs not requested. Defaults to NULL. | ||
#' @param n_samples integer number of samples to output, which are drawn from the | ||
#' simulated (sample) predictions (hence n_samples <= n_sim). Set to NULL | ||
#' if sample outputs not requested. Defaults to NULL. | ||
#' @param round_predictions boolean specifying whether to round the output | ||
#' predictions to the nearest whole number. Defaults to FALSE | ||
#' @param seed integer specifying a seed to set for reproducible results. | ||
#' Defaults to NULL, in which case no seed is set. | ||
#' | ||
#' @return data frame of a baseline forecast for one location, one model with | ||
#' columns `horizon` , `output_type`, `output_type_id`, and `value`, | ||
#' but which are stored as a nested list in a 1x1 data frame | ||
#' | ||
#' @importFrom rlang .data | ||
|
||
get_baseline_predictions <- function(target_ts, | ||
transformation, | ||
symmetrize, | ||
window_size, | ||
effective_horizons, | ||
origin = "obs", | ||
n_sim = 100000, | ||
quantile_levels = NULL, | ||
n_samples = NULL, | ||
round_predictions = FALSE, | ||
seed = NULL) { | ||
# validate arguments | ||
validate_target_ts(target_ts) | ||
|
||
num_locs <- length(unique(target_ts[["location"]])) | ||
if (num_locs != 1) { | ||
cli::cli_abort("{.arg target_ts} contains {.val num_locs} but only one may be provided.") | ||
} | ||
|
||
validate_variation_inputs(transformation, symmetrize, window_size) | ||
|
||
valid_origins <- c("median", "obs") | ||
if (!origin %in% valid_origins) { | ||
cli::cli_abort("{.arg origin} must be only one of {.val valid_origins}") | ||
} | ||
|
||
validate_integer(n_sim, "n_sim") | ||
|
||
if (any(quantile_levels > 1) || any(quantile_levels < 0)) { | ||
cli::cli_abort("{.arg quantile_levels} must only contain values between 0 and 1.") | ||
} | ||
|
||
if (!is.null(n_samples)) { | ||
validate_integer(n_samples, "n_samples") | ||
if (n_samples > n_sim) { | ||
cli::cli_abort("{.arg n_samples} must be less than or equal to {.arg n_sim}") | ||
} | ||
} | ||
|
||
if (is.null(quantile_levels) && is.null(n_samples)) { | ||
cli::cli_abort("No forecasts requested: both `quantile_levels` and `n_samples` are NULL") | ||
} | ||
|
||
if (!is.null(seed)) set.seed(seed) | ||
|
||
# fit | ||
baseline_fit <- simplets::fit_simple_ts( | ||
y = target_ts[["observation"]], | ||
ts_frequency = 1, | ||
model = "quantile_baseline", | ||
transformation = transformation, | ||
transform_offset = ifelse(transformation == "none", 0, 1), | ||
d = 0, | ||
D = 0, | ||
symmetrize = symmetrize, | ||
window_size = window_size | ||
) | ||
|
||
# predict | ||
predictions <- baseline_fit |> | ||
stats::predict( | ||
nsim = n_sim, | ||
horizon = max(effective_horizons), | ||
origin = origin, | ||
force_nonneg = TRUE | ||
) | ||
|
||
forecasts_df <- extract_predictions(predictions, effective_horizons, quantile_levels, n_samples) | ||
|
||
if (round_predictions) forecasts_df[["value"]] <- round(forecasts_df[["value"]], 0) | ||
return(dplyr::tibble(forecasts = list(forecasts_df))) | ||
} | ||
|
||
|
||
#' Extract and compute sample and/or quantile forecasts from a prediction matrix | ||
#' | ||
#' @param predictions a `matrix` of `n_sim` runs of sample predictions with | ||
#' dimensions n_sim x max(effective_horizons) | ||
#' @param effective_horizons numeric vector of prediction horizons relative to | ||
#' the last observed date in `target_ts` | ||
#' @param quantile_levels numeric vector of quantile levels to output; set to NULL | ||
#' if quantile outputs not requested. Defaults to NULL. | ||
#' @param n_samples integer number of samples to output, which are drawn from the | ||
#' simulated (sample) predictions (hence n_samples <= n_sim). Set to NULL | ||
#' if sample outputs not requested. Defaults to NULL. | ||
#' | ||
#' @return data frame of extracted sample and/or quantile forecasts with | ||
#' columns `horizon` , `output_type`, `output_type_id`, and `value` | ||
#' | ||
extract_predictions <- function(predictions, | ||
effective_horizons, | ||
quantile_levels = NULL, | ||
n_samples = NULL) { | ||
samples_df <- NULL | ||
if (!is.null(n_samples)) { | ||
samples_df <- effective_horizons |> | ||
purrr::map( | ||
function(h) { | ||
data.frame( | ||
horizon = rep(h, n_samples), | ||
value = predictions[1:n_samples, h] | ||
) |> | ||
tibble::rownames_to_column(var = "output_type_id") |> | ||
dplyr::mutate( | ||
output_type = "sample", | ||
output_type_id = as.numeric(dplyr::row_number()), | ||
.before = 2 | ||
) |> | ||
dplyr::select("horizon", "output_type", "output_type_id", "value") | ||
} | ||
) |> | ||
purrr::list_rbind() | ||
} | ||
quantiles_df <- NULL | ||
if (!is.null(quantile_levels)) { | ||
n <- length(quantile_levels) | ||
quantiles_df <- effective_horizons |> | ||
purrr::map( | ||
function(h) { | ||
data.frame( | ||
horizon = rep(h, n), | ||
output_type = "quantile", | ||
output_type_id = quantile_levels, | ||
value = stats::quantile(predictions[, h], probs = quantile_levels) | ||
) | ||
} | ||
) |> | ||
purrr::list_rbind() | ||
} | ||
|
||
combined_df <- rbind(samples_df, quantiles_df) | ||
return(combined_df) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#' Perform simple validations on the model variations dataframe used to define a | ||
#' given baseline model | ||
#' | ||
#' @param model_variations a `data.frame` where each row specifies a set of | ||
#' hyperparameters to use for a single baseline model fit, with the following | ||
#' columns: `transformation`, `symmetrize`, and `window_size`. See the details | ||
#' for more information. | ||
#' @details The types, and possible values for each of the columns in | ||
#' `model_variations` are as follows: | ||
#' - transformation (character): "none" or "sqrt" | ||
#' - symmetric (boolean) | ||
#' - window_size (numeric): a non-negative integer | ||
#' Additional validations that check each column for the correct type are performed | ||
#' by `validate_variation_inputs()`, which will always be called following a call to | ||
#' `validate_model_inputs()`. | ||
#' | ||
#' @return no return value | ||
#' | ||
#' @noRd | ||
|
||
validate_model_variations <- function(model_variations) { | ||
if (is.null(model_variations)) { | ||
cli::cli_abort("{.arg model_variations} is missing") | ||
} | ||
|
||
variation_col <- c("transformation", "symmetrize", "window_size") | ||
validate_colnames(model_variations, variation_col, "model_variations") | ||
|
||
if (any(duplicated(model_variations))) { | ||
cli::cli_abort("{.arg model_variations} contains duplicate rows.") | ||
} | ||
} | ||
|
||
|
||
#' Perform simple validations on the individual variables defining a single | ||
#' baseline model | ||
#' | ||
#' @param transformation string specifying the transformation used on the | ||
#' distribution which determines its shape; can be one of "none" or "sqrt". | ||
#' @param symmetrize boolean specifying whether to make the distribution symmetric; | ||
#' can be one of `TRUE` or `FALSE`. | ||
#' @param window_size integer specifying how many previous observations in the | ||
#' target data should be used to inform the forecasts | ||
#' | ||
#' @return no return value | ||
#' | ||
#' @noRd | ||
validate_variation_inputs <- function(transformation, symmetrize, window_size) { | ||
# check variation inputs have length 1 | ||
if (length(transformation) != 1) { | ||
cli::cli_abort("{.arg transformation} must be length 1") | ||
} | ||
|
||
if (length(symmetrize) != 1) { | ||
cli::cli_abort("{.arg symmetrize} must be length 1") | ||
} | ||
|
||
if (length(window_size) != 1) { | ||
cli::cli_abort("{.arg window_size} must be length 1") | ||
} | ||
|
||
# check variation inputs contain only valid values | ||
valid_transformations <- c("none", "sqrt") | ||
if (!transformation %in% valid_transformations) { | ||
cli::cli_abort("{.arg transformation} must only contain values {.val {valid_transformations}}") | ||
} | ||
|
||
if (!inherits(symmetrize, "logical")) { | ||
cli::cli_abort("{.arg symmetrize} must only contain logical values, e.g. TRUE or FALSE.") | ||
} | ||
|
||
if (window_size != trunc(window_size) || window_size < 0) { | ||
cli::cli_abort("{.arg window_size} must only contain non-negative integer values.") | ||
} | ||
} | ||
|
||
|
||
#' Perform simple validations on the target data (time series) dataframe | ||
#' | ||
#' @param target_ts a `data.frame` of target data in a time series format | ||
#' (contains columns `time_index`, `location`, and `observation`) | ||
#' | ||
#' @return no return value | ||
#' | ||
#' @noRd | ||
validate_target_ts <- function(target_ts) { | ||
target_col <- c("time_index", "location", "observation") | ||
validate_colnames(target_ts, target_col, "target_ts") | ||
|
||
if (any(duplicated(target_ts))) { | ||
cli::cli_abort("{.arg target_ts} contains duplicate rows.") | ||
} | ||
} | ||
|
||
|
||
#' Validate that a dataframe's columns are (named) as expected | ||
#' | ||
#' @param df a `data.frame` whose columns are to be validated | ||
#' @param expected_col a character vector of expected column names | ||
#' @param arg_name character string name of the argument being validated to be | ||
#' printed in the error message(generally the name of the `df` object) | ||
#' | ||
#' @return no return value | ||
#' | ||
#' @noRd | ||
validate_colnames <- function(df, expected_col, arg_name) { | ||
actual_col <- colnames(df) | ||
if (!all(expected_col %in% actual_col)) { | ||
cli::cli_abort("{.arg {arg_name}} is missing the column{?s}: {.val {setdiff(expected_col, actual_col)}}.") | ||
} | ||
if (!all(actual_col %in% expected_col)) { | ||
cli::cli_abort(c( | ||
x = "{.arg {arg_name}} contains the extra column{?s}: {.val {setdiff(actual_col, expected_col)}}." | ||
)) | ||
} | ||
} | ||
|
||
|
||
#' Validate that an integer is as expected and non-negative | ||
#' | ||
#' @param int a single integer to be validated | ||
#' @param arg_name character string name of the argument being validated to be | ||
#' printed in the error message(generally the name of the `int` object) | ||
#' | ||
#' @return no return value | ||
#' | ||
#' @noRd | ||
validate_integer <- function(int, arg_name) { | ||
if (!is.numeric(int) || int < 0 || int != trunc(int) || length(int) != 1) { | ||
cli::cli_abort("{.arg {arg_name}} must be a single, non-negative integer value.") | ||
} | ||
} |
Oops, something went wrong.