Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use hardhat::quantile_pred #332

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ Imports:
dplyr (>= 0.8.0.1),
generics,
glue,
hardhat (>= 1.1.0),
hardhat (>= 1.4.0.9002),
lifecycle,
mboost,
prodlim (>= 2023.03.31),
purrr,
rlang (>= 1.0.0),
stats,
tibble (>= 3.1.3),
tidyr (>= 1.0.0)
tidyr (>= 1.0.0),
vctrs
Suggests:
aorsf (>= 0.1.2),
coin,
Expand All @@ -48,6 +49,9 @@ Suggests:
rmarkdown,
rpart,
testthat (>= 3.0.0)
Remotes:
tidymodels/hardhat,
tidymodels/parsnip#1209
Config/Needs/website:
tidymodels,
tidyverse/tidytemplate
Expand Down
4 changes: 2 additions & 2 deletions R/censored-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ utils::globalVariables(
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group",
".pred_quantile", ".quantile", "interval", "level", ".pred_linear_pred",
".pred_link", ".pred_time", ".pred_survival", "next_event_time",
"sum_component", "time_interval"
"sum_component", "time_interval", "quantile_levels"
)
)

# quiet R-CMD-check NOTEs that prodlim is unused
# (parsnip uses it for all censored regression models
# (parsnip uses it for all censored regression models
# but only has it in Suggests)
#' @importFrom prodlim prodlim
NULL
12 changes: 6 additions & 6 deletions R/survival_reg-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ make_survival_reg_survival <- function() {
type = "quantile",
value = list(
pre = NULL,
post = survreg_quant,
post = parsnip::matrix_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile)
p = expr(quantile_levels)
)
)
)
Expand Down Expand Up @@ -236,14 +236,14 @@ make_survival_reg_flexsurv <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
Expand Down Expand Up @@ -393,14 +393,14 @@ make_survival_reg_flexsurvspline <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
Expand Down
45 changes: 44 additions & 1 deletion R/survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ flexsurv_post <- function(pred, object) {
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
pred
pred
}

flexsurv_rename_time <- function(pred){
Expand All @@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){
dplyr::rename(.eval_time = .time)
}
}

# ------------------------------------------------------------------------------
# Conversion of quantile predictions to the vctrs format

# For single quantile levels, flexsurv returns a data frame with column
# ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper"

# With mutiple quantile levels, flexsurv returns a data frame with a ".pred"
# column with co.lumns ".quantile" and ".pred_quantile" and perhaps
# ".pred_lower" and ".pred_upper"

flexsurv_to_quantile_pred <- function(x, object) {
# if one level, convert to nested format
if(!identical(names(x), ".pred")) {
# convert to the same format as predictions with mulitplel levels
x <- re_nest(x)
}

# Get column names to convert to vctrs encoding
nms <- names(x$.pred[[1]])
possible_cols <- c(".pred_quantile", ".pred_lower", ".pred_upper")
existing_cols <- intersect(possible_cols, nms)

# loop over prediction column names
res <- list()
for (col in existing_cols) {
res[[col]] <- purrr::map_vec(x$.pred, nested_df_iter, col = col)
}
tibble::new_tibble(res)

}

re_nest <- function(df) {
.row <- 1:nrow(df)
df <- vctrs::vec_split(df, by = .row)
df$key <- NULL
names(df) <- ".pred"
df
}

nested_df_iter <- function(df, col) {
hardhat::quantile_pred(matrix(df[[col]], nrow = 1), quantile_levels = df$.quantile)
}
55 changes: 21 additions & 34 deletions tests/testthat/test-survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ library(testthat)

test_that("model object", {
skip_if_not_installed("flexsurv")

set.seed(1234)
exp_f_fit <- flexsurv::flexsurvreg(
Surv(time, status) ~ age + ph.ecog,
Expand Down Expand Up @@ -149,7 +149,7 @@ test_that("survival probabilities for single eval time point", {

test_that("can predict for out-of-domain timepoints", {
skip_if_not_installed("flexsurv")

eval_time_obs_max_and_ood <- c(1022, 2000)
obs_without_NA <- lung[2,]

Expand Down Expand Up @@ -236,41 +236,28 @@ test_that("quantile predictions", {
)

expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), ".pred")
expect_equal(names(pred), ".pred_quantile")
expect_equal(nrow(pred), 3)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(dim(.x) == c(9, 2))
))
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(".quantile", ".pred_quantile"))
))
)
expect_equal(
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
do.call(rbind, exp_pred)$est
)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))

for (.row in 1:nrow(pred)) {
expect_equal(
unclass(pred$.pred_quantile[.row])[[1]],
exp_pred[[.row]]$est
)
}

# add confidence interval
pred <- predict(fit_s,
pred_ci <- predict(fit_s,
new_data = bladder[1:3, ], type = "quantile",
interval = "confidence", level = 0.7
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(
".quantile",
".pred_quantile",
".pred_lower",
".pred_upper"
))
))
)
expect_s3_class(pred_ci, "tbl_df")
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
expect_equal(nrow(pred_ci), 3)
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))

# single observation
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
Expand Down Expand Up @@ -354,7 +341,7 @@ test_that("hazard for single eval time point", {

test_that("`fix_xy()` works", {
skip_if_not_installed("flexsurv")

lung_x <- as.matrix(lung[, c("age", "ph.ecog")])
lung_y <- Surv(lung$time, lung$status)
lung_pred <- lung[1:5, ]
Expand Down Expand Up @@ -401,13 +388,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand Down
63 changes: 15 additions & 48 deletions tests/testthat/test-survival_reg-flexsurvspline.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ test_that("survival probability prediction", {
head(lung),
type = "survival",
times = c(0, 500, 1000)
)
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
Expand Down Expand Up @@ -211,59 +211,26 @@ test_that("quantile predictions", {
set_mode("censored regression") %>%
fit(Surv(stop, event) ~ rx + size + enum, data = bladder)
pred <- predict(fit_s, new_data = bladder[1:3, ], type = "quantile")

set.seed(1)
exp_fit <- flexsurv::flexsurvspline(
Surv(stop, event) ~ rx + size + enum,
data = bladder,
k = 1
)
exp_pred <- summary(
exp_fit,
newdata = bladder[1:3, ],
type = "quantile",
quantiles = (1:9) / 10
)

expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), ".pred")
expect_equal(names(pred), ".pred_quantile")
expect_equal(nrow(pred), 3)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(dim(.x) == c(9, 2))
))
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(".quantile", ".pred_quantile"))
))
)
expect_equal(
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
do.call(rbind, exp_pred)$est
)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))


# add confidence interval
pred <- predict(
pred_ci <- predict(
fit_s,
new_data = bladder[1:3, ],
type = "quantile",
interval = "confidence",
level = 0.7
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(
".quantile",
".pred_quantile",
".pred_lower",
".pred_upper"
))
))
)
expect_s3_class(pred_ci, "tbl_df")
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
expect_equal(nrow(pred_ci), 3)
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))

# single observation
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
Expand All @@ -284,7 +251,7 @@ test_that("hazard prediction", {
head(lung),
type = "hazard",
times = c(0, 500, 1000)
)
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
Expand Down Expand Up @@ -409,13 +376,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand All @@ -438,7 +405,7 @@ test_that("`fix_xy()` works", {

test_that("can handle case weights", {
skip_if_not_installed("flexsurv")

# flexsurv engine can only take weights > 0
set.seed(1)
wts <- runif(nrow(lung))
Expand Down
Loading
Loading