Skip to content

Commit

Permalink
Smooth multinomial calibration (#95)
Browse files Browse the repository at this point in the history
* initial work on #74

* update unit tests

* Updated documentation

* add namespace

* update snapshot for new pillar

* update condition

* redoc
  • Loading branch information
topepo authored Apr 28, 2023
1 parent b2156cf commit 4600920
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 62 deletions.
13 changes: 10 additions & 3 deletions R/cal-apply-multi.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ cal_apply_multi.cal_estimate_multinomial <-
#---------------------------- >> Single Predict --------------------------------

apply_multi_predict <- function(object, .data) {
if (inherits(object$estimates[[1]]$estimate, "gam")) {
prob_type <- "response"
} else {
prob_type <- "probs"
}
preds <- object$estimates[[1]]$estimate %>%
predict(newdata = .data, type = "probs") %>%
dplyr::as_tibble()
predict(newdata = .data, type = prob_type)

colnames(preds) <- as.character(object$levels)
preds <- dplyr::as_tibble(preds)

for (i in seq_along(object$levels)) {
lev <- object$levels[i]
.data[, as.character(lev)] <- preds[, names(lev)]
.data[, as.character(lev)] <- preds[, as.character(lev)]
}
.data
}
131 changes: 107 additions & 24 deletions R/cal-estimate-multinom.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,66 @@
#' Uses a Multinomial calibration model to calculate new probabilities
#' @details It uses the `multinom` function, from the `nnet` package, to
#' create the calibration.
#' @details
#' When `smooth = FALSE`, [nnet::multinom()] function is used to estimate the
#' model, otherwise [mgcv::gam()] is used.
#' @inheritParams cal_estimate_logistic
#' @examplesIf !probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "randomForest"))
#' library(modeldata)
#' library(parsnip)
#' library(dplyr)
#'
#' f <-
#' list(
#' ~ -0.5 + 0.6 * abs(A),
#' ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, -2),
#' ~ -0.6 * A + 0.50 * B - A * B
#' )
#'
#' set.seed(1)
#' tr_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]])
#' cal_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]])
#' te_dat <- sim_multinomial(500, eqn_1 = f[[1]], eqn_2 = f[[2]], eqn_3 = f[[3]])
#'
#' set.seed(2)
#' rf_fit <-
#' rand_forest() %>%
#' set_mode("classification") %>%
#' set_engine("randomForest") %>%
#' fit(class ~ ., data = tr_dat)
#'
#' cal_pred <-
#' predict(rf_fit, cal_dat, type = "prob") %>%
#' bind_cols(cal_dat)
#' te_pred <-
#' predict(rf_fit, te_dat, type = "prob") %>%
#' bind_cols(te_dat)
#'
#' cal_plot_windowed(cal_pred, truth = class, window_size = 0.1, step_size = 0.03)
#'
#' smoothed_mn <- cal_estimate_multinomial(cal_pred, truth = class)
#'
#' new_test_pred <- cal_apply(te_pred, smoothed_mn)
#'
#' cal_plot_windowed(new_test_pred, truth = class, window_size = 0.1, step_size = 0.03)
#'
#' @export
cal_estimate_multinomial <- function(.data,
truth = NULL,
estimate = dplyr::starts_with(".pred_"),
smooth = TRUE,
parameters = NULL,
...) {
UseMethod("cal_estimate_multinomial")
}

#' @export
#' @rdname cal_estimate_multinomial
cal_estimate_multinomial.data.frame <- function(.data,
truth = NULL,
estimate = dplyr::starts_with(".pred_"),
parameters = NULL,
...) {
cal_estimate_multinomial.data.frame <-
function(.data,
truth = NULL,
estimate = dplyr::starts_with(".pred_"),
smooth = TRUE,
parameters = NULL,
...) {
stop_null_parameters(parameters)

truth <- enquo(truth)
Expand All @@ -26,18 +69,21 @@ cal_estimate_multinomial.data.frame <- function(.data,
truth = !!truth,
estimate = {{ estimate }},
source_class = cal_class_name(.data),
smooth = smooth,
...
)
}

#' @export
#' @rdname cal_estimate_multinomial
cal_estimate_multinomial.tune_results <- function(.data,
truth = NULL,
estimate = dplyr::starts_with(".pred_"),
parameters = NULL,
...) {
tune_args <- tune_results_args(
cal_estimate_multinomial.tune_results <-
function(.data,
truth = NULL,
estimate = dplyr::starts_with(".pred_"),
smooth = TRUE,
parameters = NULL,
...) {
tune_args <- tune_results_args(
.data = .data,
truth = {{ truth }},
estimate = {{ estimate }},
Expand All @@ -53,6 +99,7 @@ cal_estimate_multinomial.tune_results <- function(.data,
truth = !!tune_args$truth,
estimate = !!tune_args$estimate,
source_class = cal_class_name(.data),
smooth = smooth,
...
)
}
Expand All @@ -64,7 +111,7 @@ required_pkgs.cal_estimate_multinomial <- function(x, ...) {
c("nnet", "probably")
}

cal_multinom_impl <- function(.data, truth, estimate, source_class, ...) {
cal_multinom_impl <- function(.data, truth, estimate, source_class, smooth, ...) {
truth <- enquo(truth)

levels <- truth_estimate_map(.data, !!truth, {{ estimate }})
Expand All @@ -79,6 +126,7 @@ cal_multinom_impl <- function(.data, truth, estimate, source_class, ...) {
.data = .data,
truth = !!truth,
levels = levels,
smooth = smooth,
...
)

Expand All @@ -95,7 +143,7 @@ cal_multinom_impl <- function(.data, truth, estimate, source_class, ...) {
}


cal_multinom_impl_grp <- function(.data, truth, levels, ...) {
cal_multinom_impl_grp <- function(.data, truth, levels, smooth, ...) {
truth <- enquo(truth)
.data %>%
split_dplyr_groups() %>%
Expand All @@ -105,6 +153,7 @@ cal_multinom_impl_grp <- function(.data, truth, levels, ...) {
.data = x$data,
truth = !!truth,
levels = levels,
smooth = smooth,
... = ...
)
list(
Expand All @@ -118,21 +167,55 @@ cal_multinom_impl_grp <- function(.data, truth, levels, ...) {
cal_multinom_impl_single <- function(.data,
truth = NULL,
levels = NULL,
smooth = TRUE,
...) {
truth <- enquo(truth)
num_lvls <- length(levels)
levels <- levels[1:(length(levels) - 1)]

levels <- levels[1:length(levels) - 1]
if (smooth) {
# multinomial gams in mgcv needs zero-based integers as the outcome

levels_formula <- purrr::reduce(
levels,
function(x, y) expr(!!x + !!y)
)
class_col <- deparse(ensym(truth))
.data[[class_col]] <- as.numeric(.data[[class_col]]) - 1
max_int <- max(.data[[class_col]], na.rm = TRUE)

f_model <- expr(!!ensym(truth) ~ !!levels_formula)
# It also needs a list of formulas, one for each level, and the first one
# requires a LHS

smooths <- purrr::map(levels, ~ call2(.fn = "s", expr(!!.x)))
rhs_f <- purrr::reduce(smooths, function(x, y) expr(!!x + !!y))
rhs_only <- new_formula(lhs = NULL, rhs = rhs_f)
both_sides <- new_formula(lhs = ensym(truth), rhs = rhs_f)
all_f <- purrr::map(seq_along(levels), ~ rhs_only)
all_f[[1]] <- both_sides

model <- mgcv::gam(all_f, data = .data, family = mgcv::multinom(max_int))

# Nuke environments saved in formulas
# # TODO This next line causes a failure for unknown reasons. Look into it more
# model$formula <- purrr::map(model$formula, clean_env)
model$terms <- clean_env(model$terms)

} else {
levels_formula <- purrr::reduce(
levels,
function(x, y) expr(!!x + !!y)
)

f_model <- expr(!!ensym(truth) ~ !!levels_formula)

prevent_output <- utils::capture.output(
model <- nnet::multinom(formula = f_model, data = .data, ...)
)
model$terms <- clean_env(model$terms)
}

prevent_output <- utils::capture.output(
model <- nnet::multinom(formula = f_model, data = .data, ...)
)

model
}

clean_env <- function(x) {
attr(x, ".Environment") <- rlang::base_env()
x
}
51 changes: 49 additions & 2 deletions man/cal_estimate_multinomial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 40 additions & 8 deletions tests/testthat/_snaps/cal-estimate.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,23 @@
`.pred_coyote` ==> coyote
`.pred_gray_fox` ==> gray_fox

---

Code
print(sp_smth_multi)
Message
-- Probability Calibration
Method: Multinomial
Type: Multiclass
Source class: Data Frame
Data points: 110
Truth variable: `Species`
Estimate variables:
`.pred_bobcat` ==> bobcat
`.pred_coyote` ==> coyote
`.pred_gray_fox` ==> gray_fox

# Multinomial estimates work - tune_results

Code
Expand All @@ -203,14 +220,29 @@
Method: Multinomial
Type: Multiclass
Source class: Tune Results
Data points: 2,930, split in 10 groups
Truth variable: `Bldg_Type`
Data points: 5,000, split in 10 groups
Truth variable: `class`
Estimate variables:
`.pred_one` ==> one
`.pred_two` ==> two
`.pred_three` ==> three

---

Code
print(tl_smth_multi)
Message
-- Probability Calibration
Method: Multinomial
Type: Multiclass
Source class: Tune Results
Data points: 5,000, split in 10 groups
Truth variable: `class`
Estimate variables:
`.pred_OneFam` ==> OneFam
`.pred_TwoFmCon` ==> TwoFmCon
`.pred_Duplex` ==> Duplex
`.pred_Twnhs` ==> Twnhs
`.pred_TwnhsE` ==> TwnhsE
`.pred_one` ==> one
`.pred_two` ==> two
`.pred_three` ==> three

# Linear estimates work - data.frame

Expand All @@ -228,7 +260,7 @@
# Linear estimates work - tune_results

Code
print(tl_logistic)
print(tl_linear)
Message
-- Regression Calibration
Expand Down
Loading

0 comments on commit 4600920

Please sign in to comment.