Skip to content

Commit

Permalink
Adds validation functions (#63)
Browse files Browse the repository at this point in the history
* Initial sketch of grouped support for estimates

* Creates custom group splitting function to create a list of data frame section that then get processed by the model iteratively

* Finalizes estimate object

* Adds support for new cal_object structure in cal_apply()

* Adds support to modify the prediction

* Adds tune_result method to logistic estimator

* Adds method for tune_results for cal_apply()

* Corrects the target prediction column name

* Adds support for dplyr groups and tune_results object to isotonic estimator

* Adds support for tune_results object to cal_apply() w/o breaking current functionality

* Adds support for tune_results on bootstrapped Isotonic

* Simplifies isoreg implementation functions, and standarizes their names

* Consolidates bootstrapped estimation isotonic functions

* Flips order of levels to be first dplyr group, and then the estimates for boot isotonic

* cal_apply() now works with isotonic boot and non-boot, for df's and tune results objects

* Passes checks

* Adds grouping support for beta, splits estimate methods into their own script

* Adds support for tune_results to beta

* Fixes as_name() for tune_results, updates tests, moves count for tune_results to helper function

* Adds tune_results tests for cal_apply()

* Updates help, ver bump, updates news, and some styler updates

* - Removes threshold from cal_apply()
- Adds parameters argument, and passes it to collect_predictions()
- Renames `predictions` argument to `pred_class`

* Adds exception for when using parameters on data.frames

* Creates centralized function to process validations, adds validation for logistic models

* Adds custom pillar type_sum for calibrations

* Adds summarization argument

* Adds validator for isotonic

* Adds beta, fixes docs, passes checks

* Adds tests

* Adds examples, ver bump, NEWS item

* Adds yardstick dev dep

* Update R/cal-validate.R

Co-authored-by: Max Kuhn <[email protected]>

* Update R/cal-validate.R

Co-authored-by: Max Kuhn <[email protected]>

* Adds isotonic boot validation, separates documentation for each validation method, adds rdname to rset methods

* Fixes param references

* Adds direction to stats tibble output

* Makes summarized results tidy

* Adds save_details argument, updates tests

* Adds cal_validate_summarize()

* Centralizes pred_class update, adds step to validation column

* Adds note about tune_results

* update docs

Co-authored-by: Max Kuhn <[email protected]>
  • Loading branch information
edgararuiz and topepo authored Jan 5, 2023
1 parent 19907f7 commit 5b0b59d
Show file tree
Hide file tree
Showing 13 changed files with 821 additions and 31 deletions.
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: probably
Title: Tools for Post-Processing Class Probability Estimates
Version: 0.1.0.9004
Version: 0.1.0.9005
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand All @@ -23,11 +23,12 @@ Imports:
rlang (>= 1.0.4),
tidyselect (>= 1.1.2),
vctrs (>= 0.4.1),
yardstick (>= 1.0.0),
yardstick (> 1.1.0),
betacal,
ggplot2,
butcher,
tibble,
pillar,
withr,
purrr,
tune,
Expand All @@ -52,3 +53,5 @@ Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Remotes:
tidymodels/yardstick@5f1b9ce
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ S3method(cal_plot_logistic,data.frame)
S3method(cal_plot_logistic,tune_results)
S3method(cal_plot_windowed,data.frame)
S3method(cal_plot_windowed,tune_results)
S3method(cal_validate_beta,rset)
S3method(cal_validate_isotonic,rset)
S3method(cal_validate_isotonic_boot,rset)
S3method(cal_validate_logistic,rset)
S3method(cal_validate_summarize,cal_rset)
S3method(format,class_pred)
S3method(is_equivocal,class_pred)
S3method(is_equivocal,default)
Expand All @@ -41,6 +46,7 @@ S3method(print,cal_estimate_isotonic)
S3method(reportable_rate,class_pred)
S3method(reportable_rate,default)
S3method(threshold_perf,data.frame)
S3method(type_sum,cal_binary)
S3method(vec_cast,character.class_pred)
S3method(vec_cast,class_pred.character)
S3method(vec_cast,class_pred.class_pred)
Expand Down Expand Up @@ -74,6 +80,11 @@ export(cal_estimate_logistic)
export(cal_plot_breaks)
export(cal_plot_logistic)
export(cal_plot_windowed)
export(cal_validate_beta)
export(cal_validate_isotonic)
export(cal_validate_isotonic_boot)
export(cal_validate_logistic)
export(cal_validate_summarize)
export(class_pred)
export(is_class_pred)
export(is_equivocal)
Expand All @@ -88,6 +99,7 @@ import(vctrs)
importFrom(generics,as.factor)
importFrom(generics,as.ordered)
importFrom(magrittr,"%>%")
importFrom(pillar,type_sum)
importFrom(purrr,map)
importFrom(stats,as.stepfun)
importFrom(stats,binomial)
Expand All @@ -96,6 +108,7 @@ importFrom(stats,isoreg)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(utils,head)
importFrom(yardstick,j_index)
importFrom(yardstick,sens)
importFrom(yardstick,spec)
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# probably (development version)

* Adds calibration validation function for Logistic, Isotonic, and Beta. These functions take in a re-sampled data set, and run the calibration on the testing data, and the applies the calibration to the assesment set. It then returns the average of the requested metrics, it defaults on Brier score.

* Adds `cal_apply()` function. It uses the output of a calibration function, and applies it to a data frame, or a tune_results object.

* Adds 5 model calibration remediation methods: Logistic, Logistic Spline, Isotonic, and Isotonic Bootstrapped, and Beta. They currently support data frames, and binary models only.
Expand Down
58 changes: 29 additions & 29 deletions R/cal-apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ cal_apply.data.frame <- function(.data,
) {
stop_null_parameters(parameters)
if (object$type == "binary") {
cal_add_adjust(
data_adjust <- cal_add_adjust(
object = object,
.data = .data,
pred_class = {{ pred_class }}
)

cal_update_prediction(
.data = data_adjust,
object = object,
pred_class = {{ pred_class }}
)

} else {
stop_multiclass()
}
Expand Down Expand Up @@ -71,7 +78,7 @@ cal_apply.tune_results <- function(.data,
pred_class <- rlang::parse_expr(".pred_class")
}

pred_classs <- tune::collect_predictions(
predictions <- tune::collect_predictions(
x = .data,
summarize = TRUE,
parameters = parameters,
Expand All @@ -80,7 +87,7 @@ cal_apply.tune_results <- function(.data,

cal_add_adjust(
object = object,
.data = pred_classs,
.data = predictions,
pred_class = !!pred_class
)
} else {
Expand Down Expand Up @@ -115,45 +122,21 @@ cal_add_adjust.cal_estimate_logistic <- function(object,
.data,
pred_class = NULL,
...) {
pred_class <- enquo(pred_class)

new_data <- cal_add_predict_impl(
cal_add_predict_impl(
object = object,
.data = .data
)

if (!quo_is_null(pred_class)) {
if (object$type == "binary") {
level1_gt <- new_data[[object$levels[[1]]]] > new_data[[object$levels[[2]]]]
new_data[level1_gt, as_name(pred_class)] <- names(object$levels[1])
new_data[!level1_gt, as_name(pred_class)] <- names(object$levels[2])
}
}

new_data
}

cal_add_adjust.cal_estimate_logistic_spline <- function(object,
.data,
pred_class = NULL,
...) {
pred_class <- enquo(pred_class)

new_data <- cal_add_predict_impl(
cal_add_predict_impl(
object = object,
.data = .data
)

if (!quo_is_null(pred_class)) {
if (object$type == "binary") {
pred_name <- as_name(pred_class)
level1_gt <- new_data[[object$levels[[1]]]] > new_data[[object$levels[[2]]]]
new_data[level1_gt, pred_name] <- names(object$levels[1])
new_data[!level1_gt, pred_name] <- names(object$levels[2])
}
}

new_data
}

cal_add_adjust.cal_estimate_isotonic_boot <- function(object,
Expand Down Expand Up @@ -267,3 +250,20 @@ cal_get_intervals <- function(estimates_table, .data, estimate) {
find_interval[find_interval == 0] <- 1
y[find_interval]
}

cal_update_prediction <- function(.data, object, pred_class) {
res <- .data
if (!is.null(pred_class)) {
if (object$type == "binary") {
pred_name <- as_name(pred_class)
if(pred_name %in% colnames(.data)) {
.data[, pred_name] <- NULL
}
level1_gt <- res[[object$levels[[1]]]] > res[[object$levels[[2]]]]
res[level1_gt, pred_name] <- names(object$levels[1])
res[!level1_gt, pred_name] <- names(object$levels[2])
res[, pred_name] <- as.factor(res[, pred_name][[1]])
}
}
res
}
Loading

0 comments on commit 5b0b59d

Please sign in to comment.