Skip to content

Commit

Permalink
Merge pull request #73 from tidymodels/set-names-parsnip
Browse files Browse the repository at this point in the history
refactor out set_pred_names()
  • Loading branch information
EmilHvitfeldt authored Dec 17, 2024
2 parents 76db153 + 398adb8 commit 4fea0f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
6 changes: 3 additions & 3 deletions R/model-glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ orbital.glm <- function(
if ("class" %in% type) {
res <- c(
res,
glue::glue(
orbital_tmp_class_name = glue::glue(
"dplyr::case_when({eq} < 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("1 - ({eq})"),
glue::glue("{eq}")
orbital_tmp_prob_name1 = glue::glue("1 - ({eq})"),
orbital_tmp_prob_name2 = glue::glue("{eq}")
)
}
} else if (mode == "regression") {
Expand Down
13 changes: 8 additions & 5 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ xgboost_multisoft <- function(x, type, lvl) {
if ("class" %in% type) {
res <- c(
res,
".pred_class" = softmax(lvl)
orbital_tmp_class_name = softmax(lvl)
)
}
if ("prob" %in% type) {
eqs <- glue::glue("exp({lvl}) / norm")
names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl))

res <- c(
res,
"norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "),
stats::setNames(glue::glue("exp({lvl}) / norm"), NA)
eqs
)
}
res
Expand Down Expand Up @@ -80,16 +83,16 @@ xgboost_logistic <- function(x, type, lvl) {

res <- c(
res,
.pred_class = glue::glue(
orbital_tmp_class_name = glue::glue(
"dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("{eq}"),
glue::glue("1 - ({eq})")
orbital_tmp_prob_name1 = glue::glue("{eq}"),
orbital_tmp_prob_name2 = glue::glue("1 - ({eq})")
)
}
res
Expand Down
54 changes: 30 additions & 24 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,46 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
)
}

if (is.language(res)) {
res <- deparse1(res)
}

res <- set_pred_names(res, x, mode, type, prefix)

new_orbital_class(res)
}

set_pred_names <- function(res, x, mode, type, prefix) {
if (mode == "regression") {
res <- stats::setNames(res, prefix)
attr(res, "pred_names") <- prefix
}

if (mode == "classification") {
names <- NULL
class_names <- NULL
prob_names <- NULL

if ("class" %in% type) {
names <- c(names, paste0(prefix, "_class"))
class_names <- paste0(prefix, "_class")
}
if ("prob" %in% type) {
names <- c(names, paste0(prefix, "_", x$lvl))
prob_names <- paste0(prefix, "_", x$lvl)
}
}
if (mode == "regression") {
names <- prefix
}

if (is.language(res)) {
res <- deparse1(res)
}
attr(res, "pred_names") <- c(class_names, prob_names)

attr(res, "pred_names") <- names
if (
inherits(x, "_xgb.Booster") &&
isTRUE(x$fit$params$objective == "multi:softprob")
) {
if (anyNA(names(res))) {
na_fields <- which(is.na(names(res)))
tmp_names <- names(res)
tmp_names[na_fields] <- paste0(prefix, "_", x$lvl)
names(res) <- tmp_names
}
} else {
res <- stats::setNames(res, names)
eq_names <- names(res)

class_ind <- eq_names %in% "orbital_tmp_class_name"
prob_ind <- grepl("^orbital_tmp_prob_name", eq_names)

eq_names[class_ind] <- class_names
eq_names[prob_ind] <- prob_names

names(res) <- eq_names
}

new_orbital_class(res)
res
}

#' @export
Expand Down

0 comments on commit 4fea0f7

Please sign in to comment.