diff --git a/R/validation.R b/R/validation.R index ad898e20..e268dac4 100644 --- a/R/validation.R +++ b/R/validation.R @@ -203,6 +203,51 @@ validate_surv_truth_list_estimate <- function(truth, call = call ) } + + all_eval_times_list <- lapply(estimate, function(x) x$.eval_time) + all_eval_times <- unlist(all_eval_times_list) + + if (any(is.na(all_eval_times))) { + cli::cli_abort( + c( + x = "Missing values in {.field .eval_time} are not allowed." + ), + call = call + ) + } + + if (any(all_eval_times < 0)) { + offenders <- unique(all_eval_times[all_eval_times < 0]) + + cli::cli_abort( + c( + x = "Negative values of {.field .eval_time} are not allowed.", + i = "The following negative values were found: {.val {offenders}}." + ), + call = call + ) + } + + if (any(is.infinite(all_eval_times))) { + cli::cli_abort( + c( + x = "Infinite values of {.field .eval_time} are not allowed." + ), + call = call + ) + } + + any_duplicates <- any( + vapply(all_eval_times_list, function(x) any(table(x) > 1), logical(1)) + ) + if (any_duplicates) { + cli::cli_abort( + c( + x = "Duplicate values of {.field .eval_time} are not allowed." + ), + call = call + ) + } } validate_surv_truth_numeric_estimate <- function(truth, diff --git a/tests/testthat/_snaps/validation.md b/tests/testthat/_snaps/validation.md index 9cc0bb3a..d81cdf04 100644 --- a/tests/testthat/_snaps/validation.md +++ b/tests/testthat/_snaps/validation.md @@ -249,6 +249,40 @@ Error: ! `estimate` should be a list, not a a double vector. +--- + + Code + validate_surv_truth_list_estimate(lung_surv_neg$surv_obj, lung_surv_neg$.pred) + Condition + Error: + x Negative values of .eval_time are not allowed. + i The following negative values were found: -100. + +--- + + Code + validate_surv_truth_list_estimate(lung_surv_na$surv_obj, lung_surv_na$.pred) + Condition + Error: + x Missing values in .eval_time are not allowed. + +--- + + Code + validate_surv_truth_list_estimate(lung_surv_inf$surv_obj, lung_surv_inf$.pred) + Condition + Error: + x Infinite values of .eval_time are not allowed. + +--- + + Code + validate_surv_truth_list_estimate(lung_surv_duplicate$surv_obj, + lung_surv_duplicate$.pred) + Condition + Error: + x Duplicate values of .eval_time are not allowed. + # validate_case_weights errors as expected Code diff --git a/tests/testthat/test-validation.R b/tests/testthat/test-validation.R index 0b373148..b4327b60 100644 --- a/tests/testthat/test-validation.R +++ b/tests/testthat/test-validation.R @@ -387,6 +387,46 @@ test_that("validate_surv_truth_list_estimate errors as expected", { lung_surv$.pred_time ) ) + + lung_surv_neg <- lung_surv + lung_surv_neg$.pred[[1]]$.eval_time[1] <- -100 + expect_snapshot( + error = TRUE, + validate_surv_truth_list_estimate( + lung_surv_neg$surv_obj, + lung_surv_neg$.pred + ) + ) + + lung_surv_na <- lung_surv + lung_surv_na$.pred[[1]]$.eval_time[1] <- NA + expect_snapshot( + error = TRUE, + validate_surv_truth_list_estimate( + lung_surv_na$surv_obj, + lung_surv_na$.pred + ) + ) + + lung_surv_inf <- lung_surv + lung_surv_inf$.pred[[1]]$.eval_time[1] <- Inf + expect_snapshot( + error = TRUE, + validate_surv_truth_list_estimate( + lung_surv_inf$surv_obj, + lung_surv_inf$.pred + ) + ) + + lung_surv_duplicate <- lung_surv + lung_surv_duplicate$.pred[[1]]$.eval_time[1] <- 200 + expect_snapshot( + error = TRUE, + validate_surv_truth_list_estimate( + lung_surv_duplicate$surv_obj, + lung_surv_duplicate$.pred + ) + ) }) test_that("validate_case_weights errors as expected", {