From a7de97a2693247b976b87660bd5e78d38f27e5ea Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:43:11 -0400 Subject: [PATCH] Add `fit_baseline_models()` tests --- R/fit_baseline_models.R | 2 +- tests/testthat/test-fit_baseline_models.R | 94 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test-fit_baseline_models.R diff --git a/R/fit_baseline_models.R b/R/fit_baseline_models.R index c2d13a6..7d33fed 100644 --- a/R/fit_baseline_models.R +++ b/R/fit_baseline_models.R @@ -27,7 +27,7 @@ fit_baseline_models <- function(model_variations, round_predictions = FALSE, seed = NULL) { if (is.null(target)) { - cli::cli_abort("No {.arg target} provided; please provide one") + cli::cli_abort("{.arg target} is missing; please provide one") } # fit baseline models diff --git a/tests/testthat/test-fit_baseline_models.R b/tests/testthat/test-fit_baseline_models.R new file mode 100644 index 0000000..6efe155 --- /dev/null +++ b/tests/testthat/test-fit_baseline_models.R @@ -0,0 +1,94 @@ +# set up simple data for test cases +model_variations <- data.frame(stringsAsFactors = FALSE, + transformation = c("none", "none"), + symmetrize = c(TRUE, TRUE), + window_size = c(3, 4)) +target_ts <- expand.grid(stringsAsFactors = FALSE, + location = c("ak", "al"), + time_index = as.Date("2022-11-05") + seq(0, 63, 7), + observation = NA) +target_ts$observation[target_ts$location == "ak"] <- c(15, 14, 38, 69, 53, 73, 51, 43, 43, 32) +target_ts$observation[target_ts$location == "al"] <- c(350, 312, 236, 237, 360, 234, 153, 153, 148, 125) + +test_that("missing target throws an error", { + fit_baseline_models(model_variations, + target_ts, + reference_date = "2023-01-14", + temporal_resolution = "weekly", + horizons = 0:3, + target = NULL, + quantile_levels = c(.1, .5, .9), + n_samples = NULL) |> + expect_error( + regexp = "`target` is missing; please provide one", fixed = TRUE + ) +}) + +test_that("model IDs are as expected", { + expected_model_ids <- paste("UMass-baseline", "none", "sym", c(3, 4), "weekly", sep = "_") + actual_model_ids <- fit_baseline_models(model_variations, + target_ts, + reference_date = "2023-01-14", + temporal_resolution = "weekly", + horizons = 0:3, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL) |> + dplyr::pull(.data[["model_id"]]) |> + unique() + expect_equal(actual_model_ids, expected_model_ids) +}) + +test_that("mapping over locations works as expected", { + expected_outputs <- + rbind( + fit_baselines_one_location( + model_variations, + target_ts |> dplyr::filter(.data[["location"]] == "ak"), + reference_date = "2023-01-14", + temporal_resolution = "weekly", + horizons = 0:3, + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + seed = 1234 + ), + fit_baselines_one_location( + model_variations, + target_ts |> dplyr::filter(.data[["location"]] == "al"), + reference_date = "2023-01-14", + temporal_resolution = "weekly", + horizons = 0:3, + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + seed = 1234 + ) + ) |> + dplyr::mutate( + model_id = paste( + "UMass-baseline", + .data[["transformation"]], + ifelse(.data[["symmetrize"]], "sym", "nonsym"), + .data[["window_size"]], + "weekly", + sep = "_" + ), + target = "inc hosp", + .before = 1 + ) |> + dplyr::select(dplyr::all_of(c( + "model_id", "location", "reference_date", "horizon", "target", + "target_end_date", "output_type", "output_type_id", "value" + ))) |> + hubUtils::as_model_out_tbl() + + actual_outputs <- fit_baseline_models(model_variations, + target_ts, + reference_date = "2023-01-14", + temporal_resolution = "weekly", + horizons = 0:3, + target = "inc hosp", + quantile_levels = c(.1, .5, .9), + n_samples = NULL, + seed = 1234) + expect_equal(actual_outputs, expected_outputs, tolerance = 1e-3) +}) \ No newline at end of file