diff --git a/R/recipes.R b/R/recipes.R index e835589..79f52ce 100644 --- a/R/recipes.R +++ b/R/recipes.R @@ -9,7 +9,10 @@ orbital.recipe <- function(x, eqs = NULL, ..., prefix = ".pred") { terms <- x$term_info all_vars <- terms$variable[terms$role == "predictor"] } else { - all_vars <- all.vars(rlang::parse_expr(eqs)) + all_vars <- rlang::parse_exprs(eqs) + all_vars <- lapply(all_vars, all.vars) + all_vars <- unlist(all_vars, use.names = FALSE) + all_vars <- unique(all_vars) } n_steps <- length(x$steps) @@ -17,7 +20,7 @@ orbital.recipe <- function(x, eqs = NULL, ..., prefix = ".pred") { if (is.null(eqs)) { out <- c() } else { - out <- stats::setNames(unname(eqs), prefix) + out <- unclass(eqs) } for (step in rev(x$steps)) { diff --git a/tests/testthat/test-workflows.R b/tests/testthat/test-workflows.R index 44ef2e3..4924d00 100644 --- a/tests/testthat/test-workflows.R +++ b/tests/testthat/test-workflows.R @@ -57,3 +57,45 @@ test_that("type argument checking works", { orbital(wf_fit, type = c("class", "numeric")) ) }) + +test_that("pred_names) works with type = c(class, prob) and recipes", { + skip_if_not_installed("parsnip") + skip_if_not_installed("workflows") + skip_if_not_installed("recipes") + skip_if_not_installed("tidypredict") + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::logistic_reg() + + rec_spec <- recipes::recipe(vs ~ disp + mpg + hp, mtcars) %>% + recipes::step_center(disp, mpg, hp) + + wf_spec <- workflows::workflow(rec_spec, lr_spec) + + wf_fit <- parsnip::fit(wf_spec, mtcars) + + orb_obj <- orbital(wf_fit, type = c("class", "prob")) + + preds <- predict(orb_obj, mtcars) + exps <- dplyr::bind_cols( + predict(wf_fit, mtcars, type = c("class")), + predict(wf_fit, mtcars, type = c("prob")) + ) + + expect_named(preds, c(".pred_class", ".pred_0", ".pred_1")) + expect_type(preds$.pred_class, "character") + expect_type(preds$.pred_0, "double") + expect_type(preds$.pred_1, "double") + + exps <- as.data.frame(exps) + exps$.pred_class <- as.character(exps$.pred_class) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps + ) +})