From 0bbaf4c61b680495161b61a160e08079dedce2b3 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 19 Dec 2024 17:39:56 -0800 Subject: [PATCH] use namespace_case_when() more generally to close #53 --- R/model-partykit.R | 6 ++---- R/model-xgboost.R | 8 -------- R/parsnip.R | 1 + R/utils.R | 7 +++++++ R/workflows.R | 2 ++ 5 files changed, 12 insertions(+), 12 deletions(-) create mode 100644 R/utils.R diff --git a/R/model-partykit.R b/R/model-partykit.R index e595056..cf36627 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -12,14 +12,12 @@ orbital.constparty <- function( if (mode == "classification") { res <- character() if ("class" %in% type) { - eq <- tidypredict::tidypredict_fit(x) - eq <- deparse1(eq) - eq <- namespace_case_when(eq) + eq <- tidypredict::tidypredict_fit(x) + eq <- deparse1(eq) res <- c(res, orbital_tmp_class_name = eq) } if ("prob" %in% type) { eqs <- tidypredict::.extract_partykit_classprob(x) - eqs <- namespace_case_when(eqs) names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl)) res <- c(res, eqs) } diff --git a/R/model-xgboost.R b/R/model-xgboost.R index 5e8492c..393f272 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -35,7 +35,6 @@ xgboost_multisoft <- function(x, type, lvl) { trees_split <- split(trees, rep(seq_along(lvl), x$niter)) trees_split <- lapply(trees_split, collapse_stumps) trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") - trees_split <- namespace_case_when(trees_split) res <- stats::setNames(trees_split, lvl) @@ -75,7 +74,6 @@ xgboost_logistic <- function(x, type, lvl) { eq <- tidypredict::tidypredict_fit(x) eq <- deparse1(eq) - eq <- namespace_case_when(eq) res <- NULL if ("class" %in% type) { @@ -98,12 +96,6 @@ xgboost_logistic <- function(x, type, lvl) { res } -namespace_case_when <- function(x) { - x <- gsub("dplyr::case_when", "case_when", x) - x <- gsub("case_when", "dplyr::case_when", x) - x -} - softmax <- function(lvl) { res <- character(0) diff --git a/R/parsnip.R b/R/parsnip.R index 28bf46f..9082e29 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -33,6 +33,7 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) { res <- deparse1(res) } + res <- namespace_case_when(res) res <- set_pred_names(res, x, mode, type, prefix) new_orbital_class(res) diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..51494c6 --- /dev/null +++ b/R/utils.R @@ -0,0 +1,7 @@ +namespace_case_when <- function(x) { + names <- names(x) + x <- gsub("dplyr::case_when", "case_when", x) + x <- gsub("case_when", "dplyr::case_when", x) + names(x) <- names + x +} diff --git a/R/workflows.R b/R/workflows.R index eec0fe9..4adec9e 100644 --- a/R/workflows.R +++ b/R/workflows.R @@ -21,5 +21,7 @@ orbital.workflow <- function(x, ..., prefix = ".pred", type = NULL) { attr(out, "pred_names") <- pred_names } + out <- namespace_case_when(out) + new_orbital_class(out) }