Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
unit tests for #170
Browse files Browse the repository at this point in the history
topepo committed Jan 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 3af5b6b commit e411810
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/testthat/test-survival-tune-compute-metrics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(censored))

skip_if_not_installed("parsnip", minimum_version = "1.1.0.9003")
skip_if_not_installed("censored", minimum_version = "0.2.0.9000")
skip_if_not_installed("tune", minimum_version = "1.1.2.9014")
skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001")

test_that("compute_metrics works with survival models", {
lung_surv <- lung %>%
dplyr::mutate(surv = Surv(time, status), .keep = "unused")

metrics <- metric_set(concordance_survival, brier_survival_integrated, brier_survival)

times <- c(2, 50, 100)

set.seed(2193)
tune_res <-
proportional_hazards(penalty = tune(), engine = "glmnet") %>%
tune_grid(
surv ~ .,
resamples = vfold_cv(lung_surv, 2),
grid = tibble(penalty = c(0.001, 0.1)),
control = control_grid(save_pred = TRUE),
metrics = metrics,
eval_time = times
)

# ------------------------------------------------------------------------------

recomp <- compute_metrics(tune_res, metrics, summarize = TRUE)
original <- collect_metrics(tune_res)
expect_equal(recomp, original)

# ------------------------------------------------------------------------------

recomp_rs <- compute_metrics(tune_res, metrics, summarize = TRUE)
original_rs <- collect_metrics(tune_res, summarize = TRUE)
expect_equal(recomp_rs, original_rs)

# ------------------------------------------------------------------------------

stc_only <- compute_metrics(tune_res, metric_set(concordance_survival), summarize = TRUE)
stc_original <- collect_metrics(tune_res) %>% filter(.metric == "concordance_survival")
expect_equal(stc_only, stc_original %>% select(-.eval_time))

})

0 comments on commit e411810

Please sign in to comment.