Skip to content

Commit

Permalink
respect predict(type) with a postprocessor
Browse files Browse the repository at this point in the history
Closes #251, closes #234. Another route we could have taken here is to take a `type` in `predict.tailor()`, but this would lead to a different type needing to be supplied to `predict.model_fit()` than `predict.tailor()` in interactive usage if the user is only interested in hard class predictions.
  • Loading branch information
simonpcouch committed Jan 9, 2025
1 parent f070c64 commit be22aac
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
24 changes: 20 additions & 4 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,33 @@ predict.workflow <- function(object, new_data, type = NULL, opts = list(), ...)
return(predict(fit, new_data, type = type, opts = opts, ...))
}

# use `augment()` rather than `fit()` to get all possible prediction `type`s.
# likely, we actually want tailor to check for the existence of needed
# columns at predict time and just use `predict()` output here.
# use `augment()` rather than `fit()` to get all possible prediction `type`s (#234).
fit_aug <- augment(fit, new_data, opts = opts, ...)

post <- extract_postprocessor(workflow)
predict(post, fit_aug)[post$columns$estimate]
predict(post, fit_aug)[predict_type_column_names(type, post$columns)]
}

forge_predictors <- function(new_data, workflow) {
mold <- extract_mold(workflow)
forged <- hardhat::forge(new_data, blueprint = mold$blueprint)
forged$predictors
}

predict_type_column_names <- function(type, tailor_columns, call = caller_env()) {
check_string(type, allow_null = TRUE, call = call)

if (is.null(type)) {
return(tailor_columns$estimate)
}

switch(
type,
numeric = , class = tailor_columns$estimate,
prob = tailor_columns$probabilities,
cli::cli_abort(
"Unsupported prediction {.arg type} {.val {type}} for a workflow with a postprocessor.",
call = call
)
)
}
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/predict.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
! Can't predict on an untrained workflow.
i Do you need to call `fit()`?

# predict(type) is respected with a postprocessor (#251)

Code
predict(wflow_fit, d[1:5, ], type = "boop")
Condition
Error in `predict()`:
! Unsupported prediction `type` "boop" for a workflow with a postprocessor.

18 changes: 18 additions & 0 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,21 @@ test_that("monitoring: no double intercept due to dot expansion in model formula
# so lm()'s predict method won't error anymore here. (tidymodels/parsnip#1033)
expect_no_error(predict(fit_with_intercept, mtcars))
})

test_that("predict(type) is respected with a postprocessor (#251)", {
# create example data
y <- seq(0, 7, .1)
d <- data.frame(y = as.factor(ifelse(y > 3.5, "yes", "no")), x = y + (y-3)^2)
wflow <- workflow(y ~ ., parsnip::logistic_reg(), tailor::tailor())
wflow_fit <- fit(wflow, d)

pred_class <- predict(wflow_fit, d[1:5,], type = "class")
pred_prob <- predict(wflow_fit, d[1:5,], type = "prob")
pred_null <- predict(wflow_fit, d[1:5,])

expect_named(pred_class, ".pred_class")
expect_named(pred_prob, c(".pred_no", ".pred_yes"), ignore.order = TRUE)
expect_equal(pred_class, pred_null)

expect_snapshot(error = TRUE, predict(wflow_fit, d[1:5,], type = "boop"))
})

0 comments on commit be22aac

Please sign in to comment.