Skip to content

Commit

Permalink
Merge pull request #70 from tidymodels/pred_vars-recipes
Browse files Browse the repository at this point in the history
fix bug orbital.recipe couldn't handle multiple eqs
  • Loading branch information
EmilHvitfeldt authored Dec 14, 2024
2 parents e652446 + 45543e3 commit a0786ec
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
7 changes: 5 additions & 2 deletions R/recipes.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ orbital.recipe <- function(x, eqs = NULL, ..., prefix = ".pred") {
terms <- x$term_info
all_vars <- terms$variable[terms$role == "predictor"]
} else {
all_vars <- all.vars(rlang::parse_expr(eqs))
all_vars <- rlang::parse_exprs(eqs)
all_vars <- lapply(all_vars, all.vars)
all_vars <- unlist(all_vars, use.names = FALSE)
all_vars <- unique(all_vars)
}

n_steps <- length(x$steps)

if (is.null(eqs)) {
out <- c()
} else {
out <- stats::setNames(unname(eqs), prefix)
out <- unclass(eqs)
}

for (step in rev(x$steps)) {
Expand Down
42 changes: 42 additions & 0 deletions tests/testthat/test-workflows.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,45 @@ test_that("type argument checking works", {
orbital(wf_fit, type = c("class", "numeric"))
)
})

test_that("pred_names) works with type = c(class, prob) and recipes", {
skip_if_not_installed("parsnip")
skip_if_not_installed("workflows")
skip_if_not_installed("recipes")
skip_if_not_installed("tidypredict")

mtcars$vs <- factor(mtcars$vs)

lr_spec <- parsnip::logistic_reg()

rec_spec <- recipes::recipe(vs ~ disp + mpg + hp, mtcars) %>%
recipes::step_center(disp, mpg, hp)

wf_spec <- workflows::workflow(rec_spec, lr_spec)

wf_fit <- parsnip::fit(wf_spec, mtcars)

orb_obj <- orbital(wf_fit, type = c("class", "prob"))

preds <- predict(orb_obj, mtcars)
exps <- dplyr::bind_cols(
predict(wf_fit, mtcars, type = c("class")),
predict(wf_fit, mtcars, type = c("prob"))
)

expect_named(preds, c(".pred_class", ".pred_0", ".pred_1"))
expect_type(preds$.pred_class, "character")
expect_type(preds$.pred_0, "double")
expect_type(preds$.pred_1, "double")

exps <- as.data.frame(exps)
exps$.pred_class <- as.character(exps$.pred_class)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps
)
})

0 comments on commit a0786ec

Please sign in to comment.