From cb786bf73be90833fb7d57db0f6633f7327ea6fa Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 Mar 2023 13:09:03 -0400 Subject: [PATCH] plot changes for #92 --- R/cal-estimate-beta.R | 20 +++++++++++++++++ R/cal-plot-breaks.R | 3 +++ R/cal-plot-logistic.R | 3 +++ R/cal-plot-regression.R | 3 +++ R/cal-plot-windowed.R | 2 ++ tests/testthat/_snaps/cal-plot.md | 36 +++++++++++++++++++++++++++++++ tests/testthat/test-cal-plot.R | 36 +++++++++++++++++++++++++++++++ 7 files changed, 103 insertions(+) diff --git a/R/cal-estimate-beta.R b/R/cal-estimate-beta.R index 9594750a..b884d94f 100644 --- a/R/cal-estimate-beta.R +++ b/R/cal-estimate-beta.R @@ -225,3 +225,23 @@ cal_beta_impl_single <- function(.data, beta_model } + + +check_cal_groups <- function(group, .data, call = rlang::env_parent()) { + if (!any(names(.data) == ".config")) { + return(invisible(NULL)) + } + num_configs <- length(unique(.data$.config)) + if (num_configs == 1) { + return(invisible(NULL)) + } + has_no_groups <- isTRUE(is.null(group)) + if (has_no_groups) { + msg <- paste("The data have several values of '.config' but no 'groups'", + "argument was passed. This will inappropriately pool the", + "data.") + rlang::abort(msg, call = call) + } + invisible(NULL) +} + diff --git a/R/cal-plot-breaks.R b/R/cal-plot-breaks.R index d4750847..638e3003 100644 --- a/R/cal-plot-breaks.R +++ b/R/cal-plot-breaks.R @@ -112,6 +112,9 @@ cal_plot_breaks.data.frame <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { + + check_cal_groups(group, .data) + cal_plot_breaks_impl( .data = .data, truth = {{ truth }}, diff --git a/R/cal-plot-logistic.R b/R/cal-plot-logistic.R index bebccb6d..53793284 100644 --- a/R/cal-plot-logistic.R +++ b/R/cal-plot-logistic.R @@ -61,6 +61,9 @@ cal_plot_logistic.data.frame <- function(.data, include_ribbon = TRUE, event_level = c("auto", "first", "second"), ...) { + + check_cal_groups(group, .data) + cal_plot_logistic_impl( .data = .data, truth = {{ truth }}, diff --git a/R/cal-plot-regression.R b/R/cal-plot-regression.R index cee3b60d..8d538cf4 100644 --- a/R/cal-plot-regression.R +++ b/R/cal-plot-regression.R @@ -40,6 +40,9 @@ cal_plot_regression_impl <- function(.data, group = NULL, smooth = TRUE, ...) { + + check_cal_groups(group, .data) + truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) diff --git a/R/cal-plot-windowed.R b/R/cal-plot-windowed.R index 6ca795f9..020a19f4 100644 --- a/R/cal-plot-windowed.R +++ b/R/cal-plot-windowed.R @@ -68,6 +68,8 @@ cal_plot_windowed.data.frame <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { + check_cal_groups(group, .data) + cal_plot_windowed_impl( .data = .data, truth = {{ truth }}, diff --git a/tests/testthat/_snaps/cal-plot.md b/tests/testthat/_snaps/cal-plot.md index bb8c1b16..87f8c04a 100644 --- a/tests/testthat/_snaps/cal-plot.md +++ b/tests/testthat/_snaps/cal-plot.md @@ -1,3 +1,39 @@ +# Binary breaks functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_breaks(class, + estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Multi-class breaks functions work + + Code + testthat_cal_multiclass() %>% tune::collect_predictions() %>% cal_plot_breaks( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Binary logistic functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_logistic( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + +# Binary windowed functions work + + Code + testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_windowed( + class, estimate = .pred_class_1) + Condition + Error: + ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + # Event level handling works Invalid event_level entry: invalid. Valid entries are 'first', 'second', or 'auto' diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R index b4f5b702..b6821ab9 100644 --- a/tests/testthat/test-cal-plot.R +++ b/tests/testthat/test-cal-plot.R @@ -22,6 +22,14 @@ test_that("Binary breaks functions work", { cal_plot_breaks(testthat_cal_binary()), "ggplot" ) + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_breaks(class, estimate = .pred_class_1) + ) + }) test_that("Multi-class breaks functions work", { @@ -53,6 +61,13 @@ test_that("Multi-class breaks functions work", { expect_error( cal_plot_breaks(species_probs, Species, event_level = "second") ) + + expect_snapshot( + error = TRUE, + testthat_cal_multiclass() %>% + tune::collect_predictions() %>% + cal_plot_breaks(class, estimate = .pred_class_1) + ) }) test_that("Binary logistic functions work", { @@ -125,6 +140,13 @@ test_that("Binary logistic functions work", { which(x25$prob == max(x25$prob)), nrow(x25) ) + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_logistic(class, estimate = .pred_class_1) + ) }) test_that("Binary windowed functions work", { @@ -192,6 +214,13 @@ test_that("Binary windowed functions work", { x33 <- cal_plot_windowed(testthat_cal_binary()) expect_s3_class(x33, "ggplot") + + expect_snapshot( + error = TRUE, + testthat_cal_binary() %>% + tune::collect_predictions() %>% + cal_plot_windowed(class, estimate = .pred_class_1) + ) }) test_that("Event level handling works", { @@ -330,4 +359,11 @@ test_that("regression functions work", { "rs-scat-group-opts", print(cal_plot_regression(obj), alpha = 1/5, smooth = FALSE) ) + + expect_snapshot( + error = TRUE, + obj %>% + tune::collect_predictions() %>% + cal_plot_windowed(class, estimate = .pred_class_1) + ) })