diff --git a/R/surv-roc_auc_survival.R b/R/surv-roc_auc_survival.R index da00a25a..cea06a43 100644 --- a/R/surv-roc_auc_survival.R +++ b/R/surv-roc_auc_survival.R @@ -94,7 +94,12 @@ roc_auc_survival_vec <- function(truth, case_weights = NULL, ...) { # No checking since roc_curve_survival_vec() does checking - curve <- roc_curve_survival_vec(truth, estimate) + curve <- roc_curve_survival_vec( + truth = truth, + estimate = estimate, + na_rm = na_rm, + case_weights = case_weights + ) curve %>% dplyr::group_by(.eval_time) %>% diff --git a/R/surv-roc_curve_survival.R b/R/surv-roc_curve_survival.R index 246a78c2..8ecaf4fc 100644 --- a/R/surv-roc_curve_survival.R +++ b/R/surv-roc_curve_survival.R @@ -128,14 +128,35 @@ roc_curve_survival_vec <- function(truth, } } - roc_curve_survival_impl(truth = truth, estimate = estimate) + roc_curve_survival_impl( + truth = truth, + estimate = estimate, + case_weights = case_weights + ) } roc_curve_survival_impl <- function(truth, - estimate) { + estimate, + case_weights) { event_time <- .extract_surv_time(truth) delta <- .extract_surv_status(truth) - data <- dplyr::tibble(event_time, delta, estimate) + case_weights <- vec_cast(case_weights, double()) + if (is.null(case_weights)) { + case_weights <- rep(1, length(delta)) + } + + # Drop any `0` weights. + # These shouldn't affect the result, but can result in wrong thresholds + detect_zero_weight <- case_weights == 0 + if (any(detect_zero_weight)) { + detect_non_zero_weight <- !detect_zero_weight + event_time <- event_time[detect_non_zero_weight] + delta <- delta[detect_non_zero_weight] + case_weights <- case_weights[detect_non_zero_weight] + estimate <- estimate[detect_non_zero_weight] + } + + data <- dplyr::tibble(event_time, delta, case_weights, estimate) data <- tidyr::unnest(data, cols = estimate) .eval_times <- unique(data$.eval_time) @@ -149,7 +170,8 @@ roc_curve_survival_impl <- function(truth, res <- roc_curve_survival_impl_one( data$event_time[.eval_time_ind], data$delta[.eval_time_ind], - data[.eval_time_ind, ] + data[.eval_time_ind, ], + data$case_weights[.eval_time_ind] ) res$.eval_time <- .eval_times[[i]] @@ -159,7 +181,7 @@ roc_curve_survival_impl <- function(truth, dplyr::bind_rows(out) } -roc_curve_survival_impl_one <- function(event_time, delta, data) { +roc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) { res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)))) obs_time_le_time <- event_time <= data$.eval_time @@ -167,22 +189,25 @@ roc_curve_survival_impl_one <- function(event_time, delta, data) { n <- nrow(data) multiplier <- delta / (n * data$.weight_censored) - sensitivity_denom <- sum(obs_time_le_time * multiplier, na.rm = TRUE) - specificity_denom <- sum(obs_time_gt_time, na.rm = TRUE) + sensitivity_denom <- sum(obs_time_le_time * multiplier * case_weights, na.rm = TRUE) + specificity_denom <- sum(obs_time_gt_time * case_weights, na.rm = TRUE) data_df <- data.frame( le_time = obs_time_le_time, ge_time = obs_time_gt_time, - multiplier = multiplier + multiplier = multiplier, + case_weights = case_weights ) + data_split <- vec_split(data_df, data$.pred_survival) data_split <- data_split$val[order(data_split$key)] sensitivity <- vapply( data_split, - function(x) sum(x$le_time * x$multiplier, na.rm = TRUE), + function(x) sum(x$le_time * x$multiplier * x$case_weights, na.rm = TRUE), FUN.VALUE = numeric(1) ) + sensitivity <- cumsum(sensitivity) sensitivity <- sensitivity / sensitivity_denom sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity) @@ -192,7 +217,7 @@ roc_curve_survival_impl_one <- function(event_time, delta, data) { specificity <- vapply( data_split, - function(x) sum(x$ge_time, na.rm = TRUE), + function(x) sum(x$ge_time * x$case_weights, na.rm = TRUE), FUN.VALUE = numeric(1) ) specificity <- cumsum(specificity) diff --git a/tests/testthat/test-surv-roc_auc_survival.R b/tests/testthat/test-surv-roc_auc_survival.R index b0b0c9e4..afb6454a 100644 --- a/tests/testthat/test-surv-roc_auc_survival.R +++ b/tests/testthat/test-surv-roc_auc_survival.R @@ -21,6 +21,27 @@ test_that("roc_curve_auc() calculations", { ) }) +# case weights ----------------------------------------------------------------- + +test_that("case weights are applied", { + wts_res <- lung_surv %>% + dplyr::mutate(wts = hardhat::frequency_weights(rep(1:0, c(128, 100)))) %>% + roc_auc_survival( + truth = surv_obj, + .pred, + case_weights = wts + ) + + subset_res <- lung_surv %>% + dplyr::slice(1:128) %>% + roc_auc_survival( + truth = surv_obj, + .pred + ) + + expect_identical(subset_res, wts_res) +}) + # self checking ---------------------------------------------------------------- test_that("snapshot equivalent", { diff --git a/tests/testthat/test-surv-roc_curve_survival.R b/tests/testthat/test-surv-roc_curve_survival.R index 255623c4..d15a41ac 100644 --- a/tests/testthat/test-surv-roc_curve_survival.R +++ b/tests/testthat/test-surv-roc_curve_survival.R @@ -37,6 +37,27 @@ test_that("roc_curve_survival works", { } }) +# case weights ----------------------------------------------------------------- + +test_that("case weights are applied", { + wts_res <- lung_surv %>% + dplyr::mutate(wts = hardhat::frequency_weights(rep(1:0, c(128, 100)))) %>% + roc_curve_survival( + truth = surv_obj, + .pred, + case_weights = wts + ) + + subset_res <- lung_surv %>% + dplyr::slice(1:128) %>% + roc_curve_survival( + truth = surv_obj, + .pred + ) + + expect_identical(subset_res, wts_res) +}) + # self checking ---------------------------------------------------------------- test_that("snapshot equivalent", {