Skip to content

Commit

Permalink
Merge pull request #439 from tidymodels/roc_curve_survival-case_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Nov 19, 2023
2 parents 00ea962 + 1c8bfe6 commit c0a820e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 11 deletions.
7 changes: 6 additions & 1 deletion R/surv-roc_auc_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ roc_auc_survival_vec <- function(truth,
case_weights = NULL,
...) {
# No checking since roc_curve_survival_vec() does checking
curve <- roc_curve_survival_vec(truth, estimate)
curve <- roc_curve_survival_vec(
truth = truth,
estimate = estimate,
na_rm = na_rm,
case_weights = case_weights
)

curve %>%
dplyr::group_by(.eval_time) %>%
Expand Down
45 changes: 35 additions & 10 deletions R/surv-roc_curve_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,35 @@ roc_curve_survival_vec <- function(truth,
}
}

roc_curve_survival_impl(truth = truth, estimate = estimate)
roc_curve_survival_impl(
truth = truth,
estimate = estimate,
case_weights = case_weights
)
}

roc_curve_survival_impl <- function(truth,
estimate) {
estimate,
case_weights) {
event_time <- .extract_surv_time(truth)
delta <- .extract_surv_status(truth)
data <- dplyr::tibble(event_time, delta, estimate)
case_weights <- vec_cast(case_weights, double())
if (is.null(case_weights)) {
case_weights <- rep(1, length(delta))
}

# Drop any `0` weights.
# These shouldn't affect the result, but can result in wrong thresholds
detect_zero_weight <- case_weights == 0
if (any(detect_zero_weight)) {
detect_non_zero_weight <- !detect_zero_weight
event_time <- event_time[detect_non_zero_weight]
delta <- delta[detect_non_zero_weight]
case_weights <- case_weights[detect_non_zero_weight]
estimate <- estimate[detect_non_zero_weight]
}

data <- dplyr::tibble(event_time, delta, case_weights, estimate)
data <- tidyr::unnest(data, cols = estimate)

.eval_times <- unique(data$.eval_time)
Expand All @@ -149,7 +170,8 @@ roc_curve_survival_impl <- function(truth,
res <- roc_curve_survival_impl_one(
data$event_time[.eval_time_ind],
data$delta[.eval_time_ind],
data[.eval_time_ind, ]
data[.eval_time_ind, ],
data$case_weights[.eval_time_ind]
)

res$.eval_time <- .eval_times[[i]]
Expand All @@ -159,30 +181,33 @@ roc_curve_survival_impl <- function(truth,
dplyr::bind_rows(out)
}

roc_curve_survival_impl_one <- function(event_time, delta, data) {
roc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) {
res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf))))

obs_time_le_time <- event_time <= data$.eval_time
obs_time_gt_time <- event_time > data$.eval_time
n <- nrow(data)
multiplier <- delta / (n * data$.weight_censored)

sensitivity_denom <- sum(obs_time_le_time * multiplier, na.rm = TRUE)
specificity_denom <- sum(obs_time_gt_time, na.rm = TRUE)
sensitivity_denom <- sum(obs_time_le_time * multiplier * case_weights, na.rm = TRUE)
specificity_denom <- sum(obs_time_gt_time * case_weights, na.rm = TRUE)

data_df <- data.frame(
le_time = obs_time_le_time,
ge_time = obs_time_gt_time,
multiplier = multiplier
multiplier = multiplier,
case_weights = case_weights
)

data_split <- vec_split(data_df, data$.pred_survival)
data_split <- data_split$val[order(data_split$key)]

sensitivity <- vapply(
data_split,
function(x) sum(x$le_time * x$multiplier, na.rm = TRUE),
function(x) sum(x$le_time * x$multiplier * x$case_weights, na.rm = TRUE),
FUN.VALUE = numeric(1)
)

sensitivity <- cumsum(sensitivity)
sensitivity <- sensitivity / sensitivity_denom
sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity)
Expand All @@ -192,7 +217,7 @@ roc_curve_survival_impl_one <- function(event_time, delta, data) {

specificity <- vapply(
data_split,
function(x) sum(x$ge_time, na.rm = TRUE),
function(x) sum(x$ge_time * x$case_weights, na.rm = TRUE),
FUN.VALUE = numeric(1)
)
specificity <- cumsum(specificity)
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-surv-roc_auc_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ test_that("roc_curve_auc() calculations", {
)
})

# case weights -----------------------------------------------------------------

test_that("case weights are applied", {
wts_res <- lung_surv %>%
dplyr::mutate(wts = hardhat::frequency_weights(rep(1:0, c(128, 100)))) %>%
roc_auc_survival(
truth = surv_obj,
.pred,
case_weights = wts
)

subset_res <- lung_surv %>%
dplyr::slice(1:128) %>%
roc_auc_survival(
truth = surv_obj,
.pred
)

expect_identical(subset_res, wts_res)
})

# self checking ----------------------------------------------------------------

test_that("snapshot equivalent", {
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-surv-roc_curve_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ test_that("roc_curve_survival works", {
}
})

# case weights -----------------------------------------------------------------

test_that("case weights are applied", {
wts_res <- lung_surv %>%
dplyr::mutate(wts = hardhat::frequency_weights(rep(1:0, c(128, 100)))) %>%
roc_curve_survival(
truth = surv_obj,
.pred,
case_weights = wts
)

subset_res <- lung_surv %>%
dplyr::slice(1:128) %>%
roc_curve_survival(
truth = surv_obj,
.pred
)

expect_identical(subset_res, wts_res)
})

# self checking ----------------------------------------------------------------

test_that("snapshot equivalent", {
Expand Down

0 comments on commit c0a820e

Please sign in to comment.