Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3 write fit baseline models function #10

Merged
merged 13 commits into from
Nov 4, 2024
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ RoxygenNote: 7.3.2
Imports:
cli,
dplyr,
hubUtils,
purrr,
rlang,
simplets,
Expand Down
69 changes: 69 additions & 0 deletions R/fit_baseline_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#' Generate predictions for all baseline models for the given reference date
#'
#' @inheritParams fit_baselines_one_location
#' @param target character string specifying the name of the prediction target
#'
#' @details The `model_variations` data frame has the following columns and
#' possible values for each:
#' - transformation (character): "none" or "sqrt", determines distribution shape
#' - symmetrize (boolean), determines if distribution is symmetric
#' - window_size (integer), determines how many previous observations inform
#' the forecast
#'
#' @return `model_out_tbl` of forecasts for all baseline models with columns:
#' `model_id`, `reference_date`, `location`, `horizon`, `target`,
#' `target_end_date`, `output_type`, `output_type_id`, and `value`
#'
#' @importFrom rlang .data
#'
#' @export
fit_baseline_models <- function(model_variations,
target_ts,
reference_date,
temporal_resolution,
horizons,
target,
quantile_levels,
n_samples,
round_predictions = FALSE,
seed = NULL) {
if (is.null(target)) {
cli::cli_abort("{.arg target} is missing; please provide one")
}

# fit baseline models
purrr::map(
unique(target_ts$location),
function(fips_code) {
fit_baselines_one_location(
model_variations = model_variations,
target_ts = dplyr::filter(target_ts, .data[["location"]] == fips_code),
reference_date = reference_date,
temporal_resolution = temporal_resolution,
horizons = horizons,
quantile_levels = quantile_levels,
n_samples = n_samples,
round_predictions = round_predictions,
seed = seed
)
}
) |>
purrr::list_rbind() |>
dplyr::mutate(
model_id = paste(
"UMass-baseline",
.data[["transformation"]],
ifelse(.data[["symmetrize"]], "sym", "nonsym"),
.data[["window_size"]],
temporal_resolution,
sep = "_"
),
target = target,
.before = 1
) |>
dplyr::select(dplyr::all_of(c(
"model_id", "location", "reference_date", "horizon", "target",
"target_end_date", "output_type", "output_type_id", "value"
))) |>
hubUtils::as_model_out_tbl()
}
52 changes: 27 additions & 25 deletions R/fit_baselines_one_location.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' location
#' @param reference_date string of the reference date for the forecasts, i.e.
#' the date relative to which the targets are defined (usually Saturday for
#' weekly targets)
#' weekly targets). Must be in the ymd format, with yyyy-mm-dd format recommended.
#' @param temporal_resolution 'daily' or 'weekly'; specifies timescale of
#' `target_ts` and `horizons`
#' @param horizons numeric vector of prediction horizons relative to
Expand Down Expand Up @@ -62,6 +62,13 @@ fit_baselines_one_location <- function(model_variations,
round_predictions = FALSE,
seed = NULL) {

if (length(reference_date) > 1) {
cli::cli_abort("only one {.arg reference_date} may be provided")
} else {
# date to which horizons are relative
reference_date <- validate_ymd_date(reference_date, arg_name = "reference_date")
}

valid_temp_res <- c("daily", "weekly")
if (!(temporal_resolution %in% valid_temp_res && length(temporal_resolution)) == 1) {
cli::cli_abort("{.arg temporal_resolution} must be only one of {.val valid_temp_res}")
Expand All @@ -76,26 +83,26 @@ fit_baselines_one_location <- function(model_variations,
}

# figure out horizons to forecast
reference_date <- lubridate::ymd(reference_date) # date to which horizons are relative
last_data_date <- max(target_ts$time_index) # last day of target data
actual_target_dates <- reference_date + ts_temp_res * horizons
effective_horizons <- as.integer(actual_target_dates - last_data_date) / ts_temp_res
horizons_to_forecast <- 1:max(effective_horizons)
h_adjustments <- min(effective_horizons) - 1

# get predictions for all model_variations
predictions <- purrr::pmap_dfr( #tibble, each 1x1 row contains predictions for 1 model
model_variations,
get_baseline_predictions,
target_ts = target_ts,
effective_horizons = horizons_to_forecast,
origin = ifelse(temporal_resolution == "weekly", "obs", "median"),
n_sim = 100000,
quantile_levels = quantile_levels,
n_samples = n_samples,
round_predictions = round_predictions,
seed = seed
)
predictions <- model_variations |>
purrr::pmap( #tibble, each 1x1 row contains predictions for 1 model
get_baseline_predictions,
target_ts = target_ts,
effective_horizons = horizons_to_forecast,
origin = ifelse(temporal_resolution == "weekly", "obs", "median"),
n_sim = 100000,
quantile_levels = quantile_levels,
n_samples = n_samples,
round_predictions = round_predictions,
seed = seed
) |>
purrr::list_rbind()

# extract forecasts
extracted_outputs <-
Expand Down Expand Up @@ -145,17 +152,12 @@ fit_baselines_one_location <- function(model_variations,
dplyr::mutate(
value = ifelse(is.na(.data[["value"]]), .data[["observation"]], .data[["value"]])
)
if (max(actual_target_dates) <= last_data_date) {
cli::cli_warn(
"all requested forecasts are for a time index within the provided {.arg target_ts},
replacing overlapping forecasts with {.val {length(horizons)}} target observations"
)
} else {
cli::cli_warn(
"forecasts requested for a time index within the provided {.arg target_ts},
replacing overlapping forecasts with {.val {abs(h_adjustments)}} target observations"
)
}

cli::cli_warn(
"forecasts requested for a time index within the provided {.arg target_ts},
replacing overlapping forecasts for {.val {abs(h_adjustments)}} horizons
with observed values"
)
}

model_outputs |>
Expand Down
23 changes: 23 additions & 0 deletions R/validate_model_inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,26 @@ validate_integer <- function(int, arg_name) {
cli::cli_abort("{.arg {arg_name}} must be a single, non-negative integer value.")
}
}


#' Validate value to be converted into a ymd date
#'
#' @param date value to be converted into a ymd date
#' @param arg_name character string name of the argument being validated to be
#' printed in the error message (generally the name of the `date` object)
#'
#' @return a validated Date object (or vector) in the ymd format
#'
#' @noRd
validate_ymd_date <- function(date, arg_name) {
if (is.null(date)) {
cli::cli_abort("{.arg {arg_name}} is missing")
}

ymd_date <- lubridate::ymd(date, quiet = TRUE)
if (is.na(ymd_date)) {
cli::cli_abort("{.arg {arg_name}} could not be correctly parsed. Please use the ymd format")
} else {
ymd_date
}
}
72 changes: 72 additions & 0 deletions man/fit_baseline_models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/fit_baselines_one_location.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

106 changes: 106 additions & 0 deletions tests/testthat/test-fit_baseline_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# set up simple data for test cases
model_variations <- data.frame(stringsAsFactors = FALSE,
transformation = c("none", "none"),
symmetrize = c(TRUE, TRUE),
window_size = c(3, 4))
target_ts <- expand.grid(stringsAsFactors = FALSE,
location = c("ak", "al"),
time_index = as.Date("2022-11-05") + seq(0, 63, 7),
observation = NA)
target_ts$observation[target_ts$location == "ak"] <- c(15, 14, 38, 69, 53, 73, 51, 43, 43, 32)
target_ts$observation[target_ts$location == "al"] <- c(350, 312, 236, 237, 360, 234, 153, 153, 148, 125)

test_that("missing target throws an error", {
fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = NULL,
quantile_levels = c(.1, .5, .9),
n_samples = NULL) |>
expect_error(
regexp = "`target` is missing; please provide one", fixed = TRUE
)
})

test_that("output predictions is a model_out_tbl", {
baseline_outputs <- fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = "inc hosp",
quantile_levels = c(.1, .5, .9),
n_samples = NULL)
expect_s3_class(baseline_outputs, "model_out_tbl")
})

test_that("model IDs are as expected", {
expected_model_ids <- paste("UMass-baseline", "none", "sym", c(3, 4), "weekly", sep = "_")
actual_model_ids <- fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = "inc hosp",
quantile_levels = c(.1, .5, .9),
n_samples = NULL) |>
dplyr::pull(.data[["model_id"]]) |>
unique()
expect_equal(actual_model_ids, expected_model_ids)
})

test_that("mapping over locations works as expected", {
expected_outputs <-
rbind(
fit_baselines_one_location(
model_variations,
target_ts |> dplyr::filter(.data[["location"]] == "ak"),
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234
),
fit_baselines_one_location(
model_variations,
target_ts |> dplyr::filter(.data[["location"]] == "al"),
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234
)
) |>
dplyr::mutate(
model_id = paste(
"UMass-baseline",
.data[["transformation"]],
ifelse(.data[["symmetrize"]], "sym", "nonsym"),
.data[["window_size"]],
"weekly",
sep = "_"
),
target = "inc hosp",
.before = 1
) |>
dplyr::select(dplyr::all_of(c(
"model_id", "location", "reference_date", "horizon", "target",
"target_end_date", "output_type", "output_type_id", "value"
))) |>
hubUtils::as_model_out_tbl()

actual_outputs <- fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = "inc hosp",
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234)
expect_equal(actual_outputs, expected_outputs, tolerance = 1e-3)
})
Loading
Loading