Skip to content

Commit

Permalink
refactor xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Dec 17, 2024
1 parent 7460bc8 commit 1a7f1e6
Showing 1 changed file with 73 additions and 60 deletions.
133 changes: 73 additions & 60 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,86 @@ orbital.xgb.Booster <- function(

if (mode == "classification") {
objective <- x$params$objective
# match arg objective

if (objective == "multi:softprob") {
trees <- tidypredict::.extract_xgb_trees(x)

trees_split <- split(trees, rep(seq_along(lvl), x$niter))
trees_split <- vapply(trees_split, paste, character(1), collapse = " + ")

trees_split <- gsub("dplyr::case_when", "case_when", trees_split)
trees_split <- gsub("case_when", "dplyr::case_when", trees_split)

res <- stats::setNames(trees_split, lvl)

if (is.null(type)) {
type <- "class"
}

if ("class" %in% type) {
res <- c(
res,
".pred_class" = softmax(lvl)
)
}
if ("prob" %in% type) {
res <- c(
res,
"norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "),
stats::setNames(glue::glue("exp({lvl}) / norm"), NA)
)
}
} else if (objective == "binary:logistic") {
eq <- tidypredict::tidypredict_fit(x)

eq <- deparse1(eq)

eq <- gsub("dplyr::case_when", "case_when", eq)
eq <- gsub("case_when", "dplyr::case_when", eq)

if (is.null(type)) {
type <- "class"
}

res <- NULL
if ("class" %in% type) {
levels <- glue::double_quote(lvl)

res <- c(
res,
.pred_class = glue::glue(
"dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("{eq}"),
glue::glue("1 - ({eq})")
)
}
objective <- rlang::arg_match0(
objective,
c("multi:softprob", "binary:logistic")
)

if (is.null(type)) {
type <- "class"
}

extractor <- switch(
objective,
"multi:softprob" = xgboost_multisoft,
"binary:logistic" = xgboost_logistic
)

res <- extractor(x, type, lvl)
} else if (mode == "regression") {
res <- tidypredict::tidypredict_fit(x)
}
res
}

xgboost_multisoft <- function(x, type, lvl) {
trees <- tidypredict::.extract_xgb_trees(x)

trees_split <- split(trees, rep(seq_along(lvl), x$niter))
trees_split <- vapply(trees_split, paste, character(1), collapse = " + ")
trees_split <- namespace_case_when(trees_split)

res <- stats::setNames(trees_split, lvl)

if ("class" %in% type) {
res <- c(
res,
".pred_class" = softmax(lvl)
)
}
if ("prob" %in% type) {
res <- c(
res,
"norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "),
stats::setNames(glue::glue("exp({lvl}) / norm"), NA)
)
}
res
}

xgboost_logistic <- function(x, type, lvl) {
eq <- tidypredict::tidypredict_fit(x)

eq <- deparse1(eq)
eq <- namespace_case_when(eq)

res <- NULL
if ("class" %in% type) {
levels <- glue::double_quote(lvl)

res <- c(
res,
.pred_class = glue::glue(
"dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("{eq}"),
glue::glue("1 - ({eq})")
)
}
res
}

namespace_case_when <- function(x) {
x <- gsub("dplyr::case_when", "case_when", x)
x <- gsub("case_when", "dplyr::case_when", x)
x
}

softmax <- function(lvl) {
res <- character(0)

Expand Down

0 comments on commit 1a7f1e6

Please sign in to comment.