Skip to content

Commit

Permalink
Check for no groups and multiple configurations (#96)
Browse files Browse the repository at this point in the history
* plot changes for #92

* add groups to estimate function

* Revert "add groups to estimate function"

This reverts commit 558fa43.

* tuning the quosure code and update tests

* update with new pillar
  • Loading branch information
topepo authored Apr 28, 2023
1 parent 4600920 commit dda8c10
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 0 deletions.
21 changes: 21 additions & 0 deletions R/cal-estimate-beta.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,24 @@ cal_beta_impl_single <- function(.data,

beta_model
}


check_cal_groups <- function(group, .data, call = rlang::env_parent()) {
group <- enquo(group)
if (!any(names(.data) == ".config")) {
return(invisible(NULL))
}
num_configs <- length(unique(.data$.config))
if (num_configs == 1) {
return(invisible(NULL))
}
has_no_groups <- rlang::quo_is_null(group)
if (has_no_groups) {
msg <- paste("The data have several values of '.config' but no 'groups'",
"argument was passed. This will inappropriately pool the",
"data.")
rlang::abort(msg, call = call)
}
invisible(NULL)
}

3 changes: 3 additions & 0 deletions R/cal-plot-breaks.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ cal_plot_breaks.data.frame <- function(.data,
include_points = TRUE,
event_level = c("auto", "first", "second"),
...) {

check_cal_groups({{ group }}, .data)

cal_plot_breaks_impl(
.data = .data,
truth = {{ truth }},
Expand Down
3 changes: 3 additions & 0 deletions R/cal-plot-logistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ cal_plot_logistic.data.frame <- function(.data,
include_ribbon = TRUE,
event_level = c("auto", "first", "second"),
...) {

check_cal_groups({{ group }}, .data)

cal_plot_logistic_impl(
.data = .data,
truth = {{ truth }},
Expand Down
3 changes: 3 additions & 0 deletions R/cal-plot-regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ cal_plot_regression_impl <- function(.data,
group = NULL,
smooth = TRUE,
...) {

check_cal_groups({{ group }}, .data)

truth <- enquo(truth)
estimate <- enquo(estimate)
group <- enquo(group)
Expand Down
2 changes: 2 additions & 0 deletions R/cal-plot-windowed.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ cal_plot_windowed.data.frame <- function(.data,
include_points = TRUE,
event_level = c("auto", "first", "second"),
...) {
check_cal_groups({{ group }}, .data)

cal_plot_windowed_impl(
.data = .data,
truth = {{ truth }},
Expand Down
44 changes: 44 additions & 0 deletions tests/testthat/_snaps/cal-plot.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,48 @@
# Binary breaks functions work

Code
testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_breaks(class,
estimate = .pred_class_1)
Condition
Error:
! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

# Multi-class breaks functions work

Code
testthat_cal_multiclass() %>% tune::collect_predictions() %>% cal_plot_breaks(
class, estimate = .pred_class_1)
Condition
Error:
! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

# Binary logistic functions work

Code
testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_logistic(
class, estimate = .pred_class_1)
Condition
Error:
! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

# Binary windowed functions work

Code
testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_windowed(
class, estimate = .pred_class_1)
Condition
Error:
! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

# Event level handling works

Invalid event_level entry: invalid. Valid entries are 'first', 'second', or 'auto'

# regression functions work

Code
obj %>% tune::collect_predictions() %>% cal_plot_windowed(outcome, estimate = .pred)
Condition
Error:
! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

36 changes: 36 additions & 0 deletions tests/testthat/test-cal-plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ test_that("Binary breaks functions work", {
cal_plot_breaks(testthat_cal_binary()),
"ggplot"
)

expect_snapshot(
error = TRUE,
testthat_cal_binary() %>%
tune::collect_predictions() %>%
cal_plot_breaks(class, estimate = .pred_class_1)
)

})

test_that("Multi-class breaks functions work", {
Expand Down Expand Up @@ -53,6 +61,13 @@ test_that("Multi-class breaks functions work", {
expect_error(
cal_plot_breaks(species_probs, Species, event_level = "second")
)

expect_snapshot(
error = TRUE,
testthat_cal_multiclass() %>%
tune::collect_predictions() %>%
cal_plot_breaks(class, estimate = .pred_class_1)
)
})

test_that("Binary logistic functions work", {
Expand Down Expand Up @@ -125,6 +140,13 @@ test_that("Binary logistic functions work", {
which(x25$prob == max(x25$prob)),
nrow(x25)
)

expect_snapshot(
error = TRUE,
testthat_cal_binary() %>%
tune::collect_predictions() %>%
cal_plot_logistic(class, estimate = .pred_class_1)
)
})

test_that("Binary windowed functions work", {
Expand Down Expand Up @@ -192,6 +214,13 @@ test_that("Binary windowed functions work", {
x33 <- cal_plot_windowed(testthat_cal_binary())

expect_s3_class(x33, "ggplot")

expect_snapshot(
error = TRUE,
testthat_cal_binary() %>%
tune::collect_predictions() %>%
cal_plot_windowed(class, estimate = .pred_class_1)
)
})

test_that("Event level handling works", {
Expand Down Expand Up @@ -330,4 +359,11 @@ test_that("regression functions work", {
"rs-scat-group-opts",
print(cal_plot_regression(obj), alpha = 1/5, smooth = FALSE)
)

expect_snapshot(
error = TRUE,
obj %>%
tune::collect_predictions() %>%
cal_plot_windowed(outcome, estimate = .pred)
)
})

0 comments on commit dda8c10

Please sign in to comment.