diff --git a/NEWS.md b/NEWS.md index 48b35b52..633858c9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -24,6 +24,8 @@ calculated with `roc_auc_survival()`. * Metrics now throw more informative error if `estimate` argument is wrongly used. (#443) +* Curve metrics now throw an informative error instead of returning `NA` when missing values are found and `na_rm = FALSE`. (#344) + # yardstick 1.2.0 ## New Metrics diff --git a/R/prob-gain_curve.R b/R/prob-gain_curve.R index 569f9eae..19249c0e 100644 --- a/R/prob-gain_curve.R +++ b/R/prob-gain_curve.R @@ -148,7 +148,10 @@ gain_curve_vec <- function(truth, estimate <- result$estimate case_weights <- result$case_weights } else if (yardstick_any_missing(truth, estimate, case_weights)) { - return(NA_real_) + cli::cli_abort(c( + x = "Missing values were detected and {.code na_ra = FALSE}.", + i = "Not able to perform calculations." + )) } gain_curve_estimator_impl( diff --git a/R/prob-pr_curve.R b/R/prob-pr_curve.R index c666adae..36067c19 100644 --- a/R/prob-pr_curve.R +++ b/R/prob-pr_curve.R @@ -104,7 +104,10 @@ pr_curve_vec <- function(truth, estimate <- result$estimate case_weights <- result$case_weights } else if (yardstick_any_missing(truth, estimate, case_weights)) { - return(NA_real_) + cli::cli_abort(c( + x = "Missing values were detected and {.code na_ra = FALSE}.", + i = "Not able to perform calculations." + )) } pr_curve_estimator_impl( diff --git a/R/prob-roc_curve.R b/R/prob-roc_curve.R index 46786a4e..73dbc084 100644 --- a/R/prob-roc_curve.R +++ b/R/prob-roc_curve.R @@ -111,7 +111,10 @@ roc_curve_vec <- function(truth, estimate <- result$estimate case_weights <- result$case_weights } else if (yardstick_any_missing(truth, estimate, case_weights)) { - return(NA_real_) + cli::cli_abort(c( + x = "Missing values were detected and {.code na_ra = FALSE}.", + i = "Not able to perform calculations." + )) } # estimate here is a matrix of class prob columns diff --git a/R/surv-roc_curve_survival.R b/R/surv-roc_curve_survival.R index 6ac25c38..1a0edfc9 100644 --- a/R/surv-roc_curve_survival.R +++ b/R/surv-roc_curve_survival.R @@ -118,13 +118,11 @@ roc_curve_survival_vec <- function(truth, truth <- result$truth estimate <- estimate[result$estimate] case_weights <- result$case_weights - } else { - any_missing <- yardstick_any_missing( - truth, estimate, case_weights - ) - if (any_missing) { - return(NA_real_) - } + } else if (yardstick_any_missing(truth, estimate, case_weights)) { + cli::cli_abort(c( + x = "Missing values were detected and {.code na_ra = FALSE}.", + i = "Not able to perform calculations." + )) } roc_curve_survival_impl( diff --git a/tests/testthat/_snaps/template.md b/tests/testthat/_snaps/template.md index 88c0743a..38826bb9 100644 --- a/tests/testthat/_snaps/template.md +++ b/tests/testthat/_snaps/template.md @@ -100,6 +100,16 @@ x This metric doesn't use the `estimate` argument. i Specify the columns without `estimate = `. +# curve_metric_summarizer()'s na_rm argument work + + Code + curve_metric_summarizer(name = "roc_curve", fn = roc_curve_vec, data = hpc_f1_na, + truth = obs, VF:L, na_rm = FALSE, case_weights = NULL) + Condition + Error: + x Missing values were detected and `na_ra = FALSE`. + i Not able to perform calculations. + # curve_metric_summarizer()'s errors when wrong things are passes Code @@ -207,6 +217,16 @@ x Problematic argument: * obviouslywrong = TRUE +# curve_survival_metric_summarizer()'s na_rm argument works + + Code + curve_survival_metric_summarizer(name = "roc_curve_survival", fn = roc_curve_survival_vec, + data = lung_surv, truth = surv_obj, .pred, na_rm = FALSE, case_weights = NULL) + Condition + Error: + x Missing values were detected and `na_ra = FALSE`. + i Not able to perform calculations. + # curve_survival_metric_summarizer()'s errors with bad input Code diff --git a/tests/testthat/test-template.R b/tests/testthat/test-template.R index 84358192..8d4642d1 100644 --- a/tests/testthat/test-template.R +++ b/tests/testthat/test-template.R @@ -899,23 +899,18 @@ test_that("curve_metric_summarizer()'s na_rm argument work", { expect_identical(roc_curve_res, roc_curve_exp) - roc_curve_res <- curve_metric_summarizer( - name = "roc_curve", - fn = roc_curve_vec, - data = hpc_f1_na, - truth = obs, - VF:L, - na_rm = FALSE, - case_weights = NULL - ) - - roc_curve_exp <- dplyr::tibble( - .metric = "roc_curve", - .estimator = "multiclass", - .estimate = na_dbl + expect_snapshot( + error = TRUE, + curve_metric_summarizer( + name = "roc_curve", + fn = roc_curve_vec, + data = hpc_f1_na, + truth = obs, + VF:L, + na_rm = FALSE, + case_weights = NULL + ) ) - - expect_identical(roc_curve_res, roc_curve_exp) }) test_that("curve_metric_summarizer()'s case_weights argument work", { @@ -1613,23 +1608,18 @@ test_that("curve_survival_metric_summarizer()'s na_rm argument works", { expect_identical(roc_curve_survival_res, roc_curve_survival_exp) - roc_curve_survival_res <- curve_survival_metric_summarizer( - name = "roc_curve_survival", - fn = roc_curve_survival_vec, - data = lung_surv, - truth = surv_obj, - .pred, - na_rm = FALSE, - case_weights = NULL - ) - - roc_curve_survival_exp <- dplyr::tibble( - .metric = "roc_curve_survival", - .estimator = "standard", - .estimate = na_dbl + expect_snapshot( + error = TRUE, + curve_survival_metric_summarizer( + name = "roc_curve_survival", + fn = roc_curve_survival_vec, + data = lung_surv, + truth = surv_obj, + .pred, + na_rm = FALSE, + case_weights = NULL + ) ) - - expect_identical(roc_curve_survival_res, roc_curve_survival_exp) }) test_that("curve_survival_metric_summarizer()'s case_weights argument works", {