diff --git a/R/cal-estimate-beta.R b/R/cal-estimate-beta.R index 9594750a..90d21289 100644 --- a/R/cal-estimate-beta.R +++ b/R/cal-estimate-beta.R @@ -225,3 +225,24 @@ cal_beta_impl_single <- function(.data, beta_model } + + +check_cal_groups <- function(group, .data, call = rlang::env_parent()) { + group <- enquo(group) + if (!any(names(.data) == ".config")) { + return(invisible(NULL)) + } + num_configs <- length(unique(.data$.config)) + if (num_configs == 1) { + return(invisible(NULL)) + } + has_no_groups <- rlang::quo_is_null(group) + if (has_no_groups) { + msg <- paste("The data have several values of '.config' but no 'groups'", + "argument was passed. This will inappropriately pool the", + "data.") + rlang::abort(msg, call = call) + } + invisible(NULL) +} + diff --git a/R/cal-plot-breaks.R b/R/cal-plot-breaks.R index d4750847..cd641aee 100644 --- a/R/cal-plot-breaks.R +++ b/R/cal-plot-breaks.R @@ -112,6 +112,9 @@ cal_plot_breaks.data.frame <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { + + check_cal_groups({{ group }}, .data) + cal_plot_breaks_impl( .data = .data, truth = {{ truth }}, diff --git a/R/cal-plot-logistic.R b/R/cal-plot-logistic.R index bebccb6d..0e188656 100644 --- a/R/cal-plot-logistic.R +++ b/R/cal-plot-logistic.R @@ -61,6 +61,9 @@ cal_plot_logistic.data.frame <- function(.data, include_ribbon = TRUE, event_level = c("auto", "first", "second"), ...) { + + check_cal_groups({{ group }}, .data) + cal_plot_logistic_impl( .data = .data, truth = {{ truth }}, diff --git a/R/cal-plot-regression.R b/R/cal-plot-regression.R index cee3b60d..8c72a683 100644 --- a/R/cal-plot-regression.R +++ b/R/cal-plot-regression.R @@ -40,6 +40,9 @@ cal_plot_regression_impl <- function(.data, group = NULL, smooth = TRUE, ...) { + + check_cal_groups({{ group }}, .data) + truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) diff --git a/R/cal-plot-windowed.R b/R/cal-plot-windowed.R index 6ca795f9..26dc1bec 100644 --- a/R/cal-plot-windowed.R +++ b/R/cal-plot-windowed.R @@ -68,6 +68,8 @@ cal_plot_windowed.data.frame <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { + check_cal_groups({{ group }}, .data) + cal_plot_windowed_impl( .data = .data, truth = {{ truth }}, diff --git a/tests/testthat/_snaps/cal-plot.md b/tests/testthat/_snaps/cal-plot.md index bb8c1b16..64eda72c 100644 --- a/tests/testthat/_snaps/cal-plot.md +++ b/tests/testthat/_snaps/cal-plot.md @@ -1,4 +1,48 @@ +# Binary breaks functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_breaks(class, + estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Multi-class breaks functions work + + Code + testthat_cal_multiclass() %>% tune::collect_predictions() %>% cal_plot_breaks( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Binary logistic functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_logistic( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Binary windowed functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_windowed( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + # Event level handling works Invalid event_level entry: invalid. Valid entries are 'first', 'second', or 'auto' +# regression functions work + + Code + obj %>% tune::collect_predictions() %>% cal_plot_windowed(outcome, estimate = .pred) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + diff --git a/tests/testthat/_snaps/cal-validate.md b/tests/testthat/_snaps/cal-validate.md index c1f5c3da..9ef60a34 100644 --- a/tests/testthat/_snaps/cal-validate.md +++ b/tests/testthat/_snaps/cal-validate.md @@ -5,19 +5,18 @@ Output # 10-fold cross-validation # A tibble: 10 x 6 - splits id calibration validation stats_after stats_~1 - - 1 Fold01 - 2 Fold02 - 3 Fold03 - 4 Fold04 - 5 Fold05 - 6 Fold06 - 7 Fold07 - 8 Fold08 - 9 Fold09 - 10 Fold10 - # ... with abbreviated variable name 1: stats_before + splits id calibration validation stats_after stats_before + + 1 Fold01 + 2 Fold02 + 3 Fold03 + 4 Fold04 + 5 Fold05 + 6 Fold06 + 7 Fold07 + 8 Fold08 + 9 Fold09 + 10 Fold10 # Logistic validation with `fit_resamples` diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R index b4f5b702..24b18728 100644 --- a/tests/testthat/test-cal-plot.R +++ b/tests/testthat/test-cal-plot.R @@ -22,6 +22,14 @@ test_that("Binary breaks functions work", { cal_plot_breaks(testthat_cal_binary()), "ggplot" ) + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_breaks(class, estimate = .pred_class_1) + ) + }) test_that("Multi-class breaks functions work", { @@ -53,6 +61,13 @@ test_that("Multi-class breaks functions work", { expect_error( cal_plot_breaks(species_probs, Species, event_level = "second") ) + + expect_snapshot( + error = TRUE, + testthat_cal_multiclass() %>% + tune::collect_predictions() %>% + cal_plot_breaks(class, estimate = .pred_class_1) + ) }) test_that("Binary logistic functions work", { @@ -125,6 +140,13 @@ test_that("Binary logistic functions work", { which(x25$prob == max(x25$prob)), nrow(x25) ) + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_logistic(class, estimate = .pred_class_1) + ) }) test_that("Binary windowed functions work", { @@ -192,6 +214,13 @@ test_that("Binary windowed functions work", { x33 <- cal_plot_windowed(testthat_cal_binary()) expect_s3_class(x33, "ggplot") + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_windowed(class, estimate = .pred_class_1) + ) }) test_that("Event level handling works", { @@ -330,4 +359,11 @@ test_that("regression functions work", { "rs-scat-group-opts", print(cal_plot_regression(obj), alpha = 1/5, smooth = FALSE) ) + + expect_snapshot( + error = TRUE, + obj %>% + tune::collect_predictions() %>% + cal_plot_windowed(outcome, estimate = .pred) + ) })