From 4df345f6f30eddb108cc98cff73f47424190896d Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:39:59 -0500 Subject: [PATCH 01/11] Write `create_trends_ensemble()` function --- R/create_trends_ensemble.R | 154 ++++++++++++++++++++++++++++++++++ man/create_trends_ensemble.Rd | 82 ++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 R/create_trends_ensemble.R create mode 100644 man/create_trends_ensemble.Rd diff --git a/R/create_trends_ensemble.R b/R/create_trends_ensemble.R new file mode 100644 index 0000000..a4154dd --- /dev/null +++ b/R/create_trends_ensemble.R @@ -0,0 +1,154 @@ +#' Generate predictions for the trends ensemble, a quantile median of component +#' baseline models +#' +#' @param component_variations a `data.frame` where each row specifies a set of +#' hyperparameters to use for a single baseline model fit, with columns +#' `transformation`, `symmetrize`, `window_size`, and `temporal_horizon`. +#' See details for more information +#' @param target_ts a `data.frame` of target data in a time series format +#' (contains columns `time_index`, `location`, and `observation`) for a single +#' location +#' @param reference_date string of the reference date for the forecasts, i.e. +#' the date relative to which the targets are defined (usually Saturday for +#' weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended. +#' @param horizons numeric vector of prediction horizons relative to +#' the reference_date, e.g. 0:3 or 1:4 +#' @param target character string specifying the name of the prediction target +#' @param quantile_levels numeric vector of quantile levels to output; set to NULL +#' if quantile outputs not requested. Defaults to NULL. +#' @param n_samples integer of amount of samples to output (and predict); +#' set to NULL if sample outputs not requested (in this case 100000 samples +#' are generated from which to extract quantiles). Defaults to NULL. +#' @param round_predictions boolean specifying whether to round the output +#' predictions to the nearest whole number. Defaults to FALSE +#' @param seed integer specifying a seed to set for reproducible results. +#' Defaults to NULL, in which case no seed is set. +#' @param return_baseline_predictions boolean specifying whether to the component +#' baseline models' forecasts in addition to the trends ensemble forecasts. +#' If TRUE, a two-item list will be returned containing a labeled model_out_tbl +#' of each. Defaults to FALSE. +#' +#' @details The `component_variations` data frame has the following columns and +#' possible values for each: +#' - transformation (character): "none" or "sqrt", determines distribution shape +#' - symmetrize (boolean), determines if distribution is symmetric +#' - window_size (integer), determines how many previous observations inform +#' the forecast +#' - temporal_resolution (character): "daily" or "weekly" +#' +#' Note that it must be possible to aggregate the `target_ts` data to the +#' temporal resolution values given in `component_variations`. For example, if +#' `target_ts` contains weekly observations but `component_variations` requests +#' models with a "daily" temporal resolution, an error will be thrown +#' +#' @return `model_out_tbl` of trends ensemble forecasts with columns: +#' `model_id`, `reference_date`, `location`, `horizon`, `target`, +#' `target_end_date`, `output_type`, `output_type_id`, and `value`. +#' +#' @importFrom rlang .data +#' +#' @export +create_trends_ensemble <- function(component_variations, + target_ts, + reference_date, + horizons, + target, + quantile_levels, + n_samples = NULL, + round_predictions = FALSE, + seed = NULL, + return_baseline_predictions = FALSE) { + valid_temp_res <- c("daily", "weekly") + temp_res_variations <- component_variations |> + dplyr::distinct(name = .data[["temporal_resolution"]], .keep_all = FALSE) |> + dplyr::mutate(num_days = dplyr::case_when( + .data[["name"]] == "daily" ~ 1, + .data[["name"]] == "weekly" ~ 7, + .default = NA + )) + if (!all(temp_res_variations$name %in% valid_temp_res)) { + cli::cli_abort("{.arg component_variations} must only include temporal resolution values {.val valid_temp_res}") + } + if (nrow(temp_res_variations) > 1) { + cli::cli_abort("Currently {.arg component_variations} may only contain one unique temporal resolution value") + } + + validate_target_ts(target_ts) + ts_dates_desc <- sort(unique(target_ts$time_index), decreasing = TRUE) + ts_temp_res <- as.integer(ts_dates_desc[1] - ts_dates_desc[2]) + if (any(temp_res_variations$num_days %% ts_temp_res != 0)) { + cli::cli_abort(c( + x = "Cannot match temporal resolution of provided {.arg target_ts} + to those requested in {.arg component_variations}.", + i = "{.arg target_ts} must aggregate to all requested temporal resolutions." + )) + } + + # calculate baseline models' forecasts + split_variations <- component_variations |> + split(f = component_variations$temporal_resolution) + component_outputs <- split_variations |> + purrr::map(.f = function(model_variations) { + current_temp_res <- temp_res_variations[temp_res_variations$name == model_variations$temporal_resolution[1], ] + if (current_temp_res$num_days > ts_temp_res) { + new_horizon_min <- floor(min(horizons) / current_temp_res$num_days) + new_horizon_max <- ceiling(max(horizons) / current_temp_res$num_days) + model_variations |> + dplyr::select(-"temporal_resolution") |> + fit_baseline_models(aggregate_daily_to_weekly(target_ts), + reference_date, + current_temp_res$name, + new_horizon_min:new_horizon_max, + target, + quantile_levels, + n_samples, + round_predictions, + seed) + } else { + model_variations |> + dplyr::select(-"temporal_resolution") |> + fit_baseline_models(target_ts, + reference_date, + current_temp_res$name[1], + horizons, + target, + quantile_levels, + n_samples, + round_predictions, + seed) + } + }) |> + purrr::list_rbind() + + # build ensemble + split_components <- split(component_outputs, + f = component_outputs$output_type + ) + ensemble_outputs <- split_components |> + purrr::map(.f = function(split_outputs) { + type <- split_outputs$output_type[1] + if (type == "quantile") { + hubEnsembles::simple_ensemble( + split_outputs, + agg_fun = "median", + model_id = "UMass-trends_ensemble" + ) + } else if (type == "sample") { + hubEnsembles::linear_pool( + split_outputs, + model_id = "UMass-trends_ensemble" + ) + } + }) |> + purrr::list_rbind() |> + dplyr::mutate( + reference_date = as.Date(reference_date), + target_end_date = as.Date(.data[["target_end_date"]]) + ) + + if (return_baseline_predictions == TRUE) { + list(ensemble = ensemble_outputs, baselines = component_outputs) + } else { + ensemble_outputs + } +} diff --git a/man/create_trends_ensemble.Rd b/man/create_trends_ensemble.Rd new file mode 100644 index 0000000..4c7a207 --- /dev/null +++ b/man/create_trends_ensemble.Rd @@ -0,0 +1,82 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create_trends_ensemble.R +\name{create_trends_ensemble} +\alias{create_trends_ensemble} +\title{Generate predictions for the trends ensemble, a quantile median of component +baseline models} +\usage{ +create_trends_ensemble( + component_variations, + target_ts, + reference_date, + horizons, + target, + quantile_levels, + n_samples = NULL, + round_predictions = FALSE, + seed = NULL, + return_baseline_predictions = FALSE +) +} +\arguments{ +\item{component_variations}{a \code{data.frame} where each row specifies a set of +hyperparameters to use for a single baseline model fit, with columns +\code{transformation}, \code{symmetrize}, \code{window_size}, and \code{temporal_horizon}. +See details for more information} + +\item{target_ts}{a \code{data.frame} of target data in a time series format +(contains columns \code{time_index}, \code{location}, and \code{observation}) for a single +location} + +\item{reference_date}{string of the reference date for the forecasts, i.e. +the date relative to which the targets are defined (usually Saturday for +weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended.} + +\item{horizons}{numeric vector of prediction horizons relative to +the reference_date, e.g. 0:3 or 1:4} + +\item{target}{character string specifying the name of the prediction target} + +\item{quantile_levels}{numeric vector of quantile levels to output; set to NULL +if quantile outputs not requested. Defaults to NULL.} + +\item{n_samples}{integer of amount of samples to output (and predict); +set to NULL if sample outputs not requested (in this case 100000 samples +are generated from which to extract quantiles). Defaults to NULL.} + +\item{round_predictions}{boolean specifying whether to round the output +predictions to the nearest whole number. Defaults to FALSE} + +\item{seed}{integer specifying a seed to set for reproducible results. +Defaults to NULL, in which case no seed is set.} + +\item{return_baseline_predictions}{boolean specifying whether to the component +baseline models' forecasts in addition to the trends ensemble forecasts. +If TRUE, a two-item list will be returned containing a labeled model_out_tbl +of each. Defaults to FALSE.} +} +\value{ +\code{model_out_tbl} of trends ensemble forecasts with columns: +\code{model_id}, \code{reference_date}, \code{location}, \code{horizon}, \code{target}, +\code{target_end_date}, \code{output_type}, \code{output_type_id}, and \code{value}. +} +\description{ +Generate predictions for the trends ensemble, a quantile median of component +baseline models +} +\details{ +The \code{component_variations} data frame has the following columns and +possible values for each: +\itemize{ +\item transformation (character): "none" or "sqrt", determines distribution shape +\item symmetrize (boolean), determines if distribution is symmetric +\item window_size (integer), determines how many previous observations inform +the forecast +\item temporal_resolution (character): "daily" or "weekly" +} + +Note that it must be possible to aggregate the \code{target_ts} data to the +temporal resolution values given in \code{component_variations}. For example, if +\code{target_ts} contains weekly observations but \code{component_variations} requests +models with a "daily" temporal resolution, an error will be thrown +} From 99382e74a2fe6bbd90f920cd1bcf3e6fd4325a47 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:13:38 -0500 Subject: [PATCH 02/11] Write helper `aggregate_daily_to_weekly()` for target data --- R/aggregate_daily_to_weekly.R | 29 +++++++++++++++++++++++++++++ man/aggregate_daily_to_weekly.Rd | 21 +++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 R/aggregate_daily_to_weekly.R create mode 100644 man/aggregate_daily_to_weekly.Rd diff --git a/R/aggregate_daily_to_weekly.R b/R/aggregate_daily_to_weekly.R new file mode 100644 index 0000000..e088960 --- /dev/null +++ b/R/aggregate_daily_to_weekly.R @@ -0,0 +1,29 @@ +#' Aggregate daily data to weekly data +#' +#' Counts weeks as beginning on Sunday and ending on Saturday. Drops observations +#' from the most recent week if not a full 7 days of data +#' +#' @param target_ts a `data.frame` of target data in a time series format +#' (contains columns `time_index`, `location`, and `observation`) for a single +#' location +#' +#' @return data.frame of time series data with the same set of input columns, with +#' weekly-aggregated data in `observation` column +aggregate_daily_to_weekly <- function(target_ts) { + validate_target_ts(target_ts) + most_recent_date <- max(target_ts$time_index) + + target_ts |> + dplyr::mutate( + sat_date = lubridate::ceiling_date( + lubridate::ymd(.data[["time_index"]]), + unit = "week" + ) - 1, + .before = "observation" + ) |> + dplyr::group_by(dplyr::across(dplyr::all_of(c("location", "sat_date")))) |> + dplyr::filter(.data[["sat_date"]] <= most_recent_date) |> + dplyr::summarize(observation = sum(.data[["observation"]], na.rm = FALSE)) |> + dplyr::rename(time_index = "sat_date") |> + dplyr::ungroup() +} diff --git a/man/aggregate_daily_to_weekly.Rd b/man/aggregate_daily_to_weekly.Rd new file mode 100644 index 0000000..2159364 --- /dev/null +++ b/man/aggregate_daily_to_weekly.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aggregate_daily_to_weekly.R +\name{aggregate_daily_to_weekly} +\alias{aggregate_daily_to_weekly} +\title{Aggregate daily data to weekly data} +\usage{ +aggregate_daily_to_weekly(target_ts) +} +\arguments{ +\item{target_ts}{a \code{data.frame} of target data in a time series format +(contains columns \code{time_index}, \code{location}, and \code{observation}) for a single +location} +} +\value{ +data.frame of time series data with the same set of input columns, with +weekly-aggregated data in \code{observation} column +} +\description{ +Counts weeks as beginning on Sunday and ending on Saturday. Drops observations +from the most recent week if not a full 7 days of data +} From a8e58ce2d611a517ea534d9b49d975b473b275a7 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:13:48 -0500 Subject: [PATCH 03/11] Update DESCRIPTION --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 0d147e6..ac173be 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,6 +31,7 @@ RoxygenNote: 7.3.2 Imports: cli, dplyr, + hubEnsembles, hubUtils, purrr, rlang, From b8c9f3dffb9d8b3efd6204b1d397ac41cc8eff59 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:14:00 -0500 Subject: [PATCH 04/11] Update NAMESPACE --- NAMESPACE | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 857392c..ba47002 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,3 +1,5 @@ # Generated by roxygen2: do not edit by hand +export(create_trends_ensemble) +export(fit_baseline_models) importFrom(rlang,.data) From 8f6483baaa46f738da4b749ee1b0f24a14c081ab Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:14:57 -0500 Subject: [PATCH 05/11] Write `create_trends_ensemble()` tests --- tests/testthat/test-create_trends_ensemble.R | 189 +++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/testthat/test-create_trends_ensemble.R diff --git a/tests/testthat/test-create_trends_ensemble.R b/tests/testthat/test-create_trends_ensemble.R new file mode 100644 index 0000000..4f441b8 --- /dev/null +++ b/tests/testthat/test-create_trends_ensemble.R @@ -0,0 +1,189 @@ +#set up variations of baseline to fit +daily_variations <- tidyr::expand_grid( + transformation = "none", + symmetrize = TRUE, + window_size = c(14 - 1, 7 - 1), + temporal_resolution = "daily" +) +weekly_variations <- tidyr::expand_grid( + transformation = "none", + symmetrize = TRUE, + window_size = c(2, 1), + temporal_resolution = "weekly" +) + +daily_ts <- expand.grid( + stringsAsFactors = FALSE, + location = c("ak", "al"), + time_index = as.Date("2022-11-05") + 1:28, + observation = NA +) +daily_ts$observation[daily_ts$location == "ak"] <- + c(8, 9, 6, 7, 3, 6, 6, 5, 4, 11, 10, 3, 4, 3, + 3, 8, 6, 7, 10, 6, 4, 4, 6, 5, 2, 5, 6, 5) +daily_ts$observation[daily_ts$location == "al"] <- + c(27, 19, 20, 16, 19, 22, 20, 21, 18, 25, 17, 25, 27, 22, + 15, 32, 21, 26, 18, 14, 14, 14, 24, 23, 16, 20, 14, 44) + +test_that("unsupported temporal_resolution values in component_variations throws an error", { + daily_variations |> + dplyr::mutate(temporal_resolution = "monthly") |> + dplyr::bind_rows(daily_variations) |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_error(regex = "`component_variations` must only include temporal resolution values", + fixed = TRUE) +}) + +test_that("multiple temporal_resolution values in component_variations throws an error", { + daily_variations |> + dplyr::bind_rows(weekly_variations) |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_error(regex = "Currently `component_variations` may only contain one unique temporal resolution value", + fixed = TRUE) +}) + +test_that("providing target_ts that cannot be aggregated to match all requested temporal resolutions throws an error", { + daily_variations |> + create_trends_ensemble(aggregate_daily_to_weekly(daily_ts), + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_error(regex = "Cannot match temporal resolution of provided `target_ts`", + fixed = TRUE) +}) + +test_that("output predictions is a model_out_tbl", { + daily_variations |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_s3_class("model_out_tbl") +}) + +test_that("component outputs are correctly calculated", { + daily_expected <- fit_baseline_models( + daily_variations[, 1:3], + daily_ts, + reference_date = "2022-12-10", + temporal_resolution = "daily", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 100, + seed = 1234 + ) + weekly_expected <- fit_baseline_models( + weekly_variations[, 1:3], + aggregate_daily_to_weekly(daily_ts), + reference_date = "2022-12-10", + temporal_resolution = "weekly", + horizons = 0:3, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 100, + seed = 1234 + ) + + daily_actual <- daily_variations |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 100, + seed = 1234, + return_baseline_predictions = TRUE) |> + purrr::pluck("baselines") + weekly_actual <- weekly_variations |> + create_trends_ensemble(aggregate_daily_to_weekly(daily_ts), + reference_date = "2022-12-10", + horizons = 0:3, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 100, + seed = 1234, + return_baseline_predictions = TRUE) |> + purrr::pluck("baselines") + + expect_equal(daily_actual, daily_expected, tolerance = 1e-3) + expect_equal(weekly_actual, weekly_expected, tolerance = 1e-3) +}) + +test_that("ensemble is correctly calculated", { + daily_outputs <- daily_variations |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 1000, + return_baseline_predictions = TRUE) + weekly_outputs <- weekly_variations |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = 0:3, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = 1000, + return_baseline_predictions = TRUE) + + daily_quantile <- daily_outputs[["baselines"]] |> + dplyr::filter(output_type == "quantile") |> + hubEnsembles::simple_ensemble( + agg_fun = median, + model_id = "UMass-trends_ensemble" + ) |> + dplyr::mutate( + reference_date = as.Date(reference_date), + target_end_date = as.Date(target_end_date) + ) + daily_sample <- daily_outputs[["baselines"]] |> + dplyr::filter(output_type == "sample") |> + dplyr::mutate( + output_type_id = as.integer(factor(paste0(model_id, output_type_id))), + model_id = "UMass-trends_ensemble", + reference_date = as.Date(reference_date), + target_end_date = as.Date(target_end_date) + ) + dplyr::bind_rows(daily_quantile, daily_sample) |> + expect_equal(daily_outputs[["ensemble"]]) + + weekly_quantile <- weekly_outputs[["baselines"]] |> + dplyr::filter(output_type == "quantile") |> + hubEnsembles::simple_ensemble( + agg_fun = median, + model_id = "UMass-trends_ensemble" + ) |> + dplyr::mutate( + reference_date = as.Date(reference_date), + ) + weekly_sample <- weekly_outputs[["baselines"]] |> + dplyr::filter(output_type == "sample") |> + dplyr::mutate( + output_type_id = as.integer(factor(paste0(model_id, output_type_id))), + model_id = "UMass-trends_ensemble", + reference_date = as.Date(reference_date), + target_end_date = as.Date(target_end_date) + ) + dplyr::bind_rows(weekly_quantile, weekly_sample) |> + expect_equal(weekly_outputs[["ensemble"]]) +}) From a7b94dfe06a1755d6f33463a5441482dd8489bdb Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:41:21 -0500 Subject: [PATCH 06/11] Add `component_variations` validation --- R/create_trends_ensemble.R | 6 ++++- man/create_trends_ensemble.Rd | 3 ++- tests/testthat/test-create_trends_ensemble.R | 26 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/R/create_trends_ensemble.R b/R/create_trends_ensemble.R index a4154dd..b91944f 100644 --- a/R/create_trends_ensemble.R +++ b/R/create_trends_ensemble.R @@ -12,7 +12,8 @@ #' the date relative to which the targets are defined (usually Saturday for #' weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended. #' @param horizons numeric vector of prediction horizons relative to -#' the reference_date, e.g. 0:3 or 1:4 +#' the reference_date, e.g. 0:3 or 1:4, and interpreted to be in terms of the +#' same temporal resolution as the provided `target_ts`. #' @param target character string specifying the name of the prediction target #' @param quantile_levels numeric vector of quantile levels to output; set to NULL #' if quantile outputs not requested. Defaults to NULL. @@ -58,6 +59,9 @@ create_trends_ensemble <- function(component_variations, round_predictions = FALSE, seed = NULL, return_baseline_predictions = FALSE) { + cv_col <- c("transformation", "symmetrize", "window_size", "temporal_resolution") + validate_colnames(component_variations, cv_col, "component_variations") + valid_temp_res <- c("daily", "weekly") temp_res_variations <- component_variations |> dplyr::distinct(name = .data[["temporal_resolution"]], .keep_all = FALSE) |> diff --git a/man/create_trends_ensemble.Rd b/man/create_trends_ensemble.Rd index 4c7a207..2ac1db1 100644 --- a/man/create_trends_ensemble.Rd +++ b/man/create_trends_ensemble.Rd @@ -33,7 +33,8 @@ the date relative to which the targets are defined (usually Saturday for weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended.} \item{horizons}{numeric vector of prediction horizons relative to -the reference_date, e.g. 0:3 or 1:4} +the reference_date, e.g. 0:3 or 1:4, and interpreted to be in terms of the +same temporal resolution as the provided \code{target_ts}.} \item{target}{character string specifying the name of the prediction target} diff --git a/tests/testthat/test-create_trends_ensemble.R b/tests/testthat/test-create_trends_ensemble.R index 4f441b8..1f283bb 100644 --- a/tests/testthat/test-create_trends_ensemble.R +++ b/tests/testthat/test-create_trends_ensemble.R @@ -25,6 +25,32 @@ daily_ts$observation[daily_ts$location == "al"] <- c(27, 19, 20, 16, 19, 22, 20, 21, 18, 25, 17, 25, 27, 22, 15, 32, 21, 26, 18, 14, 14, 14, 24, 23, 16, 20, 14, 44) + +test_that("missing or extraneous columns in component_variations throws an error", { + daily_variations[, 1] |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_error(regex = "`component_variations` is missing the column", + fixed = TRUE) + + daily_variations |> + dplyr::mutate(horizons = 28) |> + create_trends_ensemble(daily_ts, + reference_date = "2022-12-10", + horizons = -6:21, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + return_baseline_predictions = FALSE) |> + expect_error(regex = "`component_variations` contains the extra column", + fixed = TRUE) +}) + test_that("unsupported temporal_resolution values in component_variations throws an error", { daily_variations |> dplyr::mutate(temporal_resolution = "monthly") |> From f1ad4e7bb66a0b3a25c21b6a3dea890cf2ed0797 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 12:08:32 -0500 Subject: [PATCH 07/11] Fix quantiles validation --- R/get_baseline_predictions.R | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/R/get_baseline_predictions.R b/R/get_baseline_predictions.R index 382fe95..5ec12c0 100644 --- a/R/get_baseline_predictions.R +++ b/R/get_baseline_predictions.R @@ -59,7 +59,8 @@ get_baseline_predictions <- function(target_ts, validate_integer(n_sim, "n_sim") - if (any(quantile_levels > 1) || any(quantile_levels < 0)) { + if ((!is.numeric(quantile_levels) && !is.null(quantile_levels)) || + (any(quantile_levels > 1) || any(quantile_levels < 0))) { cli::cli_abort("{.arg quantile_levels} must only contain values between 0 and 1.") } @@ -68,10 +69,10 @@ get_baseline_predictions <- function(target_ts, if (n_samples > n_sim) { cli::cli_abort("{.arg n_samples} must be less than or equal to {.arg n_sim}") } - } - - if (is.null(quantile_levels) && is.null(n_samples)) { - cli::cli_abort("No forecasts requested: both `quantile_levels` and `n_samples` are NULL") + } else { + if (is.null(quantile_levels)) { + cli::cli_abort("No forecasts requested: both `quantile_levels` and `n_samples` are NULL") + } } if (!is.null(seed)) set.seed(seed) From 2bc003b420c30ed2f868e554f9c7419ccdbbccc6 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 12:24:38 -0500 Subject: [PATCH 08/11] Update DESCRIPTION --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index ac173be..9bc33a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -39,6 +39,7 @@ Imports: tibble, tidyr Remotes: + hubverse-org/hubEnsembles, reichlab/simplets Suggests: car, From 4d752b191598d231c8c110d1058a1e16c09942de Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:29:12 -0500 Subject: [PATCH 09/11] Fix documentation typo --- R/create_trends_ensemble.R | 2 +- man/create_trends_ensemble.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/create_trends_ensemble.R b/R/create_trends_ensemble.R index b91944f..7fb8098 100644 --- a/R/create_trends_ensemble.R +++ b/R/create_trends_ensemble.R @@ -3,7 +3,7 @@ #' #' @param component_variations a `data.frame` where each row specifies a set of #' hyperparameters to use for a single baseline model fit, with columns -#' `transformation`, `symmetrize`, `window_size`, and `temporal_horizon`. +#' `transformation`, `symmetrize`, `window_size`, and `temporal_resolution`. #' See details for more information #' @param target_ts a `data.frame` of target data in a time series format #' (contains columns `time_index`, `location`, and `observation`) for a single diff --git a/man/create_trends_ensemble.Rd b/man/create_trends_ensemble.Rd index 2ac1db1..4ecf722 100644 --- a/man/create_trends_ensemble.Rd +++ b/man/create_trends_ensemble.Rd @@ -21,7 +21,7 @@ create_trends_ensemble( \arguments{ \item{component_variations}{a \code{data.frame} where each row specifies a set of hyperparameters to use for a single baseline model fit, with columns -\code{transformation}, \code{symmetrize}, \code{window_size}, and \code{temporal_horizon}. +\code{transformation}, \code{symmetrize}, \code{window_size}, and \code{temporal_resolution}. See details for more information} \item{target_ts}{a \code{data.frame} of target data in a time series format From d6d90944361f151645e16d4c61112855552b6ff1 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:32:34 -0500 Subject: [PATCH 10/11] Update docs --- R/fit_baselines_one_location.R | 2 +- man/fit_baseline_models.Rd | 2 +- man/fit_baselines_one_location.Rd | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/fit_baselines_one_location.R b/R/fit_baselines_one_location.R index 846c8d1..6da5795 100644 --- a/R/fit_baselines_one_location.R +++ b/R/fit_baselines_one_location.R @@ -16,7 +16,7 @@ #' the reference_date, e.g. 0:3 or 1:4 #' @param quantile_levels numeric vector of quantile levels to output; set to NULL #' if quantile outputs not requested. Defaults to NULL. -#' @param n_samples integer of amount of samples to output (and predict); +#' @param n_samples integer of amount of samples to output; #' set to NULL if sample outputs not requested (in this case 100000 samples #' are generated from which to extract quantiles). Defaults to NULL. #' @param round_predictions boolean specifying whether to round the output diff --git a/man/fit_baseline_models.Rd b/man/fit_baseline_models.Rd index 175e15d..a589986 100644 --- a/man/fit_baseline_models.Rd +++ b/man/fit_baseline_models.Rd @@ -42,7 +42,7 @@ the reference_date, e.g. 0:3 or 1:4} \item{quantile_levels}{numeric vector of quantile levels to output; set to NULL if quantile outputs not requested. Defaults to NULL.} -\item{n_samples}{integer of amount of samples to output (and predict); +\item{n_samples}{integer of amount of samples to output; set to NULL if sample outputs not requested (in this case 100000 samples are generated from which to extract quantiles). Defaults to NULL.} diff --git a/man/fit_baselines_one_location.Rd b/man/fit_baselines_one_location.Rd index 3067e1b..72acdf4 100644 --- a/man/fit_baselines_one_location.Rd +++ b/man/fit_baselines_one_location.Rd @@ -39,7 +39,7 @@ the reference_date, e.g. 0:3 or 1:4} \item{quantile_levels}{numeric vector of quantile levels to output; set to NULL if quantile outputs not requested. Defaults to NULL.} -\item{n_samples}{integer of amount of samples to output (and predict); +\item{n_samples}{integer of amount of samples to output; set to NULL if sample outputs not requested (in this case 100000 samples are generated from which to extract quantiles). Defaults to NULL.} From 47f1c95333da21ec857b7d230061dfeb69c0b280 Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:52:47 -0500 Subject: [PATCH 11/11] Make `n_sim` a parameter for all functions --- R/create_trends_ensemble.R | 4 ++++ R/fit_baseline_models.R | 2 ++ R/fit_baselines_one_location.R | 4 +++- man/create_trends_ensemble.Rd | 3 +++ man/fit_baseline_models.Rd | 3 +++ man/fit_baselines_one_location.Rd | 3 +++ tests/testthat/test-create_trends_ensemble.R | 11 +++++++++++ tests/testthat/test-fit_baseline_models.R | 6 ++++++ tests/testthat/test-fit_baselines_one_location.R | 10 ++++++++++ 9 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/create_trends_ensemble.R b/R/create_trends_ensemble.R index 7fb8098..37e3fb0 100644 --- a/R/create_trends_ensemble.R +++ b/R/create_trends_ensemble.R @@ -15,6 +15,7 @@ #' the reference_date, e.g. 0:3 or 1:4, and interpreted to be in terms of the #' same temporal resolution as the provided `target_ts`. #' @param target character string specifying the name of the prediction target +#' @param n_sim integer number of simulations to predict. Defaults to 100000. #' @param quantile_levels numeric vector of quantile levels to output; set to NULL #' if quantile outputs not requested. Defaults to NULL. #' @param n_samples integer of amount of samples to output (and predict); @@ -54,6 +55,7 @@ create_trends_ensemble <- function(component_variations, reference_date, horizons, target, + n_sim = 10000, quantile_levels, n_samples = NULL, round_predictions = FALSE, @@ -104,6 +106,7 @@ create_trends_ensemble <- function(component_variations, current_temp_res$name, new_horizon_min:new_horizon_max, target, + n_sim = n_sim, quantile_levels, n_samples, round_predictions, @@ -116,6 +119,7 @@ create_trends_ensemble <- function(component_variations, current_temp_res$name[1], horizons, target, + n_sim = n_sim, quantile_levels, n_samples, round_predictions, diff --git a/R/fit_baseline_models.R b/R/fit_baseline_models.R index 333b961..f5eb5a4 100644 --- a/R/fit_baseline_models.R +++ b/R/fit_baseline_models.R @@ -23,6 +23,7 @@ fit_baseline_models <- function(model_variations, temporal_resolution, horizons, target, + n_sim = 10000, quantile_levels, n_samples, round_predictions = FALSE, @@ -41,6 +42,7 @@ fit_baseline_models <- function(model_variations, reference_date = reference_date, temporal_resolution = temporal_resolution, horizons = horizons, + n_sim = n_sim, quantile_levels = quantile_levels, n_samples = n_samples, round_predictions = round_predictions, diff --git a/R/fit_baselines_one_location.R b/R/fit_baselines_one_location.R index 6da5795..9913c25 100644 --- a/R/fit_baselines_one_location.R +++ b/R/fit_baselines_one_location.R @@ -14,6 +14,7 @@ #' `target_ts` and `horizons` #' @param horizons numeric vector of prediction horizons relative to #' the reference_date, e.g. 0:3 or 1:4 +#' @param n_sim integer number of simulations to predict. Defaults to 100000. #' @param quantile_levels numeric vector of quantile levels to output; set to NULL #' if quantile outputs not requested. Defaults to NULL. #' @param n_samples integer of amount of samples to output; @@ -57,6 +58,7 @@ fit_baselines_one_location <- function(model_variations, reference_date, temporal_resolution, horizons, + n_sim = 10000, quantile_levels, n_samples, round_predictions = FALSE, @@ -96,7 +98,7 @@ fit_baselines_one_location <- function(model_variations, target_ts = target_ts, effective_horizons = horizons_to_forecast, origin = ifelse(temporal_resolution == "weekly", "obs", "median"), - n_sim = 100000, + n_sim = n_sim, quantile_levels = quantile_levels, n_samples = n_samples, round_predictions = round_predictions, diff --git a/man/create_trends_ensemble.Rd b/man/create_trends_ensemble.Rd index 4ecf722..9ccd703 100644 --- a/man/create_trends_ensemble.Rd +++ b/man/create_trends_ensemble.Rd @@ -11,6 +11,7 @@ create_trends_ensemble( reference_date, horizons, target, + n_sim = 10000, quantile_levels, n_samples = NULL, round_predictions = FALSE, @@ -38,6 +39,8 @@ same temporal resolution as the provided \code{target_ts}.} \item{target}{character string specifying the name of the prediction target} +\item{n_sim}{integer number of simulations to predict. Defaults to 100000.} + \item{quantile_levels}{numeric vector of quantile levels to output; set to NULL if quantile outputs not requested. Defaults to NULL.} diff --git a/man/fit_baseline_models.Rd b/man/fit_baseline_models.Rd index a589986..d14d6b1 100644 --- a/man/fit_baseline_models.Rd +++ b/man/fit_baseline_models.Rd @@ -11,6 +11,7 @@ fit_baseline_models( temporal_resolution, horizons, target, + n_sim = 10000, quantile_levels, n_samples, round_predictions = FALSE, @@ -39,6 +40,8 @@ the reference_date, e.g. 0:3 or 1:4} \item{target}{character string specifying the name of the prediction target} +\item{n_sim}{integer number of simulations to predict. Defaults to 100000.} + \item{quantile_levels}{numeric vector of quantile levels to output; set to NULL if quantile outputs not requested. Defaults to NULL.} diff --git a/man/fit_baselines_one_location.Rd b/man/fit_baselines_one_location.Rd index 72acdf4..7c9e576 100644 --- a/man/fit_baselines_one_location.Rd +++ b/man/fit_baselines_one_location.Rd @@ -10,6 +10,7 @@ fit_baselines_one_location( reference_date, temporal_resolution, horizons, + n_sim = 10000, quantile_levels, n_samples, round_predictions = FALSE, @@ -36,6 +37,8 @@ weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended.} \item{horizons}{numeric vector of prediction horizons relative to the reference_date, e.g. 0:3 or 1:4} +\item{n_sim}{integer number of simulations to predict. Defaults to 100000.} + \item{quantile_levels}{numeric vector of quantile levels to output; set to NULL if quantile outputs not requested. Defaults to NULL.} diff --git a/tests/testthat/test-create_trends_ensemble.R b/tests/testthat/test-create_trends_ensemble.R index 1f283bb..e5d03b4 100644 --- a/tests/testthat/test-create_trends_ensemble.R +++ b/tests/testthat/test-create_trends_ensemble.R @@ -32,6 +32,7 @@ test_that("missing or extraneous columns in component_variations throws an error reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, return_baseline_predictions = FALSE) |> @@ -44,6 +45,7 @@ test_that("missing or extraneous columns in component_variations throws an error reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, return_baseline_predictions = FALSE) |> @@ -59,6 +61,7 @@ test_that("unsupported temporal_resolution values in component_variations throws reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, return_baseline_predictions = FALSE) |> @@ -73,6 +76,7 @@ test_that("multiple temporal_resolution values in component_variations throws an reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, return_baseline_predictions = FALSE) |> @@ -86,6 +90,7 @@ test_that("providing target_ts that cannot be aggregated to match all requested reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, return_baseline_predictions = FALSE) |> @@ -113,6 +118,7 @@ test_that("component outputs are correctly calculated", { temporal_resolution = "daily", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 100, seed = 1234 @@ -124,6 +130,7 @@ test_that("component outputs are correctly calculated", { temporal_resolution = "weekly", horizons = 0:3, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 100, seed = 1234 @@ -134,6 +141,7 @@ test_that("component outputs are correctly calculated", { reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 100, seed = 1234, @@ -144,6 +152,7 @@ test_that("component outputs are correctly calculated", { reference_date = "2022-12-10", horizons = 0:3, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 100, seed = 1234, @@ -160,6 +169,7 @@ test_that("ensemble is correctly calculated", { reference_date = "2022-12-10", horizons = -6:21, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 1000, return_baseline_predictions = TRUE) @@ -168,6 +178,7 @@ test_that("ensemble is correctly calculated", { reference_date = "2022-12-10", horizons = 0:3, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = 1000, return_baseline_predictions = TRUE) diff --git a/tests/testthat/test-fit_baseline_models.R b/tests/testthat/test-fit_baseline_models.R index a35ab6f..f0f6445 100644 --- a/tests/testthat/test-fit_baseline_models.R +++ b/tests/testthat/test-fit_baseline_models.R @@ -17,6 +17,7 @@ test_that("missing target throws an error", { temporal_resolution = "weekly", horizons = 0:3, target = NULL, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -31,6 +32,7 @@ test_that("output predictions is a model_out_tbl", { temporal_resolution = "weekly", horizons = 0:3, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) expect_s3_class(baseline_outputs, "model_out_tbl") @@ -44,6 +46,7 @@ test_that("model IDs are as expected", { temporal_resolution = "weekly", horizons = 0:3, target = "inc hosp", + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> dplyr::pull(.data[["model_id"]]) |> @@ -60,6 +63,7 @@ test_that("mapping over locations works as expected", { reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234 @@ -70,6 +74,7 @@ test_that("mapping over locations works as expected", { reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234 @@ -98,6 +103,7 @@ test_that("mapping over locations works as expected", { reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, target = "inc hosp", quantile_levels = c(.1, .5, .9), n_samples = NULL, diff --git a/tests/testthat/test-fit_baselines_one_location.R b/tests/testthat/test-fit_baselines_one_location.R index bd0e7a9..1f2afa1 100644 --- a/tests/testthat/test-fit_baselines_one_location.R +++ b/tests/testthat/test-fit_baselines_one_location.R @@ -17,6 +17,7 @@ test_that("multiple locations in target data throws an error", { reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -30,6 +31,7 @@ test_that("null reference date throws an error", { reference_date = NULL, temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -43,6 +45,7 @@ test_that("invalid reference date format throws an error", { reference_date = "2023", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -56,6 +59,7 @@ test_that("multiple reference dates throws an error", { reference_date = c("2023-01-07", "2023-01-14"), temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -69,6 +73,7 @@ test_that("invalid temporal resolution throws an error", { reference_date = "2023-01-14", temporal_resolution = "monthly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -82,6 +87,7 @@ test_that("provided temporal_resolution not matching that of target_ts throws an reference_date = "2023-01-14", temporal_resolution = "daily", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL) |> expect_error( @@ -100,6 +106,7 @@ test_that( reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:4, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234 @@ -115,6 +122,7 @@ test_that( reference_date = "2023-01-21", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234 @@ -139,6 +147,7 @@ test_that("overlapping forecasts are replaced with observed values and throws a reference_date = "2023-01-14", temporal_resolution = "weekly", horizons = 0:2, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234 @@ -158,6 +167,7 @@ test_that("overlapping forecasts are replaced with observed values and throws a reference_date = "2023-01-07", temporal_resolution = "weekly", horizons = 0:3, + n_sim = 10000, quantile_levels = c(.1, .5, .9), n_samples = NULL, seed = 1234