Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds validation functions #63

Merged
merged 44 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c86346a
Initial sketch of grouped support for estimates
edgararuiz Dec 16, 2022
0310f01
Creates custom group splitting function to create a list of data fram…
edgararuiz Dec 16, 2022
218c16a
Finalizes estimate object
edgararuiz Dec 16, 2022
fe49e27
Adds support for new cal_object structure in cal_apply()
edgararuiz Dec 16, 2022
ee003d6
Adds support to modify the prediction
edgararuiz Dec 16, 2022
bf773b0
Adds tune_result method to logistic estimator
edgararuiz Dec 16, 2022
2e4ec67
Adds method for tune_results for cal_apply()
edgararuiz Dec 16, 2022
57dac59
Corrects the target prediction column name
edgararuiz Dec 16, 2022
fb22f32
Syncs with main, and updates tests
edgararuiz Dec 22, 2022
9985061
Adds support for dplyr groups and tune_results object to isotonic est…
edgararuiz Dec 22, 2022
d59e8c8
Adds support for tune_results object to cal_apply() w/o breaking curr…
edgararuiz Dec 22, 2022
4fb4e58
Adds support for tune_results on bootstrapped Isotonic
edgararuiz Dec 23, 2022
361d6da
Simplifies isoreg implementation functions, and standarizes their names
edgararuiz Dec 23, 2022
3bd0028
Consolidates bootstrapped estimation isotonic functions
edgararuiz Dec 23, 2022
3c49785
Flips order of levels to be first dplyr group, and then the estimates…
edgararuiz Dec 23, 2022
732231c
cal_apply() now works with isotonic boot and non-boot, for df's and t…
edgararuiz Dec 23, 2022
27ec17f
Passes checks
edgararuiz Dec 26, 2022
72c6112
Adds grouping support for beta, splits estimate methods into their ow…
edgararuiz Dec 26, 2022
6e90545
Adds support for tune_results to beta
edgararuiz Dec 26, 2022
29a6fc3
Fixes as_name() for tune_results, updates tests, moves count for tune…
edgararuiz Dec 26, 2022
1c2fdc0
Adds tune_results tests for cal_apply()
edgararuiz Dec 27, 2022
3f92932
Updates help, ver bump, updates news, and some styler updates
edgararuiz Dec 27, 2022
041411e
- Removes threshold from cal_apply()
edgararuiz Dec 28, 2022
8a19417
Adds exception for when using parameters on data.frames
edgararuiz Dec 28, 2022
bbe83b0
Creates centralized function to process validations, adds validation …
edgararuiz Dec 30, 2022
fa2ab8b
Adds custom pillar type_sum for calibrations
edgararuiz Dec 30, 2022
2147d49
Adds summarization argument
edgararuiz Dec 30, 2022
977f5d7
Adds validator for isotonic
edgararuiz Dec 31, 2022
b1c7e81
Adds beta, fixes docs, passes checks
edgararuiz Jan 1, 2023
7d5eb0e
Adds tests
edgararuiz Jan 1, 2023
66b4854
Addresses conflicts from main
edgararuiz Jan 2, 2023
255fca4
Adds examples, ver bump, NEWS item
edgararuiz Jan 2, 2023
cc7b49f
Adds yardstick dev dep
edgararuiz Jan 2, 2023
1c81c64
Update R/cal-validate.R
edgararuiz Jan 3, 2023
a9e427d
Update R/cal-validate.R
edgararuiz Jan 3, 2023
02d8143
Adds isotonic boot validation, separates documentation for each valid…
edgararuiz Jan 3, 2023
bd06868
Fixes param references
edgararuiz Jan 3, 2023
235d6a7
Adds direction to stats tibble output
edgararuiz Jan 3, 2023
298ab6f
Makes summarized results tidy
edgararuiz Jan 3, 2023
204b139
Adds save_details argument, updates tests
edgararuiz Jan 3, 2023
efe960d
Adds cal_validate_summarize()
edgararuiz Jan 3, 2023
40ccf1f
Centralizes pred_class update, adds step to validation column
edgararuiz Jan 4, 2023
7e3e48c
Adds note about tune_results
edgararuiz Jan 4, 2023
d8d1283
update docs
topepo Jan 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: probably
Title: Tools for Post-Processing Class Probability Estimates
Version: 0.1.0.9004
Version: 0.1.0.9005
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand All @@ -23,11 +23,12 @@ Imports:
rlang (>= 1.0.4),
tidyselect (>= 1.1.2),
vctrs (>= 0.4.1),
yardstick (>= 1.0.0),
yardstick (> 1.1.0),
betacal,
ggplot2,
butcher,
tibble,
pillar,
withr,
purrr,
tune,
Expand All @@ -52,3 +53,5 @@ Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Remotes:
tidymodels/yardstick@5f1b9ce
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ S3method(cal_plot_logistic,data.frame)
S3method(cal_plot_logistic,tune_results)
S3method(cal_plot_windowed,data.frame)
S3method(cal_plot_windowed,tune_results)
S3method(cal_validate_beta,rset)
S3method(cal_validate_isotonic,rset)
S3method(cal_validate_isotonic_boot,rset)
S3method(cal_validate_logistic,rset)
S3method(cal_validate_summarize,cal_rset)
S3method(format,class_pred)
S3method(is_equivocal,class_pred)
S3method(is_equivocal,default)
Expand All @@ -41,6 +46,7 @@ S3method(print,cal_estimate_isotonic)
S3method(reportable_rate,class_pred)
S3method(reportable_rate,default)
S3method(threshold_perf,data.frame)
S3method(type_sum,cal_binary)
S3method(vec_cast,character.class_pred)
S3method(vec_cast,class_pred.character)
S3method(vec_cast,class_pred.class_pred)
Expand Down Expand Up @@ -74,6 +80,11 @@ export(cal_estimate_logistic)
export(cal_plot_breaks)
export(cal_plot_logistic)
export(cal_plot_windowed)
export(cal_validate_beta)
export(cal_validate_isotonic)
export(cal_validate_isotonic_boot)
export(cal_validate_logistic)
export(cal_validate_summarize)
export(class_pred)
export(is_class_pred)
export(is_equivocal)
Expand All @@ -88,6 +99,7 @@ import(vctrs)
importFrom(generics,as.factor)
importFrom(generics,as.ordered)
importFrom(magrittr,"%>%")
importFrom(pillar,type_sum)
importFrom(purrr,map)
importFrom(stats,as.stepfun)
importFrom(stats,binomial)
Expand All @@ -96,6 +108,7 @@ importFrom(stats,isoreg)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(utils,head)
importFrom(yardstick,j_index)
importFrom(yardstick,sens)
importFrom(yardstick,spec)
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# probably (development version)

* Adds calibration validation function for Logistic, Isotonic, and Beta. These functions take in a re-sampled data set, and run the calibration on the testing data, and the applies the calibration to the assesment set. It then returns the average of the requested metrics, it defaults on Brier score.

* Adds `cal_apply()` function. It uses the output of a calibration function, and applies it to a data frame, or a tune_results object.

* Adds 5 model calibration remediation methods: Logistic, Logistic Spline, Isotonic, and Isotonic Bootstrapped, and Beta. They currently support data frames, and binary models only.
Expand Down
58 changes: 29 additions & 29 deletions R/cal-apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ cal_apply.data.frame <- function(.data,
) {
stop_null_parameters(parameters)
if (object$type == "binary") {
cal_add_adjust(
data_adjust <- cal_add_adjust(
object = object,
.data = .data,
pred_class = {{ pred_class }}
)

cal_update_prediction(
.data = data_adjust,
object = object,
pred_class = {{ pred_class }}
)

} else {
stop_multiclass()
}
Expand Down Expand Up @@ -71,7 +78,7 @@ cal_apply.tune_results <- function(.data,
pred_class <- rlang::parse_expr(".pred_class")
}

pred_classs <- tune::collect_predictions(
predictions <- tune::collect_predictions(
x = .data,
summarize = TRUE,
parameters = parameters,
Expand All @@ -80,7 +87,7 @@ cal_apply.tune_results <- function(.data,

cal_add_adjust(
object = object,
.data = pred_classs,
.data = predictions,
pred_class = !!pred_class
)
} else {
Expand Down Expand Up @@ -115,45 +122,21 @@ cal_add_adjust.cal_estimate_logistic <- function(object,
.data,
pred_class = NULL,
...) {
pred_class <- enquo(pred_class)

new_data <- cal_add_predict_impl(
cal_add_predict_impl(
object = object,
.data = .data
)

if (!quo_is_null(pred_class)) {
if (object$type == "binary") {
level1_gt <- new_data[[object$levels[[1]]]] > new_data[[object$levels[[2]]]]
new_data[level1_gt, as_name(pred_class)] <- names(object$levels[1])
new_data[!level1_gt, as_name(pred_class)] <- names(object$levels[2])
}
}

new_data
}

cal_add_adjust.cal_estimate_logistic_spline <- function(object,
.data,
pred_class = NULL,
...) {
pred_class <- enquo(pred_class)

new_data <- cal_add_predict_impl(
cal_add_predict_impl(
object = object,
.data = .data
)

if (!quo_is_null(pred_class)) {
if (object$type == "binary") {
pred_name <- as_name(pred_class)
level1_gt <- new_data[[object$levels[[1]]]] > new_data[[object$levels[[2]]]]
new_data[level1_gt, pred_name] <- names(object$levels[1])
new_data[!level1_gt, pred_name] <- names(object$levels[2])
}
}

new_data
}

cal_add_adjust.cal_estimate_isotonic_boot <- function(object,
Expand Down Expand Up @@ -267,3 +250,20 @@ cal_get_intervals <- function(estimates_table, .data, estimate) {
find_interval[find_interval == 0] <- 1
y[find_interval]
}

cal_update_prediction <- function(.data, object, pred_class) {
res <- .data
if (!is.null(pred_class)) {
if (object$type == "binary") {
pred_name <- as_name(pred_class)
if(pred_name %in% colnames(.data)) {
.data[, pred_name] <- NULL
}
level1_gt <- res[[object$levels[[1]]]] > res[[object$levels[[2]]]]
res[level1_gt, pred_name] <- names(object$levels[1])
res[!level1_gt, pred_name] <- names(object$levels[2])
res[, pred_name] <- as.factor(res[, pred_name][[1]])
}
}
res
}
Loading