Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for 2024/25 season endpoints #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 221 additions & 1 deletion tests/testthat/test-get_flusight_bin_endpoints.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("get_flusight_bin_endpoints works, 2023/24", {
# definitions at
# https://github.com/cdcepi/FluSight-forecast-hub/tree/main/model-output#rate-trend-forecast-specifications
# https://github.com/cdcepi/FluSight-forecast-hub/blob/1c10e63fa1b115c71f274d90a2420cb599ce5f64/model-output/README.md#rate-trend-forecast-specifications
#
# our strategy is to:
# - construct data that should fall into known categories
Expand Down Expand Up @@ -216,3 +216,223 @@ test_that("get_flusight_bin_endpoints works, 2023/24", {
# expect no mismatches!
expect_equal(nrow(mismatched_categorizations), 0L)
})


test_that("get_flusight_bin_endpoints works, 2024/25", {
# definitions at
# https://github.com/cdcepi/FluSight-forecast-hub/tree/main/model-output#rate-trend-forecast-specifications
#
# our strategy is to:
# - construct data that should fall into known categories
# (with every horizon/category combination, and both criteria for stable)
# - compute bins and apply them to the data
# - check that we got the right answers
location_meta <- readr::read_csv(
file = testthat::test_path("fixtures", "location_meta_24.csv")
) |>
dplyr::mutate(
pop100k = .data[["population"]] / 100000
)

# locations for testing "stable", "increase" and "decrease" thresholds are:
# 56 = Wyoming, pop100k = 5.78
# 02 = Alaska, pop100k = 7.11 and 11 = District of Columbia, pop100k = 6.69
# these states trigger the "minimum count change at least 10" rule
locs <- c("US", "56", "02", "11", "05", "06")

# create data
# our reference date will be 2024-11-23.
# changes are relative to 2024-11-16.
# 2024-11-09 is throw-away, to make sure we grab the right "relative to" date
target_data <- tidyr::expand_grid(
location = locs,
date = as.Date("2024-11-09") + seq(from = 0, by = 7, length.out = 6),
value = NA
)

expected_categories <- NULL

# all expected category levels are "stable": rate change less than
# 0.3, 0.5, 0.7, or 1 * population rate
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[1]]
target_data$value[target_data$location == locs[1]] <- c(
0,
10000,
10000 + floor(0.29 * loc_pop100k),
10000 - floor(0.49 * loc_pop100k),
10000 + floor(0.69 * loc_pop100k),
10000 - floor(0.99 * loc_pop100k)
Comment on lines +261 to +264
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had copied the test for the 2023/24 season and modified the numbers for 2024/25, but the alternating signs for this portion of the original code (lines 41 to 44) confused me. Should they be alternating here?

)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[1],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "stable"
)
)

# all expected category levels are "stable": count change less than 10
target_data$value[target_data$location == locs[2]] <- c(0, 300, 304, 297, 309, 291)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[2],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "stable"
)
)

# all expected category levels are "increase": count change >= 10,
# horizon 0: 0.3 <= rate change < 1.7
# horizon 1: 0.5 <= rate change < 3
# horizon 2: 0.7 <= rate change < 4
# horizon 3: 1.0 <= rate change < 5
# note, loc_pop100k for this location is 7.11 < 10
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[3]]
target_data$value[target_data$location == locs[3]] <- c(
0,
10000,
10000 + 10,
10000 + floor(2.99 * loc_pop100k),
10000 + floor(3.99 * loc_pop100k),
10000 + ceiling(1.01 * loc_pop100k)
Comment on lines +305 to +307
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, I just modified the numbers here, but they didn't seem to be right in the original code (lines 113 to 115)

)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[3],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "increase"
)
)

# all expected category levels are "decrease": count change <= -11,
# horizon 0: -0.3 >= rate change > -1.7
# horizon 1: -0.5 >= rate change > -3
# horizon 2: -0.7 >= rate change > -4
# horizon 3: -1.0 >= rate change > -5
# note, loc_pop100k for this location is 73.4 >= 10
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[4]]
target_data$value[target_data$location == locs[4]] <- c(
0,
10000,
10000 - 11,
10000 - floor(2.99 * loc_pop100k),
10000 - floor(3.99 * loc_pop100k),
10000 - ceiling(1.01 * loc_pop100k)
)
expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[4],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "decrease"
)
)

# all expected category levels are "large increase": count change >= 10,
# horizon 0: 1.7 <= rate change
# horizon 1: 3 <= rate change
# horizon 2: 4 <= rate change
# horizon 3: 5 <= rate change
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[5]]
target_data$value[target_data$location == locs[5]] <- c(
0,
10000,
10000 + max(10, ceiling(1.7 * loc_pop100k)),
10000 + max(10, ceiling(3 * loc_pop100k)),
10000 + max(10, ceiling(4 * loc_pop100k)),
10000 + max(10, ceiling(5 * loc_pop100k))
)

expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[5],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "large_increase"
)
)

# all expected category levels are "large decrease": count change <= -10,
# horizon 0: -1.7 >= rate change
# horizon 1: -3 >= rate change
# horizon 2: -4 >= rate change
# horizon 3: -5 >= rate change
loc_pop100k <- location_meta$pop100k[location_meta$location == locs[6]]
target_data$value[target_data$location == locs[6]] <- c(
0,
10000,
10000 - max(10, ceiling(1.7 * loc_pop100k)),
10000 - max(10, ceiling(3 * loc_pop100k)),
10000 - max(10, ceiling(4 * loc_pop100k)),
10000 - max(10, ceiling(5 * loc_pop100k))
)

expected_categories <- dplyr::bind_rows(
expected_categories,
target_data |>
dplyr::filter(
.data[["location"]] == locs[6],
.data[["date"]] >= "2024-11-23"
) |>
dplyr::mutate(
horizon = as.integer((.data[["date"]] - as.Date("2024-11-23")) / 7),
output_type_id = "large_decrease"
)
)

bin_endpoints <- get_flusight_bin_endpoints(
target_ts = target_data |>
dplyr::filter(
.data[["date"]] < "2024-11-23"
),
location_meta = location_meta,
season = "2024/25"
)

actual_categories <- bin_endpoints |>
dplyr::mutate(
reference_date = as.Date("2024-11-23"),
target_end_date = as.Date("2024-11-23") + 7 * .data[["horizon"]]
) |>
dplyr::left_join(
target_data,
by = c("location", "target_end_date" = "date")
) |>
dplyr::filter(
.data[["lower"]] < .data[["value"]],
.data[["value"]] <= .data[["upper"]]
)

mismatched_categorizations <- expected_categories |>
dplyr::left_join(
actual_categories,
by = c("location", "date" = "reference_date", "horizon")
) |>
dplyr::filter(output_type_id.x != output_type_id.y)

# expect no mismatches!
expect_equal(nrow(mismatched_categorizations), 0L)
})
Loading