Skip to content

Commit

Permalink
Adds multiclasss support for plots (#89)
Browse files Browse the repository at this point in the history
* Initial grouped process midpoint

* New approach to multi class windowed

* First sucessful run of facet grid for tuned results with multi class, break plots

* ensures binary models plot,updates rds

* Switches out cols and rows

* Override of warning for tune results

* Stashing changes

* Updates test, adds qmd to ignore

* Sets event_level default to auto

* Moving each plot method to its own R script, renames original cal-plot, to cal-plot-utils

* Second level works again

* Fixes breaks for binary

* Windowed plots now work

* Adds model grps function

* Logistic works

* Passes cal-plot tests, but not all tests...yet :)

* Restoring previous level to column matching

* Passes all tests

* Re-orgs breaks script, tries to make it easier to read and debug

* Re-org windowed plot script

* Styles logistic, utils and windows, re-orgs logistic

* Properly fixes tune results plot grids

* Removes 'binary' from cal_table functions, and from the plot impl function

* Moves is_val processing up to grouping function

* Increse performance of tune results for breaks and windowed, cut process time by 2/3

* Passes checks

* utility for going between a quosure and a symbol

* update tests

* use dev branch for tune

* unquote in select()

* update quosure argument passing

* more quosure passing

---------

Co-authored-by: Max Kuhn <[email protected]>
  • Loading branch information
edgararuiz and topepo authored Mar 17, 2023
1 parent 64d329d commit 9c582b1
Show file tree
Hide file tree
Showing 20 changed files with 1,588 additions and 1,323 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ VignetteBuilder:
knitr
Remotes:
tidymodels/yardstick@5f1b9ce,
tidymodels/tune#624
tidymodels/tune
ByteCompile: true
Config/Needs/website: tidyverse/tidytemplate
Config/testthat/edition: 3
Expand Down
19 changes: 10 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Generated by roxygen2: do not edit by hand

S3method(.cal_binary_table_breaks,data.frame)
S3method(.cal_binary_table_breaks,tune_results)
S3method(.cal_binary_table_logistic,data.frame)
S3method(.cal_binary_table_logistic,tune_results)
S3method(.cal_binary_table_windowed,data.frame)
S3method(.cal_binary_table_windowed,tune_results)
S3method(.cal_table_breaks,data.frame)
S3method(.cal_table_breaks,tune_results)
S3method(.cal_table_logistic,data.frame)
S3method(.cal_table_logistic,tune_results)
S3method(.cal_table_windowed,data.frame)
S3method(.cal_table_windowed,tune_results)
S3method(any_equivocal,class_pred)
S3method(any_equivocal,default)
S3method(as_class_pred,default)
Expand Down Expand Up @@ -90,9 +90,9 @@ S3method(vec_ptype2,factor.class_pred)
S3method(vec_ptype_abbr,class_pred)
S3method(which_equivocal,class_pred)
S3method(which_equivocal,default)
export(.cal_binary_table_breaks)
export(.cal_binary_table_logistic)
export(.cal_binary_table_windowed)
export(.cal_table_breaks)
export(.cal_table_logistic)
export(.cal_table_windowed)
export(any_equivocal)
export(append_class_pred)
export(as.factor)
Expand Down Expand Up @@ -146,6 +146,7 @@ importFrom(stats,glm)
importFrom(stats,isoreg)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(stats,prop.test)
importFrom(stats,qnorm)
importFrom(utils,head)
importFrom(yardstick,j_index)
Expand Down
139 changes: 47 additions & 92 deletions R/cal-estimate-utils.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#-------------------------- Binary Objects -------------------------------------
#-------------------------- Print methods -------------------------------------

#' @export
print.cal_binary <- function(x, ...) {
Expand All @@ -10,50 +10,19 @@ print.cal_estimate_isotonic <- function(x, ...) {
print_cls_cal(x, upv = TRUE, ...)
}

# ------------------------------- Multi ----------------------------------------

#' @export
print.cal_multi <- function(x, ...) {
print_cls_cal(x, ...)
}


# ------------------------------- Regression -----------------------------------

#' @export
print.cal_regression <- function(x, ...) {
print_reg_cal(x, ...)
}


as_regression_cal_object <- function(estimate,
truth,
levels,
method,
rows,
additional_class = NULL,
source_class = NULL) {
truth <- enquo(truth)

as_cal_object(
estimate = estimate,
truth = !!truth,
levels = levels,
method = method,
rows = rows,
additional_classes = additional_class,
source_class = source_class,
type = "regression"
)
}

# ------------------------------- Utils ----------------------------------------

print_cls_cal <- function(x, upv = FALSE, ...) {

print_type <-
switch(
x$type,
switch(x$type,
"binary" = "Binary",
"multiclass" = "Multiclass",
"one_vs_all" = "Multiclass (1 v All)",
Expand Down Expand Up @@ -123,6 +92,50 @@ print_reg_cal <- function(x, upv = FALSE, ...) {
cli::cli_end()
}

# ------------------------ Estimate name methods -------------------------------

cal_class_name <- function(x) {
UseMethod("cal_class_name")
}

cal_class_name.data.frame <- function(x) {
"Data Frame"
}

cal_class_name.tune_results <- function(x) {
"Tune Results"
}

cal_class_name.tune_results <- function(x) {
"Tune Results"
}

cal_class_name.rset <- function(x) {
"Resampled data set"
}

# ------------------------------- Utils ----------------------------------------

as_regression_cal_object <- function(estimate,
truth,
levels,
method,
rows,
additional_class = NULL,
source_class = NULL) {
truth <- enquo(truth)

as_cal_object(
estimate = estimate,
truth = !!truth,
levels = levels,
method = method,
rows = rows,
additional_classes = additional_class,
source_class = source_class,
type = "regression"
)
}

as_cal_object <- function(estimate,
truth,
Expand Down Expand Up @@ -169,45 +182,6 @@ stop_multiclass <- function() {
cli::cli_abort("Multiclass not supported...yet")
}

# Centralizes the figuring out of which probability-variable maps to which
# factor level of the "truth" variable. This is where the logic of finding
# and mapping tidymodels explicit column names happen. If there are no .pred_
# named variables, it will map the variables based on the position.
# It returns a named list, wit the variable names as syms, and the assigned
# levels as the name.
truth_estimate_map <- function(.data, truth, estimate) {
truth_str <- tidyselect_cols(.data, {{ truth }})

estimate_str <- tidyselect_cols(.data, {{ estimate }}) %>%
names()

if (length(estimate_str) == 0) {
cli::cli_abort("{.arg estimate} must select at least one column.")
}

truth_levels <- levels(.data[[truth_str]])

if (length(truth_levels) > 0) {
if (all(substr(estimate_str, 1, 6) == ".pred_")) {
est_map <- purrr::map(
truth_levels,
~ sym(estimate_str[paste0(".pred_", .x) == estimate_str])
)
} else {
est_map <- purrr::map(
seq_along(truth_levels),
~ sym(estimate_str[[.x]])
)
}

res <- set_names(est_map, truth_levels)
} else {
res <- list(sym(estimate_str))
names(res) <- "predictions"
}
res
}

# Wraps tidyselect call to avoid code duplication in the function above
tidyselect_cols <- function(.data, x) {
tidyselect::eval_select(
Expand All @@ -217,6 +191,7 @@ tidyselect_cols <- function(.data, x) {
)
}


# dplyr::group_map() does not pass the parent function's `...`, it overrides it
# and there seems to be no way to change it. This function will split the the
# data set by all the combination of the grouped variables. It will respect
Expand Down Expand Up @@ -251,23 +226,3 @@ stop_null_parameters <- function(x) {
rlang::abort("The `parameters` argument is only valid for `tune_results`.")
}
}

cal_class_name <- function(x) {
UseMethod("cal_class_name")
}

cal_class_name.data.frame <- function(x) {
"Data Frame"
}

cal_class_name.tune_results <- function(x) {
"Tune Results"
}

cal_class_name.tune_results <- function(x) {
"Tune Results"
}

cal_class_name.rset <- function(x) {
"Resampled data set"
}
Loading

0 comments on commit 9c582b1

Please sign in to comment.