Skip to content

Commit

Permalink
CV+ conformal intervals (#85)
Browse files Browse the repository at this point in the history
* add conformal intervals for regression models

* fix reference

* missing pkgdown entry

* starts_with 🙄

* another missing topic

* split out a predict() method

* add cv method

* cv+ estimator

* update docs and tests

* another doc update

* update snapshots

* Apply suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* update snapshots

* small descr update

* remove redundant tests

* use call arg when stopping

* doc update

* are not using glue

* another snapshot refresh

* fix snapshot

* remove test to see if basic argument passing works

* updated with current CRAN versions

---------

Co-authored-by: Simon P. Couch <[email protected]>
  • Loading branch information
topepo and simonpcouch authored Mar 21, 2023
1 parent 9c582b1 commit 7c1b36f
Show file tree
Hide file tree
Showing 14 changed files with 592 additions and 29 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Description: Models can be improved by post-processing class
probabilities, by: recalibration, conversion to hard probabilities,
assessment of equivocal zones, and other activities. 'probably'
contains tools for conducting these operations as well as calibration
tools for regression models.
tools and conformal inference techniques for regression models.
License: MIT + file LICENSE
URL: https://github.com/tidymodels/probably/,
https://probably.tidymodels.org
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ S3method(cal_validate_summarize,cal_rset)
S3method(format,class_pred)
S3method(int_conformal_infer,default)
S3method(int_conformal_infer,workflow)
S3method(int_conformal_infer_cv,default)
S3method(int_conformal_infer_cv,resample_results)
S3method(int_conformal_infer_cv,tune_results)
S3method(is_equivocal,class_pred)
S3method(is_equivocal,default)
S3method(is_ordered,class_pred)
Expand All @@ -58,12 +61,14 @@ S3method(obj_print_data,class_pred)
S3method(obj_print_footer,class_pred)
S3method(obj_print_header,class_pred)
S3method(predict,int_conformal_infer)
S3method(predict,int_conformal_infer_cv)
S3method(print,cal_binary)
S3method(print,cal_estimate_isotonic)
S3method(print,cal_estimate_linear)
S3method(print,cal_multi)
S3method(print,cal_regression)
S3method(print,int_conformal_infer)
S3method(print,int_conformal_infer_cv)
S3method(reportable_rate,class_pred)
S3method(reportable_rate,default)
S3method(required_pkgs,cal_estimate_beta)
Expand Down Expand Up @@ -121,6 +126,7 @@ export(class_pred)
export(control_conformal_infer)
export(fit)
export(int_conformal_infer)
export(int_conformal_infer_cv)
export(is_class_pred)
export(is_equivocal)
export(make_class_pred)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

* Based on the initial PR (#37) by Antonio R. Vargas, `threshold_perf()` now accepts a custom metric set (#25)

* Two functions were added to compute prediction intervals for regression models via conformal inference:

* `int_conformal_infer()`
* `int_conformal_infer_cv()`

# probably 0.1.0

* Max Kuhn is now the maintainer (#49).
Expand Down
11 changes: 7 additions & 4 deletions R/conformal_infer.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
#' distributional quantile, the value is one of the bounds.
#'
#' The literature proposed using a grid search of trial values to find the two
#' points that correspond to the prediction intervals. To use this apprach,
#' points that correspond to the prediction intervals. To use this approach,
#' set `method = "grid"` in [control_conformal_infer()]. However, the default method
#' `"search` uses two different one-dimensional iterative searches on either
#' side of the predicted value to find values that correspond to the prediction intervals.

#'
#' For medium to large data sets, the iterative search method is likely to
#' generate slightly smaller intervals. For small training sets, grid search
Expand Down Expand Up @@ -104,8 +105,9 @@ int_conformal_infer.workflow <-
#' @export
print.int_conformal_infer <- function(x, ...) {
cat("Conformal inference\n")
# cat("preprocessor:", .get_pre_type(x$wflow), "\n")
# cat("model:", .get_fit_type(x$wflow), "\n")

cat("preprocessor:", .get_pre_type(x$wflow), "\n")
cat("model:", .get_fit_type(x$wflow), "\n")
cat("training set size:", format(nrow(x$training), big.mark = ","), "\n\n")

cat("Use `predict(object, new_data, level)` to compute prediction intervals\n")
Expand All @@ -121,7 +123,7 @@ print.int_conformal_infer <- function(x, ...) {
#' @param ... Not currently used.
#' @return A tibble with columns `.pred_lower` and `.pred_upper`. If
#' the computations for the prediction bound fail, a missing value is used.
#' @seealso [int_conformal_infer()]
#' @seealso [int_conformal_infer()], [int_conformal_infer_cv()]
#' @export
predict.int_conformal_infer <- function(object, new_data, level = 0.95, ...) {
check_data(new_data, object$wflow)
Expand Down Expand Up @@ -202,6 +204,7 @@ check_workflow <- function(x, call = rlang::caller_env()) {

var_model <- function(object, train_data, call = caller_env()) {


y_name <- get_outcome_name(object)

train_res <- predict(object, train_data)
Expand Down
280 changes: 280 additions & 0 deletions R/conformal_infer_cv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
#' Prediction intervals via conformal inference CV+
#'
#' Nonparametric prediction intervals can be computed for fitted regression
#' workflow objects using the CV+ conformal inference method described by
#' Barber _at al_ (2018).
#'
#' @param object An object from a tidymodels resampling or tuning function such
#' as [tune::fit_resamples()], [tune::tune_grid()], or similar. The object
#' should have been produced in a way that the `.extracts` column contains the
#' fitted workflow for each resample (see the Details below).
#' @param parameters An tibble of tuning parameter values that can be
#' used to filter the predicted values before processing. This tibble should
#' select a single set of hyper-parameter values from the tuning results. This is
#' only required when a tuning object is passed to `object`.
#' @param ... Not currently used.
#' @return An object of class `"int_conformal_infer_cv"` containing the information
#' to create intervals. The `predict()` method is used to produce the intervals.
#' @details
#' This function implements the CV+ method found in Section 3 of Barber _at al_
#' (2018). It uses the resampled model fits and their associated holdout
#' residuals to make prediction intervals for regression models.
#'
#' This function prepares the objects for the computations. The [predict()]
#' method computes the intervals for new data.
#'
#' This method was developed for V-fold cross-validation (no repeats). Interval
#' coverage is unknown for any other resampling methods. The function will not
#' stop the computations for other types of resamples, but we have no way of
#' knowing whether the results are appropriate.
#'
#' @seealso [predict.int_conformal_infer_cv()]
#' @references
#' Rina Foygel Barber, Emmanuel J. Candès, Aaditya Ramdas, Ryan J. Tibshirani
#' "Predictive inference with the jackknife+," _The Annals of Statistics_,
#' 49(1), 486-507, 2021
#' @examplesIf !probably:::is_cran_check()
#' library(workflows)
#' library(dplyr)
#' library(parsnip)
#' library(rsample)
#' library(tune)
#' library(modeldata)
#'
#' set.seed(2)
#' sim_train <- sim_regression(200)
#' sim_new <- sim_regression( 5) %>% select(-outcome)
#'
#' sim_rs <- vfold_cv(sim_train)
#'
#' # We'll use a neural network model
#' mlp_spec <-
#' mlp(hidden_units = 5, penalty = 0.01) %>%
#' set_mode("regression")
#'
#' # Use a control function that saves the predictions as well as the models.
#' # Consider using the butcher package in the extracts function to have smaller
#' # object sizes
#'
#' ctrl <- control_resamples(save_pred = TRUE, extract = I)
#'
#' set.seed(3)
#' nnet_res <-
#' mlp_spec %>%
#' fit_resamples(outcome ~ ., resamples = sim_rs, control = ctrl)
#'
#' nnet_int_obj <- int_conformal_infer_cv(nnet_res)
#' nnet_int_obj
#'
#' predict(nnet_int_obj, sim_new)
#' @export
int_conformal_infer_cv <- function(object, ...) {
UseMethod("int_conformal_infer_cv")
}


#' @export
#' @rdname int_conformal_infer_cv
int_conformal_infer_cv.default <- function(object, ...) {
rlang::abort("No known 'int_conformal_infer_cv' methods for this type of object.")
}

#' @export
#' @rdname int_conformal_infer_cv
int_conformal_infer_cv.resample_results <- function(object, ...) {
check_resampling(object)
check_extras(object)

model_list <- .get_fitted_workflows(object)

y_name <- tune::.get_tune_outcome_names(object)
resids <-
tune::collect_predictions(object, summarize = TRUE) %>%
dplyr::mutate(.abs_resid = abs(.pred - !!rlang::sym(y_name)))

new_infer_cv(model_list, resids$.abs_resid)
}

#' @export
#' @rdname int_conformal_infer_cv
int_conformal_infer_cv.tune_results <- function(object, parameters, ...) {
check_resampling(object)
check_parameters(object, parameters)
check_extras(object)

model_list <- .get_fitted_workflows(object, parameters)
y_name <- tune::.get_tune_outcome_names(object)

resids <-
tune::collect_predictions(object, parameters = parameters, summarize = TRUE) %>%
dplyr::mutate(.abs_resid = abs(.pred - !!rlang::sym(y_name)))

new_infer_cv(model_list, resids$.abs_resid)
}

#' @export
#' @rdname predict.int_conformal_infer
predict.int_conformal_infer_cv <- function(object, new_data, level = 0.95, ...) {
mean_pred <-
purrr::map_dfr(
object$models,
~ predict(.x, new_data) %>% parsnip::add_rowindex()
) %>%
dplyr::group_by(.row) %>%
dplyr::summarize(estimate = mean(.pred, na.rm = TRUE), .groups = "drop") %>%
purrr::pluck("estimate")
lower <-
purrr::map_dbl(
as.list(seq_along(mean_pred)),
~ .get_lower_cv_bound(mean_pred[.x], object$abs_resid, level = level)
)
upper <-
purrr::map_dbl(
as.list(seq_along(mean_pred)),
~ .get_upper_cv_bound(mean_pred[.x], object$abs_resid, level = level)
)
dplyr::tibble(.pred_lower = lower, .pred_upper = upper)
}

#' @export
print.int_conformal_infer_cv <- function(x, ...) {
cat("Conformal inference via CV+\n")
cat("preprocessor:", .get_pre_type(x$models[[1]]), "\n")
cat("model:", .get_fit_type(x$models[[1]]), "\n")
cat("number of models:", format(length(x$models), big.mark = ","), "\n")
cat("training set size:", format(length(x$abs_resid), big.mark = ","), "\n\n")

cat("Use `predict(object, new_data, level)` to compute prediction intervals\n")
invisible(x)
}

# ------------------------------------------------------------------------------
# helpers

new_infer_cv <- function(models, resid) {
if (!is.numeric(resid)) {
rlang::abort("Absolute residuals should be numeric")
}
na_resid <- is.na(resid)
if (all(na_resid)) {
rlang::abort("All of the absolute residuals are missing.")
}

if (!is.list(models)) {
rlang::abort("The model list should be... a list")
}
is_wflow <- purrr::map_lgl(models, workflows::is_trained_workflow)
if (all(!is_wflow)) {
rlang::abort(".extracts does not contain fitted workflows")
}
if (any(!is_wflow)) {
models <- models[is_wflow]
}

res <- list(
models = models,
abs_resid = resid[!na_resid]
)
class(res) <- "int_conformal_infer_cv"
res
}

.get_lower_cv_bound <- function(pred, resid, level = 0.95) {
as.vector(stats::quantile(pred - resid, probs = 1 - level))
}

.get_upper_cv_bound <- function(pred, resid, level = 0.95) {
as.vector(stats::quantile(pred + resid, probs = level))
}

.get_pre_type <- function(x) {
cls <- x %>% workflows::extract_preprocessor() %>% class()
cls <- cls[!grepl("butchered", cls)]
cls[1]
}

.get_fit_type <- function(x) {
fitted <- x %>% workflows::extract_fit_parsnip()
res <- paste0(class(fitted$spec)[1], " (engine = ", fitted$spec$engine, ")")
res
}

.get_fitted_workflows <- function(x, prm = NULL) {
if (is.null(prm)) {
res <- purrr::map(x$.extracts, ~ .x$.extracts[[1]])
} else {
by_vars <- names(prm)
res <-
x %>%
dplyr::select(.extracts) %>%
tidyr::unnest(.extracts) %>%
dplyr::inner_join(prm, by = by_vars) %>%
purrr::pluck(".extracts")
}
res
}

# ------------------------------------------------------------------------------
# checks

check_resampling <- function(x) {
rs <- attr(x, "rset_info")
if (rs$att$class != "vfold_cv") {
msg <- paste0(
"The data were resampled using ", rs$label,
". This method was developed for V-fold cross-validation. Interval ",
"coverage is unknown for your resampling method."
)
rlang::warn(msg)
} else {
if (rs$att$repeats > 1) {
msg <- paste0(
rs$att$repeats, " repeats were used. This method was developed for ",
"basic V-fold cross-validation. Interval coverage is unknown for multiple ",
"repeats."
)
rlang::warn(msg)
}
}
invisible(NULL)
}

check_parameters <- function(x, param, call = rlang::caller_env()) {
prms <- tune::.get_tune_parameter_names(x)
mtr <- tune::collect_metrics(x) %>%
dplyr::distinct(.config, !!!rlang::syms(prms))
remain <- dplyr::inner_join(mtr, param, by = names(param))
if (nrow(remain) > 1) {
msg <-
paste0(
"The `parameters` argument selected ", nrow(remain), " submodels. Only ",
"1 should be selected."
)
rlang::abort(msg, call = call)
}
invisible(NULL)
}

check_extras <- function(x, call = rlang::caller_env()) {
if (!any(names(x) == ".extracts")) {
msg <-
paste0(
"The output must contain a column called '.extracts' that contains the ",
"fitted workflow objects. See the documentation on the 'extract' ",
"argument of the control function (e.g., `control_grid()` or ",
"`control_resamples()`, etc.)."
)
rlang::abort(msg)
}
if (!any(names(x) == ".predictions")) {
msg <-
paste0(
"The output must contain a column called '.predictions' that contains the ",
"holdout predictions. See the documentation on the 'save_pred' ",
"argument of the control function (e.g., `control_grid()` or ",
"`control_resamples()`, etc.)."
)
rlang::abort(msg, cal = call)
}
invisible(NULL)
}
2 changes: 1 addition & 1 deletion R/probably-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ NULL
utils::globalVariables(c(
".bin", ".is_val", "event_rate", "events", "lower",
"predicted_midpoint", "total", "upper", ".config",
".adj_estimate", ".rounded", '.pred', '.bound', 'pred_val'
".adj_estimate", ".rounded", '.pred', '.bound', 'pred_val', '.extracts'
))
12 changes: 12 additions & 0 deletions R/reexports.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ generics::augment
#' @export
generics::required_pkgs

# from tune
# nocov start

is_cran_check <- function () {
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
FALSE
} else {
Sys.getenv("_R_CHECK_PACKAGE_NAME_", "") != ""
}
}

#nocov end
Loading

0 comments on commit 7c1b36f

Please sign in to comment.