From b367a251a953bf9f9134ec59f30290acc20051f7 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 12:44:00 -0800 Subject: [PATCH 1/4] add orbital.xgb.Booster --- DESCRIPTION | 2 ++ NAMESPACE | 1 + R/model-xgboost.R | 59 +++++++++++++++++++++++++++++++++++++++++++++++ R/parsnip.R | 18 +++++++++++++-- 4 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 R/model-xgboost.R diff --git a/DESCRIPTION b/DESCRIPTION index c72175d..564ffc7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -45,6 +45,8 @@ Suggests: workflows VignetteBuilder: knitr +Remotes: + tidymodels/tidypredict Config/Needs/website: tidyverse/tidytemplate, rmarkdown, gt Config/testthat/edition: 3 Encoding: UTF-8 diff --git a/NAMESPACE b/NAMESPACE index e58b2f8..7b82ceb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -54,6 +54,7 @@ S3method(orbital,step_unknown) S3method(orbital,step_upsample) S3method(orbital,step_zv) S3method(orbital,workflow) +S3method(orbital,xgb.Booster) S3method(predict,orbital_class) S3method(print,orbital_class) export(augment) diff --git a/R/model-xgboost.R b/R/model-xgboost.R new file mode 100644 index 0000000..b875fd2 --- /dev/null +++ b/R/model-xgboost.R @@ -0,0 +1,59 @@ +#' @export +orbital.xgb.Booster <- function( + x, + ..., + mode = c("classification", "regression"), + type = NULL, + lvl = NULL +) { + mode <- rlang::arg_match(mode) + + if (mode == "classification") { + trees <- tidypredict::.extract_xgb_trees(x) + + trees_split <- split(trees, rep(seq_along(lvl), x$niter)) + trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") + + trees_split <- gsub("dplyr::case_when", "case_when", trees_split) + trees_split <- gsub("case_when", "dplyr::case_when", trees_split) + + res <- stats::setNames(trees_split, lvl) + + if (is.null(type)) { + type <- "class" + } + + if ("class" %in% type) { + res <- c( + res, + ".pred_class" = softmax(lvl) + ) + } + if ("prob" %in% type) { + res <- c( + res, + "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), + stats::setNames(glue::glue("exp({lvl}) / norm"), NA) + ) + } + } else if (mode == "regression") { + res <- tidypredict::tidypredict_fit(x) + } + res +} + +softmax <- function(lvl) { + res <- character(0) + + for (i in seq(1, length(lvl) - 1)) { + line <- glue::glue("{lvl[i]} > {lvl[-i]}") + line <- glue::glue_collapse(line, sep = " & ") + line <- glue::glue("{line} ~ {glue::double_quote(lvl[i])}") + res[i] <- line + } + + res <- glue::glue_collapse(res, ", ") + default <- glue::double_quote(lvl[length(lvl)]) + + glue::glue("dplyr::case_when({res}, .default = {default})") +} diff --git a/R/parsnip.R b/R/parsnip.R index f1acda6..3f9389e 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -4,7 +4,14 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) { check_mode(mode) check_type(type, mode) - res <- try(orbital(x$fit, mode = mode, type = type), silent = TRUE) + if (mode == "classification") { + res <- try( + orbital(x$fit, mode = mode, type = type, lvl = x$lvl), + silent = TRUE + ) + } else { + res <- try(orbital(x$fit, mode = mode, type = type), silent = TRUE) + } if (inherits(res, "try-error")) { res <- tryCatch( @@ -48,7 +55,14 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) { } attr(res, "pred_names") <- names - res <- stats::setNames(res, names) + if (inherits(x, "_xgb.Booster")) { + na_fields <- which(is.na(names(res))) + tmp_names <- names(res) + tmp_names[na_fields] <- paste0(prefix, "_", x$lvl) + names(res) <- tmp_names + } else { + res <- stats::setNames(res, names) + } new_orbital_class(res) } From 7460bc8dc54b4667e28d6168293627fd82588de1 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 16:07:57 -0800 Subject: [PATCH 2/4] add xgboost classification support --- NEWS.md | 3 + R/model-xgboost.R | 77 ++++++++---- R/parsnip.R | 15 ++- tests/testthat/test-model-xgboost.R | 183 ++++++++++++++++++++++++++++ vignettes/supported-models.Rmd | 2 +- 5 files changed, 253 insertions(+), 27 deletions(-) create mode 100644 tests/testthat/test-model-xgboost.R diff --git a/NEWS.md b/NEWS.md index fd25e6b..715e31d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,9 @@ * `orbital()` has gained `type` argument to change prediction type. (#66) +* `orbital()` now works with `boost_tree(engine = "xgboost")` models for class prediction and probability predictions. (#71) + + # orbital 0.2.0 * Support for `step_dummy()`, `step_impute_mean()`, `step_impute_median()`, `step_impute_mode()`, `step_unknown()`, `step_novel()`, `step_other()`, `step_BoxCox()`, `step_inverse()`, `step_mutate()`, `step_sqrt()`, `step_indicate_na()`, `step_range()`, `step_intercept()`, `step_ratio()`, `step_lag()`, `step_log()`, `step_rename()` has been added. (#17) diff --git a/R/model-xgboost.R b/R/model-xgboost.R index b875fd2..2046513 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -9,32 +9,67 @@ orbital.xgb.Booster <- function( mode <- rlang::arg_match(mode) if (mode == "classification") { - trees <- tidypredict::.extract_xgb_trees(x) + objective <- x$params$objective + # match arg objective - trees_split <- split(trees, rep(seq_along(lvl), x$niter)) - trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") + if (objective == "multi:softprob") { + trees <- tidypredict::.extract_xgb_trees(x) - trees_split <- gsub("dplyr::case_when", "case_when", trees_split) - trees_split <- gsub("case_when", "dplyr::case_when", trees_split) + trees_split <- split(trees, rep(seq_along(lvl), x$niter)) + trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") - res <- stats::setNames(trees_split, lvl) + trees_split <- gsub("dplyr::case_when", "case_when", trees_split) + trees_split <- gsub("case_when", "dplyr::case_when", trees_split) - if (is.null(type)) { - type <- "class" - } + res <- stats::setNames(trees_split, lvl) - if ("class" %in% type) { - res <- c( - res, - ".pred_class" = softmax(lvl) - ) - } - if ("prob" %in% type) { - res <- c( - res, - "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), - stats::setNames(glue::glue("exp({lvl}) / norm"), NA) - ) + if (is.null(type)) { + type <- "class" + } + + if ("class" %in% type) { + res <- c( + res, + ".pred_class" = softmax(lvl) + ) + } + if ("prob" %in% type) { + res <- c( + res, + "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), + stats::setNames(glue::glue("exp({lvl}) / norm"), NA) + ) + } + } else if (objective == "binary:logistic") { + eq <- tidypredict::tidypredict_fit(x) + + eq <- deparse1(eq) + + eq <- gsub("dplyr::case_when", "case_when", eq) + eq <- gsub("case_when", "dplyr::case_when", eq) + + if (is.null(type)) { + type <- "class" + } + + res <- NULL + if ("class" %in% type) { + levels <- glue::double_quote(lvl) + + res <- c( + res, + .pred_class = glue::glue( + "dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})" + ) + ) + } + if ("prob" %in% type) { + res <- c( + res, + glue::glue("{eq}"), + glue::glue("1 - ({eq})") + ) + } } } else if (mode == "regression") { res <- tidypredict::tidypredict_fit(x) diff --git a/R/parsnip.R b/R/parsnip.R index 3f9389e..6545755 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -55,11 +55,16 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) { } attr(res, "pred_names") <- names - if (inherits(x, "_xgb.Booster")) { - na_fields <- which(is.na(names(res))) - tmp_names <- names(res) - tmp_names[na_fields] <- paste0(prefix, "_", x$lvl) - names(res) <- tmp_names + if ( + inherits(x, "_xgb.Booster") && + isTRUE(x$fit$params$objective == "multi:softprob") + ) { + if (anyNA(names(res))) { + na_fields <- which(is.na(names(res))) + tmp_names <- names(res) + tmp_names[na_fields] <- paste0(prefix, "_", x$lvl) + names(res) <- tmp_names + } } else { res <- stats::setNames(res, names) } diff --git a/tests/testthat/test-model-xgboost.R b/tests/testthat/test-model-xgboost.R new file mode 100644 index 0000000..a0e0dff --- /dev/null +++ b/tests/testthat/test-model-xgboost.R @@ -0,0 +1,183 @@ +test_that("boost_tree(), objective = binary:logistic, works with type = class", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = "class") + + preds <- predict(orb_obj, mtcars) + exps <- predict(lr_fit, mtcars) + + expect_named(preds, ".pred_class") + expect_type(preds$.pred_class, "character") + + expect_identical( + preds$.pred_class, + as.character(exps$.pred_class) + ) +}) + +test_that("boost_tree(), objective = binary:logistic, works with type = class", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + + orb_obj <- orbital(lr_fit, type = "class") + + preds <- predict(orb_obj, iris) + exps <- predict(lr_fit, iris) + + expect_named(preds, ".pred_class") + expect_type(preds$.pred_class, "character") + + expect_identical( + preds$.pred_class, + as.character(exps$.pred_class) + ) +}) + +test_that("boost_tree(), objective = binary:logistic, works with type = prob", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = "prob") + + preds <- predict(orb_obj, mtcars) + exps <- predict(lr_fit, mtcars, type = "prob") + + expect_named(preds, c(".pred_0", ".pred_1")) + expect_type(preds$.pred_0, "double") + expect_type(preds$.pred_1, "double") + + exps <- as.data.frame(exps) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps, + tolerance = 0.0000001 + ) +}) + +test_that("boost_tree(), objective = binary:logistic, works with type = prob", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + + orb_obj <- orbital(lr_fit, type = "prob") + + preds <- predict(orb_obj, iris) + exps <- predict(lr_fit, iris, type = "prob") + + expect_named(preds, paste0(".pred_", levels(iris$Species))) + expect_type(preds$.pred_setosa, "double") + expect_type(preds$.pred_versicolor, "double") + expect_type(preds$.pred_virginica, "double") + + exps <- as.data.frame(exps) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps, + tolerance = 0.0000001 + ) +}) + +test_that("boost_tree(), objective = binary:logistic, works with type = c(class, prob)", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = c("class", "prob")) + + preds <- predict(orb_obj, mtcars) + exps <- dplyr::bind_cols( + predict(lr_fit, mtcars, type = c("class")), + predict(lr_fit, mtcars, type = c("prob")) + ) + + expect_named(preds, c(".pred_class", ".pred_0", ".pred_1")) + expect_type(preds$.pred_class, "character") + expect_type(preds$.pred_0, "double") + expect_type(preds$.pred_1, "double") + + exps <- as.data.frame(exps) + exps$.pred_class <- as.character(exps$.pred_class) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps, + tolerance = 0.0000001 + ) +}) + +test_that("boost_tree(), objective = binary:logistic, works with type = c(class, prob)", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("xgboost") + + lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + + lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + + orb_obj <- orbital(lr_fit, type = c("class", "prob")) + + preds <- predict(orb_obj, iris) + exps <- dplyr::bind_cols( + predict(lr_fit, iris, type = c("class")), + predict(lr_fit, iris, type = c("prob")) + ) + + expect_named(preds, c(".pred_class", paste0(".pred_", levels(iris$Species)))) + expect_type(preds$.pred_class, "character") + expect_type(preds$.pred_setosa, "double") + expect_type(preds$.pred_versicolor, "double") + expect_type(preds$.pred_virginica, "double") + + exps <- as.data.frame(exps) + exps$.pred_class <- as.character(exps$.pred_class) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps, + tolerance = 0.0000001 + ) +}) diff --git a/vignettes/supported-models.Rmd b/vignettes/supported-models.Rmd index 2f33511..f53e31d 100644 --- a/vignettes/supported-models.Rmd +++ b/vignettes/supported-models.Rmd @@ -29,7 +29,7 @@ library(dplyr) tibble::tribble( ~parsnip, ~engine, ~numeric, ~class, ~prob, - "`boost_tree()`", "`\"xgboost\"`", "✅", "⚪", "⚪", + "`boost_tree()`", "`\"xgboost\"`", "✅", "✅", "✅", "`cubist_rules()`", "`\"Cubist\"`", "✅", "❌", "❌", "`decision_tree()`", "`\"partykit\"`", "✅", "⚪", "⚪", "`linear_reg()`", "`\"lm\"`", "✅", "❌", "❌", From 1a7f1e6ebaac03542a6d3b7b14ec796cddf5c06e Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 16:25:24 -0800 Subject: [PATCH 3/4] refactor xgboost --- R/model-xgboost.R | 133 +++++++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 60 deletions(-) diff --git a/R/model-xgboost.R b/R/model-xgboost.R index 2046513..b8c51d6 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -10,73 +10,86 @@ orbital.xgb.Booster <- function( if (mode == "classification") { objective <- x$params$objective - # match arg objective - - if (objective == "multi:softprob") { - trees <- tidypredict::.extract_xgb_trees(x) - - trees_split <- split(trees, rep(seq_along(lvl), x$niter)) - trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") - - trees_split <- gsub("dplyr::case_when", "case_when", trees_split) - trees_split <- gsub("case_when", "dplyr::case_when", trees_split) - - res <- stats::setNames(trees_split, lvl) - - if (is.null(type)) { - type <- "class" - } - - if ("class" %in% type) { - res <- c( - res, - ".pred_class" = softmax(lvl) - ) - } - if ("prob" %in% type) { - res <- c( - res, - "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), - stats::setNames(glue::glue("exp({lvl}) / norm"), NA) - ) - } - } else if (objective == "binary:logistic") { - eq <- tidypredict::tidypredict_fit(x) - - eq <- deparse1(eq) - - eq <- gsub("dplyr::case_when", "case_when", eq) - eq <- gsub("case_when", "dplyr::case_when", eq) - - if (is.null(type)) { - type <- "class" - } - - res <- NULL - if ("class" %in% type) { - levels <- glue::double_quote(lvl) - - res <- c( - res, - .pred_class = glue::glue( - "dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})" - ) - ) - } - if ("prob" %in% type) { - res <- c( - res, - glue::glue("{eq}"), - glue::glue("1 - ({eq})") - ) - } + objective <- rlang::arg_match0( + objective, + c("multi:softprob", "binary:logistic") + ) + + if (is.null(type)) { + type <- "class" } + + extractor <- switch( + objective, + "multi:softprob" = xgboost_multisoft, + "binary:logistic" = xgboost_logistic + ) + + res <- extractor(x, type, lvl) } else if (mode == "regression") { res <- tidypredict::tidypredict_fit(x) } res } +xgboost_multisoft <- function(x, type, lvl) { + trees <- tidypredict::.extract_xgb_trees(x) + + trees_split <- split(trees, rep(seq_along(lvl), x$niter)) + trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") + trees_split <- namespace_case_when(trees_split) + + res <- stats::setNames(trees_split, lvl) + + if ("class" %in% type) { + res <- c( + res, + ".pred_class" = softmax(lvl) + ) + } + if ("prob" %in% type) { + res <- c( + res, + "norm" = glue::glue_collapse(glue::glue("exp({lvl})"), sep = " + "), + stats::setNames(glue::glue("exp({lvl}) / norm"), NA) + ) + } + res +} + +xgboost_logistic <- function(x, type, lvl) { + eq <- tidypredict::tidypredict_fit(x) + + eq <- deparse1(eq) + eq <- namespace_case_when(eq) + + res <- NULL + if ("class" %in% type) { + levels <- glue::double_quote(lvl) + + res <- c( + res, + .pred_class = glue::glue( + "dplyr::case_when({eq} > 0.5 ~ {levels[1]}, .default = {levels[2]})" + ) + ) + } + if ("prob" %in% type) { + res <- c( + res, + glue::glue("{eq}"), + glue::glue("1 - ({eq})") + ) + } + res +} + +namespace_case_when <- function(x) { + x <- gsub("dplyr::case_when", "case_when", x) + x <- gsub("case_when", "dplyr::case_when", x) + x +} + softmax <- function(lvl) { res <- character(0) From df167a6841605f7427a6552663ed3920962cd922 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 16:25:34 -0800 Subject: [PATCH 4/4] rename objects in test xgboost --- tests/testthat/test-model-xgboost.R | 52 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/testthat/test-model-xgboost.R b/tests/testthat/test-model-xgboost.R index a0e0dff..4537a36 100644 --- a/tests/testthat/test-model-xgboost.R +++ b/tests/testthat/test-model-xgboost.R @@ -5,14 +5,14 @@ test_that("boost_tree(), objective = binary:logistic, works with type = class", mtcars$vs <- factor(mtcars$vs) - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars) - orb_obj <- orbital(lr_fit, type = "class") + orb_obj <- orbital(bt_fit, type = "class") preds <- predict(orb_obj, mtcars) - exps <- predict(lr_fit, mtcars) + exps <- predict(bt_fit, mtcars) expect_named(preds, ".pred_class") expect_type(preds$.pred_class, "character") @@ -28,14 +28,14 @@ test_that("boost_tree(), objective = binary:logistic, works with type = class", skip_if_not_installed("tidypredict") skip_if_not_installed("xgboost") - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris) - orb_obj <- orbital(lr_fit, type = "class") + orb_obj <- orbital(bt_fit, type = "class") preds <- predict(orb_obj, iris) - exps <- predict(lr_fit, iris) + exps <- predict(bt_fit, iris) expect_named(preds, ".pred_class") expect_type(preds$.pred_class, "character") @@ -53,14 +53,14 @@ test_that("boost_tree(), objective = binary:logistic, works with type = prob", { mtcars$vs <- factor(mtcars$vs) - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars) - orb_obj <- orbital(lr_fit, type = "prob") + orb_obj <- orbital(bt_fit, type = "prob") preds <- predict(orb_obj, mtcars) - exps <- predict(lr_fit, mtcars, type = "prob") + exps <- predict(bt_fit, mtcars, type = "prob") expect_named(preds, c(".pred_0", ".pred_1")) expect_type(preds$.pred_0, "double") @@ -83,14 +83,14 @@ test_that("boost_tree(), objective = binary:logistic, works with type = prob", { skip_if_not_installed("tidypredict") skip_if_not_installed("xgboost") - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris) - orb_obj <- orbital(lr_fit, type = "prob") + orb_obj <- orbital(bt_fit, type = "prob") preds <- predict(orb_obj, iris) - exps <- predict(lr_fit, iris, type = "prob") + exps <- predict(bt_fit, iris, type = "prob") expect_named(preds, paste0(".pred_", levels(iris$Species))) expect_type(preds$.pred_setosa, "double") @@ -116,16 +116,16 @@ test_that("boost_tree(), objective = binary:logistic, works with type = c(class, mtcars$vs <- factor(mtcars$vs) - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + bt_fit <- parsnip::fit(bt_spec, vs ~ disp + mpg + hp, mtcars) - orb_obj <- orbital(lr_fit, type = c("class", "prob")) + orb_obj <- orbital(bt_fit, type = c("class", "prob")) preds <- predict(orb_obj, mtcars) exps <- dplyr::bind_cols( - predict(lr_fit, mtcars, type = c("class")), - predict(lr_fit, mtcars, type = c("prob")) + predict(bt_fit, mtcars, type = c("class")), + predict(bt_fit, mtcars, type = c("prob")) ) expect_named(preds, c(".pred_class", ".pred_0", ".pred_1")) @@ -151,16 +151,16 @@ test_that("boost_tree(), objective = binary:logistic, works with type = c(class, skip_if_not_installed("tidypredict") skip_if_not_installed("xgboost") - lr_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") + bt_spec <- parsnip::boost_tree(mode = "classification", engine = "xgboost") - lr_fit <- parsnip::fit(lr_spec, Species ~ ., iris) + bt_fit <- parsnip::fit(bt_spec, Species ~ ., iris) - orb_obj <- orbital(lr_fit, type = c("class", "prob")) + orb_obj <- orbital(bt_fit, type = c("class", "prob")) preds <- predict(orb_obj, iris) exps <- dplyr::bind_cols( - predict(lr_fit, iris, type = c("class")), - predict(lr_fit, iris, type = c("prob")) + predict(bt_fit, iris, type = c("class")), + predict(bt_fit, iris, type = c("prob")) ) expect_named(preds, c(".pred_class", paste0(".pred_", levels(iris$Species))))