Skip to content

Commit

Permalink
Merge pull request #71 from tidymodels/xgboost-classification
Browse files Browse the repository at this point in the history
Xgboost classification
  • Loading branch information
EmilHvitfeldt authored Dec 17, 2024
2 parents f7eff40 + df167a6 commit 2d923ac
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 3 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Suggests:
workflows
VignetteBuilder:
knitr
Remotes:
tidymodels/tidypredict
Config/Needs/website: tidyverse/tidytemplate, rmarkdown, gt
Config/testthat/edition: 3
Encoding: UTF-8
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ S3method(orbital,step_unknown)
S3method(orbital,step_upsample)
S3method(orbital,step_zv)
S3method(orbital,workflow)
S3method(orbital,xgb.Booster)
S3method(predict,orbital_class)
S3method(print,orbital_class)
export(augment)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

* `orbital()` has gained `type` argument to change prediction type. (#66)

* `orbital()` now works with `boost_tree(engine = "xgboost")` models for class prediction and probability predictions. (#71)


# orbital 0.2.0

* Support for `step_dummy()`, `step_impute_mean()`, `step_impute_median()`, `step_impute_mode()`, `step_unknown()`, `step_novel()`, `step_other()`, `step_BoxCox()`, `step_inverse()`, `step_mutate()`, `step_sqrt()`, `step_indicate_na()`, `step_range()`, `step_intercept()`, `step_ratio()`, `step_lag()`, `step_log()`, `step_rename()` has been added. (#17)
Expand Down
107 changes: 107 additions & 0 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#' @export
orbital.xgb.Booster <- function(
x,
...,
mode = c("classification", "regression"),
type = NULL,
lvl = NULL
) {
mode <- rlang::arg_match(mode)

if (mode == "classification") {
objective <- x$params$objective
objective <- rlang::arg_match0(
objective,
c("multi:softprob", "binary:logistic")
)

if (is.null(type)) {
type <- "class"
}

extractor <- switch(
objective,
"multi:softprob" = xgboost_multisoft,
"binary:logistic" = xgboost_logistic
)

res <- extractor(x, type, lvl)
} else if (mode == "regression") {
res <- tidypredict::tidypredict_fit(x)
}
res
}

xgboost_multisoft <- function(x, type, lvl) {
trees <- tidypredict::.extract_xgb_trees(x)

trees_split <- split(trees, rep(seq_along(lvl), x$niter))
trees_split <- vapply(trees_split, paste, character(1), collapse = " + ")
trees_split <- namespace_case_when(trees_split)

res <- stats::setNames(trees_split, lvl)

if ("class" %in% type) {
res <- c(
res,
".pred_class" = softmax(lvl)
)
}
if ("prob" %in% type) {
res <- c(
res,
"norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "),
stats::setNames(glue::glue("exp({lvl}) / norm"), NA)
)
}
res
}

xgboost_logistic <- function(x, type, lvl) {
eq <- tidypredict::tidypredict_fit(x)

eq <- deparse1(eq)
eq <- namespace_case_when(eq)

res <- NULL
if ("class" %in% type) {
levels <- glue::double_quote(lvl)

res <- c(
res,
.pred_class = glue::glue(
"dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("{eq}"),
glue::glue("1 - ({eq})")
)
}
res
}

namespace_case_when <- function(x) {
x <- gsub("dplyr::case_when", "case_when", x)
x <- gsub("case_when", "dplyr::case_when", x)
x
}

softmax <- function(lvl) {
res <- character(0)

for (i in seq(1, length(lvl) - 1)) {
line <- glue::glue("{lvl[i]} > {lvl[-i]}")
line <- glue::glue_collapse(line, sep = " & ")
line <- glue::glue("{line} ~ {glue::double_quote(lvl[i])}")
res[i] <- line
}

res <- glue::glue_collapse(res, ", ")
default <- glue::double_quote(lvl[length(lvl)])

glue::glue("dplyr::case_when({res}, .default = {default})")
}
23 changes: 21 additions & 2 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
check_mode(mode)
check_type(type, mode)

res <- try(orbital(x$fit, mode = mode, type = type), silent = TRUE)
if (mode == "classification") {
res <- try(
orbital(x$fit, mode = mode, type = type, lvl = x$lvl),
silent = TRUE
)
} else {
res <- try(orbital(x$fit, mode = mode, type = type), silent = TRUE)
}

if (inherits(res, "try-error")) {
res <- tryCatch(
Expand Down Expand Up @@ -48,7 +55,19 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
}

attr(res, "pred_names") <- names
res <- stats::setNames(res, names)
if (
inherits(x, "_xgb.Booster") &&
isTRUE(x$fit$params$objective == "multi:softprob")
) {
if (anyNA(names(res))) {
na_fields <- which(is.na(names(res)))
tmp_names <- names(res)
tmp_names[na_fields] <- paste0(prefix, "_", x$lvl)
names(res) <- tmp_names
}
} else {
res <- stats::setNames(res, names)
}

new_orbital_class(res)
}
Expand Down
183 changes: 183 additions & 0 deletions tests/testthat/test-model-xgboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
test_that("boost_tree(), objective = binary:logistic, works with type = class", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

mtcars$vs <- factor(mtcars$vs)

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(bt_fit, type = "class")

preds <- predict(orb_obj, mtcars)
exps <- predict(bt_fit, mtcars)

expect_named(preds, ".pred_class")
expect_type(preds$.pred_class, "character")

expect_identical(
preds$.pred_class,
as.character(exps$.pred_class)
)
})

test_that("boost_tree(), objective = binary:logistic, works with type = class", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris)

orb_obj <- orbital(bt_fit, type = "class")

preds <- predict(orb_obj, iris)
exps <- predict(bt_fit, iris)

expect_named(preds, ".pred_class")
expect_type(preds$.pred_class, "character")

expect_identical(
preds$.pred_class,
as.character(exps$.pred_class)
)
})

test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

mtcars$vs <- factor(mtcars$vs)

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(bt_fit, type = "prob")

preds <- predict(orb_obj, mtcars)
exps <- predict(bt_fit, mtcars, type = "prob")

expect_named(preds, c(".pred_0", ".pred_1"))
expect_type(preds$.pred_0, "double")
expect_type(preds$.pred_1, "double")

exps <- as.data.frame(exps)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps,
tolerance = 0.0000001
)
})

test_that("boost_tree(), objective = binary:logistic, works with type = prob", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris)

orb_obj <- orbital(bt_fit, type = "prob")

preds <- predict(orb_obj, iris)
exps <- predict(bt_fit, iris, type = "prob")

expect_named(preds, paste0(".pred_", levels(iris$Species)))
expect_type(preds$.pred_setosa, "double")
expect_type(preds$.pred_versicolor, "double")
expect_type(preds$.pred_virginica, "double")

exps <- as.data.frame(exps)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps,
tolerance = 0.0000001
)
})

test_that("boost_tree(), objective = binary:logistic, works with type = c(class, prob)", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

mtcars$vs <- factor(mtcars$vs)

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(bt_fit, type = c("class", "prob"))

preds <- predict(orb_obj, mtcars)
exps <- dplyr::bind_cols(
predict(bt_fit, mtcars, type = c("class")),
predict(bt_fit, mtcars, type = c("prob"))
)

expect_named(preds, c(".pred_class", ".pred_0", ".pred_1"))
expect_type(preds$.pred_class, "character")
expect_type(preds$.pred_0, "double")
expect_type(preds$.pred_1, "double")

exps <- as.data.frame(exps)
exps$.pred_class <- as.character(exps$.pred_class)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps,
tolerance = 0.0000001
)
})

test_that("boost_tree(), objective = binary:logistic, works with type = c(class, prob)", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("xgboost")

bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost")

bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris)

orb_obj <- orbital(bt_fit, type = c("class", "prob"))

preds <- predict(orb_obj, iris)
exps <- dplyr::bind_cols(
predict(bt_fit, iris, type = c("class")),
predict(bt_fit, iris, type = c("prob"))
)

expect_named(preds, c(".pred_class", paste0(".pred_", levels(iris$Species))))
expect_type(preds$.pred_class, "character")
expect_type(preds$.pred_setosa, "double")
expect_type(preds$.pred_versicolor, "double")
expect_type(preds$.pred_virginica, "double")

exps <- as.data.frame(exps)
exps$.pred_class <- as.character(exps$.pred_class)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps,
tolerance = 0.0000001
)
})
2 changes: 1 addition & 1 deletion vignettes/supported-models.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ library(dplyr)
tibble::tribble(
~parsnip, ~engine, ~numeric, ~class, ~prob,
"`boost_tree()`", "`\"xgboost\"`", "✅", "", "",
"`boost_tree()`", "`\"xgboost\"`", "✅", "", "",
"`cubist_rules()`", "`\"Cubist\"`", "✅", "❌", "❌",
"`decision_tree()`", "`\"partykit\"`", "✅", "⚪", "⚪",
"`linear_reg()`", "`\"lm\"`", "✅", "❌", "❌",
Expand Down

0 comments on commit 2d923ac

Please sign in to comment.