Skip to content

Commit

Permalink
Merge pull request #4 from reichlab/1-write-get_baseline_predictions-…
Browse files Browse the repository at this point in the history
…function

1 write get baseline predictions function
  • Loading branch information
lshandross authored Sep 10, 2024
2 parents 0bbe7d9 + 340e61d commit 1c86ee0
Show file tree
Hide file tree
Showing 13 changed files with 1,532 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
^trendsEnsemble\.Rproj$
^\.Rproj\.user$
^\.github$
^\.lintr$
^README\.Rmd$
^LICENSE\.md$
34 changes: 34 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: lint

permissions: read-all

jobs:
lint:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v4

- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::lintr, local::.
needs: lint

- name: Lint
run: lintr::lint_package()
shell: Rscript {0}
env:
LINTR_ERROR_ON_LINT: true
4 changes: 4 additions & 0 deletions .lintr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
linters: linters_with_defaults(
line_length_linter = line_length_linter(120L),
commented_code_linter = NULL
)
21 changes: 18 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,24 @@ Authors@R: c(
email = "[email protected]",
comment = c(ORCID = "0000-0003-3503-9899")))
Description: A collection of functions used to create the UMass-trends_ensemble,
an equally weighted ensemble of simple time series baseline models
License: `use_mit_license()`, `use_gpl3_license()` or friends to pick a
license
an equally weighted ensemble of simple time series baseline models.
License: GPL (>= 3)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Imports:
cli,
dplyr,
purrr,
rlang,
simplets,
tibble
Remotes:
reichlab/simplets
Suggests:
car,
fabletools,
feasts,
testthat (>= 3.2.1),
tidyr
Config/testthat/edition: 3
595 changes: 595 additions & 0 deletions LICENSE.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Generated by roxygen2: do not edit by hand

importFrom(rlang,.data)
166 changes: 166 additions & 0 deletions R/get_baseline_predictions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#' Get sample and/or quantile type predictions for a single baseline model
#'
#' @param target_ts a `data.frame` of target data in a time series format
#' (contains columns `time_index`, `location`, and `observation`) for a single
#' location
#' @param transformation string specifying the transformation used on the
#' distribution which determines its shape; can be one of "none" or "sqrt".
#' @param symmetrize boolean specifying whether to make the distribution symmetric;
#' can be one of `TRUE` or `FALSE`.
#' @param window_size integer specifying how many previous observations in the
#' target data should be used to inform the forecasts
#' @param effective_horizons numeric vector of prediction horizons relative to
#' the last observed date in `target_ts`
#' @param origin string specifying the origin to use when making predictions;
#' recommended to be "median" if the temporal resolution is daily and "obs"
#' if weekly or otherwise. Defaults to "obs".
#' @param n_sim integer number of simulations to predict. Defaults to 100000.
#' @param quantile_levels numeric vector of quantile levels to output; set to NULL
#' if quantile outputs not requested. Defaults to NULL.
#' @param n_samples integer number of samples to output, which are drawn from the
#' simulated (sample) predictions (hence n_samples <= n_sim). Set to NULL
#' if sample outputs not requested. Defaults to NULL.
#' @param round_predictions boolean specifying whether to round the output
#' predictions to the nearest whole number. Defaults to FALSE
#' @param seed integer specifying a seed to set for reproducible results.
#' Defaults to NULL, in which case no seed is set.
#'
#' @return data frame of a baseline forecast for one location, one model with
#' columns `horizon` , `output_type`, `output_type_id`, and `value`,
#' but which are stored as a nested list in a 1x1 data frame
#'
#' @importFrom rlang .data

get_baseline_predictions <- function(target_ts,
transformation,
symmetrize,
window_size,
effective_horizons,
origin = "obs",
n_sim = 100000,
quantile_levels = NULL,
n_samples = NULL,
round_predictions = FALSE,
seed = NULL) {
# validate arguments
validate_target_ts(target_ts)

num_locs <- length(unique(target_ts[["location"]]))
if (num_locs != 1) {
cli::cli_abort("{.arg target_ts} contains {.val num_locs} but only one may be provided.")
}

validate_variation_inputs(transformation, symmetrize, window_size)

valid_origins <- c("median", "obs")
if (!origin %in% valid_origins) {
cli::cli_abort("{.arg origin} must be only one of {.val valid_origins}")
}

validate_integer(n_sim, "n_sim")

if (any(quantile_levels > 1) || any(quantile_levels < 0)) {
cli::cli_abort("{.arg quantile_levels} must only contain values between 0 and 1.")
}

if (!is.null(n_samples)) {
validate_integer(n_samples, "n_samples")
if (n_samples > n_sim) {
cli::cli_abort("{.arg n_samples} must be less than or equal to {.arg n_sim}")
}
}

if (is.null(quantile_levels) && is.null(n_samples)) {
cli::cli_abort("No forecasts requested: both `quantile_levels` and `n_samples` are NULL")
}

if (!is.null(seed)) set.seed(seed)

# fit
baseline_fit <- simplets::fit_simple_ts(
y = target_ts[["observation"]],
ts_frequency = 1,
model = "quantile_baseline",
transformation = transformation,
transform_offset = ifelse(transformation == "none", 0, 1),
d = 0,
D = 0,
symmetrize = symmetrize,
window_size = window_size
)

# predict
predictions <- baseline_fit |>
stats::predict(
nsim = n_sim,
horizon = max(effective_horizons),
origin = origin,
force_nonneg = TRUE
)

forecasts_df <- extract_predictions(predictions, effective_horizons, quantile_levels, n_samples)

if (round_predictions) forecasts_df[["value"]] <- round(forecasts_df[["value"]], 0)
return(dplyr::tibble(forecasts = list(forecasts_df)))
}


#' Extract and compute sample and/or quantile forecasts from a prediction matrix
#'
#' @param predictions a `matrix` of `n_sim` runs of sample predictions with
#' dimensions n_sim x max(effective_horizons)
#' @param effective_horizons numeric vector of prediction horizons relative to
#' the last observed date in `target_ts`
#' @param quantile_levels numeric vector of quantile levels to output; set to NULL
#' if quantile outputs not requested. Defaults to NULL.
#' @param n_samples integer number of samples to output, which are drawn from the
#' simulated (sample) predictions (hence n_samples <= n_sim). Set to NULL
#' if sample outputs not requested. Defaults to NULL.
#'
#' @return data frame of extracted sample and/or quantile forecasts with
#' columns `horizon` , `output_type`, `output_type_id`, and `value`
#'
extract_predictions <- function(predictions,
effective_horizons,
quantile_levels = NULL,
n_samples = NULL) {
samples_df <- NULL
if (!is.null(n_samples)) {
samples_df <- effective_horizons |>
purrr::map(
function(h) {
data.frame(
horizon = rep(h, n_samples),
value = predictions[1:n_samples, h]
) |>
tibble::rownames_to_column(var = "output_type_id") |>
dplyr::mutate(
output_type = "sample",
output_type_id = as.numeric(dplyr::row_number()),
.before = 2
) |>
dplyr::select("horizon", "output_type", "output_type_id", "value")
}
) |>
purrr::list_rbind()
}
quantiles_df <- NULL
if (!is.null(quantile_levels)) {
n <- length(quantile_levels)
quantiles_df <- effective_horizons |>
purrr::map(
function(h) {
data.frame(
horizon = rep(h, n),
output_type = "quantile",
output_type_id = quantile_levels,
value = stats::quantile(predictions[, h], probs = quantile_levels)
)
}
) |>
purrr::list_rbind()
}

combined_df <- rbind(samples_df, quantiles_df)
return(combined_df)
}
132 changes: 132 additions & 0 deletions R/validate_model_inputs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#' Perform simple validations on the model variations dataframe used to define a
#' given baseline model
#'
#' @param model_variations a `data.frame` where each row specifies a set of
#' hyperparameters to use for a single baseline model fit, with the following
#' columns: `transformation`, `symmetrize`, and `window_size`. See the details
#' for more information.
#' @details The types, and possible values for each of the columns in
#' `model_variations` are as follows:
#' - transformation (character): "none" or "sqrt"
#' - symmetric (boolean)
#' - window_size (numeric): a non-negative integer
#' Additional validations that check each column for the correct type are performed
#' by `validate_variation_inputs()`, which will always be called following a call to
#' `validate_model_inputs()`.
#'
#' @return no return value
#'
#' @noRd

validate_model_variations <- function(model_variations) {
if (is.null(model_variations)) {
cli::cli_abort("{.arg model_variations} is missing")
}

variation_col <- c("transformation", "symmetrize", "window_size")
validate_colnames(model_variations, variation_col, "model_variations")

if (any(duplicated(model_variations))) {
cli::cli_abort("{.arg model_variations} contains duplicate rows.")
}
}


#' Perform simple validations on the individual variables defining a single
#' baseline model
#'
#' @param transformation string specifying the transformation used on the
#' distribution which determines its shape; can be one of "none" or "sqrt".
#' @param symmetrize boolean specifying whether to make the distribution symmetric;
#' can be one of `TRUE` or `FALSE`.
#' @param window_size integer specifying how many previous observations in the
#' target data should be used to inform the forecasts
#'
#' @return no return value
#'
#' @noRd
validate_variation_inputs <- function(transformation, symmetrize, window_size) {
# check variation inputs have length 1
if (length(transformation) != 1) {
cli::cli_abort("{.arg transformation} must be length 1")
}

if (length(symmetrize) != 1) {
cli::cli_abort("{.arg symmetrize} must be length 1")
}

if (length(window_size) != 1) {
cli::cli_abort("{.arg window_size} must be length 1")
}

# check variation inputs contain only valid values
valid_transformations <- c("none", "sqrt")
if (!transformation %in% valid_transformations) {
cli::cli_abort("{.arg transformation} must only contain values {.val {valid_transformations}}")
}

if (!inherits(symmetrize, "logical")) {
cli::cli_abort("{.arg symmetrize} must only contain logical values, e.g. TRUE or FALSE.")
}

if (window_size != trunc(window_size) || window_size < 0) {
cli::cli_abort("{.arg window_size} must only contain non-negative integer values.")
}
}


#' Perform simple validations on the target data (time series) dataframe
#'
#' @param target_ts a `data.frame` of target data in a time series format
#' (contains columns `time_index`, `location`, and `observation`)
#'
#' @return no return value
#'
#' @noRd
validate_target_ts <- function(target_ts) {
target_col <- c("time_index", "location", "observation")
validate_colnames(target_ts, target_col, "target_ts")

if (any(duplicated(target_ts))) {
cli::cli_abort("{.arg target_ts} contains duplicate rows.")
}
}


#' Validate that a dataframe's columns are (named) as expected
#'
#' @param df a `data.frame` whose columns are to be validated
#' @param expected_col a character vector of expected column names
#' @param arg_name character string name of the argument being validated to be
#' printed in the error message(generally the name of the `df` object)
#'
#' @return no return value
#'
#' @noRd
validate_colnames <- function(df, expected_col, arg_name) {
actual_col <- colnames(df)
if (!all(expected_col %in% actual_col)) {
cli::cli_abort("{.arg {arg_name}} is missing the column{?s}: {.val {setdiff(expected_col, actual_col)}}.")
}
if (!all(actual_col %in% expected_col)) {
cli::cli_abort(c(
x = "{.arg {arg_name}} contains the extra column{?s}: {.val {setdiff(actual_col, expected_col)}}."
))
}
}


#' Validate that an integer is as expected and non-negative
#'
#' @param int a single integer to be validated
#' @param arg_name character string name of the argument being validated to be
#' printed in the error message(generally the name of the `int` object)
#'
#' @return no return value
#'
#' @noRd
validate_integer <- function(int, arg_name) {
if (!is.numeric(int) || int < 0 || int != trunc(int) || length(int) != 1) {
cli::cli_abort("{.arg {arg_name}} must be a single, non-negative integer value.")
}
}
Loading

0 comments on commit 1c86ee0

Please sign in to comment.