From 398adb8958854278a24be4a5a43d665326204e4f Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 18:28:12 -0800 Subject: [PATCH] refactor out set_pred_names() --- R/model-glm.R | 6 +++--- R/model-xgboost.R | 13 +++++++----- R/parsnip.R | 54 ++++++++++++++++++++++++++--------------------- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/R/model-glm.R b/R/model-glm.R index 0a753b4..865cb1a 100644 --- a/R/model-glm.R +++ b/R/model-glm.R @@ -19,7 +19,7 @@ orbital.glm <- function( if ("class" %in% type) { res <- c( res, - glue::glue( + orbital_tmp_class_name = glue::glue( "dplyr::case_when({eq} < 0.5 ~ {levels[1]}, .default = {levels[2]})" ) ) @@ -27,8 +27,8 @@ orbital.glm <- function( if ("prob" %in% type) { res <- c( res, - glue::glue("1 - ({eq})"), - glue::glue("{eq}") + orbital_tmp_prob_name1 = glue::glue("1 - ({eq})"), + orbital_tmp_prob_name2 = glue::glue("{eq}") ) } } else if (mode == "regression") { diff --git a/R/model-xgboost.R b/R/model-xgboost.R index ebb5b23..5e8492c 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -42,14 +42,17 @@ xgboost_multisoft <- function(x, type, lvl) { if ("class" %in% type) { res <- c( res, - ".pred_class" = softmax(lvl) + orbital_tmp_class_name = softmax(lvl) ) } if ("prob" %in% type) { + eqs <- glue::glue("exp({lvl}) / norm") + names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl)) + res <- c( res, "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), - stats::setNames(glue::glue("exp({lvl}) / norm"), NA) + eqs ) } res @@ -80,7 +83,7 @@ xgboost_logistic <- function(x, type, lvl) { res <- c( res, - .pred_class = glue::glue( + orbital_tmp_class_name = glue::glue( "dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})" ) ) @@ -88,8 +91,8 @@ xgboost_logistic <- function(x, type, lvl) { if ("prob" %in% type) { res <- c( res, - glue::glue("{eq}"), - glue::glue("1 - ({eq})") + orbital_tmp_prob_name1 = glue::glue("{eq}"), + orbital_tmp_prob_name2 = glue::glue("1 - ({eq})") ) } res diff --git a/R/parsnip.R b/R/parsnip.R index f888126..28bf46f 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -29,40 +29,46 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) { ) } + if (is.language(res)) { + res <- deparse1(res) + } + + res <- set_pred_names(res, x, mode, type, prefix) + + new_orbital_class(res) +} + +set_pred_names <- function(res, x, mode, type, prefix) { + if (mode == "regression") { + res <- stats::setNames(res, prefix) + attr(res, "pred_names") <- prefix + } + if (mode == "classification") { - names <- NULL + class_names <- NULL + prob_names <- NULL if ("class" %in% type) { - names <- c(names, paste0(prefix, "_class")) + class_names <- paste0(prefix, "_class") } if ("prob" %in% type) { - names <- c(names, paste0(prefix, "_", x$lvl)) + prob_names <- paste0(prefix, "_", x$lvl) } - } - if (mode == "regression") { - names <- prefix - } - if (is.language(res)) { - res <- deparse1(res) - } + attr(res, "pred_names") <- c(class_names, prob_names) - attr(res, "pred_names") <- names - if ( - inherits(x, "_xgb.Booster") && - isTRUE(x$fit$params$objective == "multi:softprob") - ) { - if (anyNA(names(res))) { - na_fields <- which(is.na(names(res))) - tmp_names <- names(res) - tmp_names[na_fields] <- paste0(prefix, "_", x$lvl) - names(res) <- tmp_names - } - } else { - res <- stats::setNames(res, names) + eq_names <- names(res) + + class_ind <- eq_names %in% "orbital_tmp_class_name" + prob_ind <- grepl("^orbital_tmp_prob_name", eq_names) + + eq_names[class_ind] <- class_names + eq_names[prob_ind] <- prob_names + + names(res) <- eq_names } - new_orbital_class(res) + res } #' @export