Skip to content

Commit

Permalink
Add fit_baseline_models() tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Oct 23, 2024
1 parent 48d74dd commit a7de97a
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
2 changes: 1 addition & 1 deletion R/fit_baseline_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fit_baseline_models <- function(model_variations,
round_predictions = FALSE,
seed = NULL) {
if (is.null(target)) {
cli::cli_abort("No {.arg target} provided; please provide one")
cli::cli_abort("{.arg target} is missing; please provide one")
}

# fit baseline models
Expand Down
94 changes: 94 additions & 0 deletions tests/testthat/test-fit_baseline_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# set up simple data for test cases
model_variations <- data.frame(stringsAsFactors = FALSE,
transformation = c("none", "none"),
symmetrize = c(TRUE, TRUE),
window_size = c(3, 4))
target_ts <- expand.grid(stringsAsFactors = FALSE,
location = c("ak", "al"),
time_index = as.Date("2022-11-05") + seq(0, 63, 7),
observation = NA)
target_ts$observation[target_ts$location == "ak"] <- c(15, 14, 38, 69, 53, 73, 51, 43, 43, 32)
target_ts$observation[target_ts$location == "al"] <- c(350, 312, 236, 237, 360, 234, 153, 153, 148, 125)

test_that("missing target throws an error", {
fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = NULL,
quantile_levels = c(.1, .5, .9),
n_samples = NULL) |>
expect_error(
regexp = "`target` is missing; please provide one", fixed = TRUE
)
})

test_that("model IDs are as expected", {
expected_model_ids <- paste("UMass-baseline", "none", "sym", c(3, 4), "weekly", sep = "_")
actual_model_ids <- fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = "inc hosp",
quantile_levels = c(.1, .5, .9),
n_samples = NULL) |>
dplyr::pull(.data[["model_id"]]) |>
unique()
expect_equal(actual_model_ids, expected_model_ids)
})

test_that("mapping over locations works as expected", {
expected_outputs <-
rbind(
fit_baselines_one_location(
model_variations,
target_ts |> dplyr::filter(.data[["location"]] == "ak"),
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234
),
fit_baselines_one_location(
model_variations,
target_ts |> dplyr::filter(.data[["location"]] == "al"),
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234
)
) |>
dplyr::mutate(
model_id = paste(
"UMass-baseline",
.data[["transformation"]],
ifelse(.data[["symmetrize"]], "sym", "nonsym"),
.data[["window_size"]],
"weekly",
sep = "_"
),
target = "inc hosp",
.before = 1
) |>
dplyr::select(dplyr::all_of(c(
"model_id", "location", "reference_date", "horizon", "target",
"target_end_date", "output_type", "output_type_id", "value"
))) |>
hubUtils::as_model_out_tbl()

actual_outputs <- fit_baseline_models(model_variations,
target_ts,
reference_date = "2023-01-14",
temporal_resolution = "weekly",
horizons = 0:3,
target = "inc hosp",
quantile_levels = c(.1, .5, .9),
n_samples = NULL,
seed = 1234)
expect_equal(actual_outputs, expected_outputs, tolerance = 1e-3)
})

Check warning on line 94 in tests/testthat/test-fit_baseline_models.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-fit_baseline_models.R,line=94,col=3,[trailing_blank_lines_linter] Missing terminal newline.

0 comments on commit a7de97a

Please sign in to comment.