From 645e92ede336561df2026b9ea793035a920e4833 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Sep 2024 11:41:05 +0100 Subject: [PATCH 1/4] `repeats`: replace custom checker with standard one --- R/clustering.R | 2 +- R/vfold.R | 10 ++-------- tests/testthat/_snaps/clustering.md | 4 ++-- tests/testthat/_snaps/vfold.md | 4 ++-- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/R/clustering.R b/R/clustering.R index 5ed07f60..dbfa2715 100644 --- a/R/clustering.R +++ b/R/clustering.R @@ -57,7 +57,7 @@ clustering_cv <- function(data, distance_function = "dist", cluster_function = c("kmeans", "hclust"), ...) { - check_repeats(repeats) + check_number_whole(repeats, min = 1) if (!rlang::is_function(cluster_function)) { cluster_function <- rlang::arg_match(cluster_function) diff --git a/R/vfold.R b/R/vfold.R index faded51e..9a23a712 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -72,7 +72,7 @@ vfold_cv <- function(data, v = 10, repeats = 1, } check_strata(strata, data) - check_repeats(repeats) + check_number_whole(repeats, min = 1) if (repeats == 1) { split_objs <- vfold_splits( @@ -213,7 +213,7 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, pr #' @export group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) { check_dots_empty() - check_repeats(repeats) + check_number_whole(repeats, min = 1) group <- validate_group({{ group }}, data) balance <- rlang::arg_match(balance) @@ -368,9 +368,3 @@ check_grouped_strata <- function(group, strata, pool, data) { strata } - -check_repeats <- function(repeats, call = rlang::caller_env()) { - if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) { - cli_abort("{.arg repeats} must be a single positive integer.", call = call) - } -} diff --git a/tests/testthat/_snaps/clustering.md b/tests/testthat/_snaps/clustering.md index 048a87c1..ad66d49f 100644 --- a/tests/testthat/_snaps/clustering.md +++ b/tests/testthat/_snaps/clustering.md @@ -44,7 +44,7 @@ clustering_cv(Orange, repeats = 0) Condition Error in `clustering_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number larger than or equal to 1, not the number 0. --- @@ -52,7 +52,7 @@ clustering_cv(Orange, repeats = NULL) Condition Error in `clustering_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number, not `NULL`. --- diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index add53300..a2df8540 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -87,7 +87,7 @@ vfold_cv(Orange, repeats = 0) Condition Error in `vfold_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number larger than or equal to 1, not the number 0. --- @@ -95,7 +95,7 @@ vfold_cv(Orange, repeats = NULL) Condition Error in `vfold_cv()`: - ! `repeats` must be a single positive integer. + ! `repeats` must be a whole number, not `NULL`. --- From b5f0098be5c76ec79ad7d0d690687b1e344df807 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Sep 2024 12:29:17 +0100 Subject: [PATCH 2/4] improve and test `check_grouped_strata()` --- R/vfold.R | 7 +++++-- tests/testthat/_snaps/vfold.md | 8 ++++++++ tests/testthat/test-vfold.R | 29 +++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/R/vfold.R b/R/vfold.R index 9a23a712..fc6a7db6 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -347,7 +347,7 @@ check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::c } } -check_grouped_strata <- function(group, strata, pool, data) { +check_grouped_strata <- function(group, strata, pool, data, call = caller_env()) { strata <- tidyselect::vars_select(names(data), !!enquo(strata)) @@ -363,7 +363,10 @@ check_grouped_strata <- function(group, strata, pool, data) { if (nrow(vctrs::vec_unique(grouped_table)) != nrow(vctrs::vec_unique(grouped_table["group"]))) { - cli_abort("{.arg strata} must be constant across all members of each {.arg group}.") + cli_abort( + "{.field strata} must be constant across all members of each {.field group}.", + call = call + ) } strata diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index a2df8540..2e6dd6ed 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -286,6 +286,14 @@ 10 Resample10 # i 20 more rows +# grouping fails for strata not constant across group members + + Code + group_vfold_cv(sample_data, group, v = 5, strata = outcome) + Condition + Error in `group_vfold_cv()`: + ! strata must be constant across all members of each group. + # grouping -- printing Code diff --git a/tests/testthat/test-vfold.R b/tests/testthat/test-vfold.R index f6e7bd1f..83a5899b 100644 --- a/tests/testthat/test-vfold.R +++ b/tests/testthat/test-vfold.R @@ -403,6 +403,35 @@ test_that("grouping -- strata", { ) }) +test_that("grouping fails for strata not constant across group members", { + set.seed(11) + + n_common_class <- 70 + n_rare_class <- 30 + + group_table <- tibble( + group = 1:100, + outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class))) + ) + observation_table <- tibble( + group = sample(1:100, 1e5, replace = TRUE), + observation = 1:1e5 + ) + sample_data <- dplyr::full_join( + group_table, + observation_table, + by = "group", + multiple = "all" + ) + + # violate requirement + sample_data$outcome[1] <- ifelse(sample_data$outcome[1], 0, 1) + + expect_snapshot(error = TRUE, { + group_vfold_cv(sample_data, group, v = 5, strata = outcome) + }) +}) + test_that("grouping -- repeated", { set.seed(11) rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4) From 1cb0efd74f80100e49282bcd948be441e80e3acc Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Sep 2024 18:18:39 +0100 Subject: [PATCH 3/4] update `check_v()` --- R/vfold.R | 13 +++++++------ tests/testthat/_snaps/clustering.md | 4 ++-- tests/testthat/_snaps/vfold.md | 8 ++++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/R/vfold.R b/R/vfold.R index fc6a7db6..0aa9f0c9 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -331,18 +331,19 @@ add_vfolds <- function(x, v) { } check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) { - if (!is.numeric(v) || length(v) != 1 || v < 2) { - cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call) - } else if (v > max_v) { + check_number_whole(v, min = 2, call = call) + + if (v > max_v) { cli_abort( "The number of {rows} is less than {.arg v} = {.val {v}}.", call = call ) - } else if (prevent_loo && isTRUE(v == max_v)) { + } + if (prevent_loo && isTRUE(v == max_v)) { cli_abort(c( "Leave-one-out cross-validation is not supported by this function.", - "x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.", - "i" = "Use `loo_cv()` in this case." + "x" = "You set {.arg v} to {.code nrow(data)}, which would result in a leave-one-out cross-validation.", + "i" = "Use {.fn loo_cv} in this case." ), call = call) } } diff --git a/tests/testthat/_snaps/clustering.md b/tests/testthat/_snaps/clustering.md index ad66d49f..68c20d93 100644 --- a/tests/testthat/_snaps/clustering.md +++ b/tests/testthat/_snaps/clustering.md @@ -12,7 +12,7 @@ clustering_cv(iris, Sepal.Length, v = -500) Condition Error in `clustering_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number -500. --- @@ -36,7 +36,7 @@ clustering_cv(Orange, v = 1, vars = "Tree") Condition Error in `clustering_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. --- diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index 2e6dd6ed..ce340309 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -47,7 +47,7 @@ vfold_cv(iris, v = -500) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number -500. --- @@ -55,7 +55,7 @@ vfold_cv(iris, v = 1) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. --- @@ -63,7 +63,7 @@ vfold_cv(iris, v = NULL) Condition Error in `vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number, not `NULL`. --- @@ -191,7 +191,7 @@ group_vfold_cv(Orange, v = 1, group = "Tree") Condition Error in `group_vfold_cv()`: - ! `v` must be a single positive integer greater than 1. + ! `v` must be a whole number larger than or equal to 2, not the number 1. # grouping -- other balance methods From a74d3a4fa2ea5cfd321d4535a39e022d0a450bef Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Sep 2024 18:21:19 +0100 Subject: [PATCH 4/4] split up test for more descriptive title --- tests/testthat/_snaps/vfold.md | 22 +++++++++++----------- tests/testthat/test-vfold.R | 11 +++++++---- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index ce340309..63664707 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -41,7 +41,7 @@ ! strata cannot be a object. i Use the time or event variable directly. -# bad args +# v arg is checked Code vfold_cv(iris, v = -500) @@ -75,6 +75,16 @@ --- + Code + vfold_cv(mtcars, v = nrow(mtcars)) + Condition + Error in `vfold_cv()`: + ! Leave-one-out cross-validation is not supported by this function. + x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation. + i Use `loo_cv()` in this case. + +# repeats arg is checked + Code vfold_cv(iris, v = 150, repeats = 2) Condition @@ -97,16 +107,6 @@ Error in `vfold_cv()`: ! `repeats` must be a whole number, not `NULL`. ---- - - Code - vfold_cv(mtcars, v = nrow(mtcars)) - Condition - Error in `vfold_cv()`: - ! Leave-one-out cross-validation is not supported by this function. - x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation. - i Use `loo_cv()` in this case. - # printing Code diff --git a/tests/testthat/test-vfold.R b/tests/testthat/test-vfold.R index 83a5899b..56602300 100644 --- a/tests/testthat/test-vfold.R +++ b/tests/testthat/test-vfold.R @@ -104,7 +104,7 @@ test_that("strata arg is checked", { }) }) -test_that("bad args", { +test_that("v arg is checked", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = -500) }) @@ -117,6 +117,12 @@ test_that("bad args", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = 500) }) + expect_snapshot(error = TRUE, { + vfold_cv(mtcars, v = nrow(mtcars)) + }) +}) + +test_that("repeats arg is checked", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = 150, repeats = 2) }) @@ -126,9 +132,6 @@ test_that("bad args", { expect_snapshot(error = TRUE, { vfold_cv(Orange, repeats = NULL) }) - expect_snapshot(error = TRUE, { - vfold_cv(mtcars, v = nrow(mtcars)) - }) }) test_that("printing", {