Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update input checks for vfold_cv.R #548

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ clustering_cv <- function(data,
distance_function = "dist",
cluster_function = c("kmeans", "hclust"),
...) {
check_repeats(repeats)
check_number_whole(repeats, min = 1)

if (!rlang::is_function(cluster_function)) {
cluster_function <- rlang::arg_match(cluster_function)
Expand Down
30 changes: 14 additions & 16 deletions R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ vfold_cv <- function(data, v = 10, repeats = 1,
}

check_strata(strata, data)
check_repeats(repeats)
check_number_whole(repeats, min = 1)

if (repeats == 1) {
split_objs <- vfold_splits(
Expand Down Expand Up @@ -213,7 +213,7 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, pr
#' @export
group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) {
check_dots_empty()
check_repeats(repeats)
check_number_whole(repeats, min = 1)
group <- validate_group({{ group }}, data)
balance <- rlang::arg_match(balance)

Expand Down Expand Up @@ -331,23 +331,24 @@ add_vfolds <- function(x, v) {
}

check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) {
if (!is.numeric(v) || length(v) != 1 || v < 2) {
cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call)
} else if (v > max_v) {
check_number_whole(v, min = 2, call = call)

if (v > max_v) {
cli_abort(
"The number of {rows} is less than {.arg v} = {.val {v}}.",
call = call
)
} else if (prevent_loo && isTRUE(v == max_v)) {
}
if (prevent_loo && isTRUE(v == max_v)) {
cli_abort(c(
"Leave-one-out cross-validation is not supported by this function.",
"x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.",
"i" = "Use `loo_cv()` in this case."
"x" = "You set {.arg v} to {.code nrow(data)}, which would result in a leave-one-out cross-validation.",
"i" = "Use {.fn loo_cv} in this case."
), call = call)
}
}

check_grouped_strata <- function(group, strata, pool, data) {
check_grouped_strata <- function(group, strata, pool, data, call = caller_env()) {

strata <- tidyselect::vars_select(names(data), !!enquo(strata))

Expand All @@ -363,14 +364,11 @@ check_grouped_strata <- function(group, strata, pool, data) {

if (nrow(vctrs::vec_unique(grouped_table)) !=
nrow(vctrs::vec_unique(grouped_table["group"]))) {
cli_abort("{.arg strata} must be constant across all members of each {.arg group}.")
cli_abort(
"{.field strata} must be constant across all members of each {.field group}.",
hfrick marked this conversation as resolved.
Show resolved Hide resolved
call = call
)
}

strata
}

check_repeats <- function(repeats, call = rlang::caller_env()) {
if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) {
cli_abort("{.arg repeats} must be a single positive integer.", call = call)
}
}
hfrick marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
clustering_cv(iris, Sepal.Length, v = -500)
Condition
Error in `clustering_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number -500.

---

Expand All @@ -36,23 +36,23 @@
clustering_cv(Orange, v = 1, vars = "Tree")
Condition
Error in `clustering_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

---

Code
clustering_cv(Orange, repeats = 0)
Condition
Error in `clustering_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number larger than or equal to 1, not the number 0.

---

Code
clustering_cv(Orange, repeats = NULL)
Condition
Error in `clustering_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number, not `NULL`.

---

Expand Down
40 changes: 24 additions & 16 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@
! strata cannot be a <Surv> object.
i Use the time or event variable directly.

# bad args
# v arg is checked

Code
vfold_cv(iris, v = -500)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number -500.

---

Code
vfold_cv(iris, v = 1)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

---

Code
vfold_cv(iris, v = NULL)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number, not `NULL`.

---

Expand All @@ -76,36 +76,36 @@
---

Code
vfold_cv(iris, v = 150, repeats = 2)
vfold_cv(mtcars, v = nrow(mtcars))
Condition
Error in `vfold_cv()`:
! Repeated resampling when `v` is 150 would create identical resamples.
! Leave-one-out cross-validation is not supported by this function.
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
i Use `loo_cv()` in this case.

---
# repeats arg is checked

Code
vfold_cv(Orange, repeats = 0)
vfold_cv(iris, v = 150, repeats = 2)
Condition
Error in `vfold_cv()`:
! `repeats` must be a single positive integer.
! Repeated resampling when `v` is 150 would create identical resamples.

---

Code
vfold_cv(Orange, repeats = NULL)
vfold_cv(Orange, repeats = 0)
Condition
Error in `vfold_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number larger than or equal to 1, not the number 0.

---

Code
vfold_cv(mtcars, v = nrow(mtcars))
vfold_cv(Orange, repeats = NULL)
Condition
Error in `vfold_cv()`:
! Leave-one-out cross-validation is not supported by this function.
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
i Use `loo_cv()` in this case.
! `repeats` must be a whole number, not `NULL`.

# printing

Expand Down Expand Up @@ -191,7 +191,7 @@
group_vfold_cv(Orange, v = 1, group = "Tree")
Condition
Error in `group_vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

# grouping -- other balance methods

Expand Down Expand Up @@ -286,6 +286,14 @@
10 <split [96051/3949]> Resample10
# i 20 more rows

# grouping fails for strata not constant across group members

Code
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
Condition
Error in `group_vfold_cv()`:
! strata must be constant across all members of each group.

# grouping -- printing

Code
Expand Down
40 changes: 36 additions & 4 deletions tests/testthat/test-vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ test_that("strata arg is checked", {
})
})

test_that("bad args", {
test_that("v arg is checked", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = -500)
})
Expand All @@ -117,6 +117,12 @@ test_that("bad args", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = 500)
})
expect_snapshot(error = TRUE, {
vfold_cv(mtcars, v = nrow(mtcars))
})
})

test_that("repeats arg is checked", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = 150, repeats = 2)
})
Expand All @@ -126,9 +132,6 @@ test_that("bad args", {
expect_snapshot(error = TRUE, {
vfold_cv(Orange, repeats = NULL)
})
expect_snapshot(error = TRUE, {
vfold_cv(mtcars, v = nrow(mtcars))
})
})

test_that("printing", {
Expand Down Expand Up @@ -403,6 +406,35 @@ test_that("grouping -- strata", {
)
})

test_that("grouping fails for strata not constant across group members", {
set.seed(11)

n_common_class <- 70
n_rare_class <- 30

group_table <- tibble(
group = 1:100,
outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class)))
)
observation_table <- tibble(
group = sample(1:100, 1e5, replace = TRUE),
observation = 1:1e5
)
sample_data <- dplyr::full_join(
group_table,
observation_table,
by = "group",
multiple = "all"
)

# violate requirement
sample_data$outcome[1] <- ifelse(sample_data$outcome[1], 0, 1)

expect_snapshot(error = TRUE, {
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
})
})

test_that("grouping -- repeated", {
set.seed(11)
rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4)
Expand Down
Loading