Skip to content

Commit

Permalink
Syncs with main, and updates tests
Browse files Browse the repository at this point in the history
Merge branch 'main' into cal-apply

# Conflicts:
#	NAMESPACE
#	R/cal-apply.R
  • Loading branch information
edgararuiz committed Dec 22, 2022
2 parents 57dac59 + bd0b262 commit fb22f32
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 13 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: probably
Title: Tools for Post-Processing Class Probability Estimates
Version: 0.1.0.9002
Version: 0.1.0.9003
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand All @@ -24,6 +24,7 @@ Imports:
tidyselect (>= 1.1.2),
vctrs (>= 0.4.1),
yardstick (>= 1.0.0),
betacal,
ggplot2,
butcher,
tibble,
Expand All @@ -50,4 +51,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.2
RoxygenNote: 7.2.3
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ S3method(as_class_pred,factor)
S3method(cal_apply,cal_object)
S3method(cal_apply,data.frame)
S3method(cal_apply,tune_results)
S3method(cal_estimate_beta,data.frame)
S3method(cal_estimate_isotonic,data.frame)
S3method(cal_estimate_isotonic_boot,data.frame)
S3method(cal_estimate_logistic,data.frame)
Expand Down Expand Up @@ -63,6 +64,7 @@ export(as.factor)
export(as.ordered)
export(as_class_pred)
export(cal_apply)
export(cal_estimate_beta)
export(cal_estimate_isotonic)
export(cal_estimate_isotonic_boot)
export(cal_estimate_logistic)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

* Adds `cal_apply()` function. It uses the output of a calibration function, and applies it to a data frame.

* Adds 4 model calibration remediation methods: Logistic, Logistic Spline, Isotonic, and Isotonic Bootstrapped. They currently support data.frame only, and binary models.
* Adds 5 model calibration remediation methods: Logistic, Logistic Spline, Isotonic, and Isotonic Bootstrapped, and Beta. They currently support data frames, and binary models only.

* Adds model calibration diagnostic functions. They implement three methods: binning probabilities, fitting a logistic spline model against the probabilities, and with creating a running percentage of the data. There are three new plotting functions, and three table functions. It supports data.frames and tune_results objects.

Expand Down
28 changes: 25 additions & 3 deletions R/cal-apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ cal_apply.tune_results<- function(.data, object, prediction = NULL, threshold =

#' @export
cal_apply.cal_object <- function(.data, object, prediction = NULL, threshold = NULL, ...) {
rlang::abort(paste0("`cal_apply()` expects the data as the first argument,",
"and the object object as the second argument."
))
if("data.frame" %in% class(object)) {
rlang::abort(paste0("`cal_apply()` expects the data as the first argument,",
" and the object as the second argument. Please reverse",
" the order of the arguments and try again."
))
}
}

# ------------------------------- Adjust ---------------------------------------
Expand Down Expand Up @@ -149,6 +152,25 @@ cal_add_adjust.cal_estimate_isotonic <- function(object,
)
}

cal_add_adjust.cal_estimate_beta <- function(object,
.data,
prediction = NULL,
threshold = NULL,
...
) {
if (object$type == "binary") {
p <- dplyr::pull(.data, !!object$levels[[1]])
model <- object$estimates
preds <- betacal::beta_predict(
p = p,
calib = model
)
.data[object$levels[[1]]] <- preds
.data[object$levels[[2]]] <- 1 - preds
}
.data
}

#---------------------------- Adjust implementations ---------------------------

cal_add_predict_impl <- function(object, .data) {
Expand Down
110 changes: 105 additions & 5 deletions R/cal-estimate.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#' cal_estimate_logistic(segment_logistic, Class, c(.pred_poor, .pred_good))
#' # dplyr selector functions are also supported
#' cal_estimate_logistic(segment_logistic, Class, dplyr::starts_with(".pred_"))
#' @details
#' This function uses existing modeling functions from other packages to create
#' the calibration:
#' - `stats::glm()` is used when `smooth` is set to `FALSE`
#' - `mgcv::gam()` is used when `smooth` is set to `TRUE`
#' @export
cal_estimate_logistic <- function(.data,
truth = NULL,
Expand Down Expand Up @@ -77,6 +82,12 @@ cal_estimate_logistic.tune_results<- function(.data,
#----------------------------- >> Isotonic -------------------------------------
#' Uses an Isotonic regression model to calibrate probabilities
#' @inheritParams cal_estimate_logistic
#' @details This function uses `stats::isoreg()` to create obtain the calibration
#' values.
#' @references
#' Zadrozny, Bianca and Elkan, Charles. (2002). Transforming Classifier Scores
#' into Accurate Multiclass Probability Estimates. _Proceedings of the ACM SIGKDD
#' International Conference on Knowledge Discovery and Data Mining._
#' @examples
#' # It will automatically identify the probability columns
#' # if passed a model fitted with tidymodels
Expand Down Expand Up @@ -129,6 +140,9 @@ cal_estimate_isotonic.data.frame <- function(.data,
#' Uses a bootstrapped Isotonic regression model to calibrate probabilities
#' @param times Number of bootstraps.
#' @inheritParams cal_estimate_logistic
#' @details This function uses `stats::isoreg()` to create obtain the calibration
#' values. It runs `isoreg()` multiple times, and each time with a different
#' seed. The results are saved inside the returned `cal_object`.
#' @examples
#' # It will automatically identify the probability columns
#' # if passed a model fitted with tidymodels
Expand Down Expand Up @@ -180,6 +194,95 @@ cal_estimate_isotonic_boot.data.frame <- function(.data,
res
}

#------------------------------- >> Beta --------------------------------------
#' Uses a Beta calibration model to calculate new probabilities
#' @param shape_params Number of shape parameters to use. Accepted values are
#' 1 and 2. Defaults to 2.
#' @param location_params Number of location parameters to use. Accepted values
#' 1 and 0. Defaults to 1.
#' @inheritParams cal_estimate_logistic
#' @details This function uses the `betcal::beta_calibration()` function, and
#' retains the resulting model.
#' @references Meelis Kull, Telmo M. Silva Filho, Peter Flach "Beyond sigmoids:
#' How to obtain well-calibrated probabilities from binary classifiers with beta
#' calibration," _Electronic Journal of Statistics_ 11(2), 5052-5080, (2017)
#' @examples
#' # It will automatically identify the probability columns
#' # if passed a model fitted with tidymodels
#' cal_estimate_beta(segment_logistic, Class)
#' @export
cal_estimate_beta <- function(.data,
truth = NULL,
shape_params = 2,
location_params = 1,
estimate = dplyr::starts_with(".pred_"),
...) {
UseMethod("cal_estimate_beta")
}

#' @export
cal_estimate_beta.data.frame <- function(.data,
truth = NULL,
shape_params = 2,
location_params = 1,
estimate = dplyr::starts_with(".pred_"),
...) {
truth <- enquo(truth)

levels <- truth_estimate_map(.data, {{ truth }}, {{ estimate }})

if (length(levels) == 2) {
x_factor <- dplyr::pull(.data, !!truth)
x <- x_factor == names(levels[1])
y <- dplyr::pull(.data, !!levels[[1]])

parameters <- NULL

if (shape_params == 1) {
parameters <- "a"
}

if (shape_params == 2) {
parameters <- "ab"
}

if (location_params == 1) {
parameters <- paste0(parameters, "m")
}

if (location_params > 1) {
rlang::abort("Invalid `location_params`, allowed values are 1 and 0")
}

if (is.null(parameters)) {
rlang::abort("Invalid `shape_params`, allowed values are 1 and 2")
}

prevent_output <- utils::capture.output(
beta_model <- invisible(betacal::beta_calibration(
p = y,
y = x,
parameters = parameters
))
)

beta_model$model <- butcher::butcher(beta_model$model)

res <- as_binary_cal_object(
estimate = beta_model,
levels = levels,
truth = !!truth,
method = "Beta",
rows = nrow(.data),
additional_class = "cal_estimate_beta"
)
} else {
stop_multiclass()
}

res
}

#------------------------- Estimate implementation -----------------------------
#------------------------------ >> Logistic ------------------------------------
cal_estimate_logistic_impl <- function(.data,
Expand Down Expand Up @@ -298,9 +401,7 @@ cal_isoreg_dataframe <- function(.data,
}

# cal_isoreg_boot() runs boot_iso() as many times specified by `times`.
# Each time it runs, it passes a different seed. boot_iso() then runs a
# single Isotonic model with using withr to set a new seed.

# Each time it runs, it passes a different seed.
cal_isoreg_boot <- function(.data,
truth,
estimate,
Expand All @@ -319,8 +420,7 @@ cal_isoreg_boot <- function(.data,

boot_iso <- function(.data, truth, estimate, seed) {
withr::with_seed(
seed,
{
seed, {
cal_isoreg_dataframe(
.data = .data,
truth = {{ truth }},
Expand Down
53 changes: 53 additions & 0 deletions man/cal_estimate_beta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions man/cal_estimate_isotonic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/cal_estimate_isotonic_boot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions man/cal_estimate_logistic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions tests/testthat/_snaps/cal-estimate.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@
`.pred_good` ==> good
`.pred_poor` ==> poor

# Beta estimates work

Code
print(sl_beta)
Message
-- Probability Calibration
Method: Beta
Type: Binary
Train set size: 1,010
Truth variable: `Class`
Estimate variables:
`.pred_good` ==> good
`.pred_poor` ==> poor

# Non-default names used for estimate columns

Code
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-cal-apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ test_that("Isotonic apply work", {
expect_equal(sd(pred_good), 0.2870775, tolerance = 0.000001)
})

test_that("Beta apply work", {
sl_beta <- cal_estimate_beta(segment_logistic, Class)
ap_beta <- cal_apply(segment_logistic, sl_beta)

pred_good <- ap_beta$.pred_good
expect_equal(mean(pred_good), 0.3425743, tolerance = 0.000001)
expect_equal(sd(pred_good), 0.294565, tolerance = 0.000001)
})

test_that("Isotonic Bootstrapped apply work", {
sl_boot <- cal_estimate_isotonic_boot(segment_logistic, Class)
ap_boot <- cal_apply(segment_logistic, sl_boot)
Expand Down
Loading

0 comments on commit fb22f32

Please sign in to comment.