Skip to content

Commit

Permalink
Merge pull request #69 from tidymodels/type-argument
Browse files Browse the repository at this point in the history
add type argument
  • Loading branch information
EmilHvitfeldt authored Dec 14, 2024
2 parents 2a2d50f + b2fb086 commit e652446
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 22 deletions.
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

* `orbital()` gained `prefix` argument to allow for renaming of prediction columns. (#59)

* `orbital()` now works with `logistic_reg()` models. (#62)
* `orbital()` now works with `logistic_reg()` models for class prediction and probability predictions. (#62, #66)

* `orbital()` has gained `type` argument to change prediction type. (#66)

# orbital 0.2.0

Expand Down
37 changes: 30 additions & 7 deletions R/model-glm.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
#' @export
orbital.glm <- function(x, ..., mode = c("classification", "regression")) {
orbital.glm <- function(
x,
...,
mode = c("classification", "regression"),
type = NULL
) {
mode <- rlang::arg_match(mode)

if (mode == "classification") {
outcome <- names(attr(x$terms, "dataClasses"))[attr(x$terms, "response")]
outcome <- names(attr(x$terms, "dataClasses"))[attr(x$terms, "response")]
levels <- levels(x$data[[outcome]])
levels <- glue::double_quote(levels)
res <- tidypredict::tidypredict_fit(x)
res <- deparse1(res)
res <- glue::glue(
"dplyr::case_when({res} < 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
eq <- tidypredict::tidypredict_fit(x)
eq <- deparse1(eq)

if (is.null(type)) {
type <- "class"
}

res <- NULL
if ("class" %in% type) {
res <- c(
res,
glue::glue(
"dplyr::case_when({eq} < 0.5 ~ {levels[1]}, .default = {levels[2]})"
)
)
}
if ("prob" %in% type) {
res <- c(
res,
glue::glue("1 - ({eq})"),
glue::glue("{eq}")
)
}
} else if (mode == "regression") {
res <- tidypredict::tidypredict_fit(x)
}
Expand Down
7 changes: 6 additions & 1 deletion R/orbital.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#' If `x` produces a prediction, tidymodels standards dictate that the
#' predictions will start with `.pred`. This is not a valid name for
#' some data bases.
#' @param type A vector of strings, specifies the prediction type. Regression
#' models allow for `"numeric"` and classification models allow for `"class"`
#' and `"prob"`. Multiple values are allowed to produce hard and soft
#' predictions for classification models. Defaults to `NULL` which defaults to
#' `"numeric"` for regression models and `"class"` for classification models.
#'
#' @returns An [orbital] object.
#'
Expand Down Expand Up @@ -58,7 +63,7 @@
#' orbital()
#'
#' @export
orbital <- function(x, ..., prefix = ".pred") {
orbital <- function(x, ..., prefix = ".pred", type = NULL) {
UseMethod("orbital")
}

Expand Down
49 changes: 44 additions & 5 deletions R/parsnip.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#' @export
orbital.model_fit <- function(x, ..., prefix = ".pred") {
orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
mode <- x$spec$mode

check_mode(mode)
check_type(type, mode)

res <- try(orbital(x$fit, mode = mode), silent = TRUE)
res <- try(orbital(x$fit, mode = mode, type = type), silent = TRUE)

if (inherits(res, "try-error")) {
res <- tryCatch(
Expand All @@ -26,14 +26,29 @@ orbital.model_fit <- function(x, ..., prefix = ".pred") {
}

if (mode == "classification") {
prefix <- paste0(prefix, "_class")
names <- NULL

if (is.null(type)) {
type <- "class"
}

if ("class" %in% type) {
names <- c(names, paste0(prefix, "_class"))
}
if ("prob" %in% type) {
names <- c(names, paste0(prefix, "_", x$lvl))
}
}
if (mode == "regression") {
names <- prefix
}

if (is.language(res)) {
res <- deparse1(res)
}

res <- stats::setNames(res, prefix)
attr(res, "pred_names") <- names
res <- stats::setNames(res, names)

new_orbital_class(res)
}
Expand All @@ -54,3 +69,27 @@ check_mode <- function(mode, call = rlang::caller_env()) {
)
}
}

check_type <- function(type, mode, call = rlang::caller_env()) {
if (is.null(type)) {
return(invisible())
}

supported_types <- c("numeric", "class", "prob")
rlang::arg_match(type, supported_types, multiple = TRUE, error_call = call)

if (mode == "regression" && any(!type %in% "numeric")) {
cli::cli_abort(
"{.arg type} can only be {.val numeric} for model with mode
{.val regression}, not {.val {type}}.",
call = call
)
}
if (mode == "classification" && any(!type %in% c("class", "prob"))) {
cli::cli_abort(
"{.arg type} can only be {.val class} or {.val prob} for model with mode
{.val classification}, not {.val {type}}.",
call = call
)
}
}
6 changes: 4 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ predict.orbital_class <- function(object, new_data, ...) {
rlang::check_dots_empty()
res <- dplyr::mutate(new_data, !!!orbital_inline(object))

pred_name <- names(object)[length(object)]
res <- dplyr::select(res, dplyr::any_of(pred_name))
pred_name <- attr(object, "pred_names")
if (!is.null(pred_name)) {
res <- dplyr::select(res, dplyr::any_of(pred_name))
}

res
}
6 changes: 4 additions & 2 deletions R/workflows.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @export
orbital.workflow <- function(x, ..., prefix = ".pred") {
orbital.workflow <- function(x, ..., prefix = ".pred", type = NULL) {
if (!workflows::is_trained_workflow(x)) {
cli::cli_abort("{.arg x} must be a fully trained {.cls workflow}.")
}
Expand All @@ -9,14 +9,16 @@ orbital.workflow <- function(x, ..., prefix = ".pred") {
}

model_fit <- workflows::extract_fit_parsnip(x)
out <- orbital(model_fit, prefix = prefix)
out <- orbital(model_fit, prefix = prefix, type = type)

preprocessor <- workflows::extract_preprocessor(x)

if (inherits(preprocessor, "recipe")) {
recipe_fit <- workflows::extract_recipe(x)

pred_names <- attr(out, "pred_names")
out <- orbital(recipe_fit, out, prefix = prefix)
attr(out, "pred_names") <- pred_names
}

new_orbital_class(out)
Expand Down
8 changes: 7 additions & 1 deletion man/orbital.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions tests/testthat/_snaps/parsnip.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,51 @@
Error in `orbital()`:
! Only models with modes "regression" and "classification" are supported. Not "invalid mode".

# type argument checking works

Code
orbital(lm_fit, type = "invalid")
Condition
Error in `orbital()`:
! `type` must be one of "numeric", "class", or "prob", not "invalid".

---

Code
orbital(lm_fit, type = "class")
Condition
Error in `orbital()`:
! `type` can only be "numeric" for model with mode "regression", not "class".

---

Code
orbital(lm_fit, type = c("class", "numeric"))
Condition
Error in `orbital()`:
! `type` can only be "numeric" for model with mode "regression", not "class" and "numeric".

---

Code
orbital(lm_fit, type = "invalid")
Condition
Error in `orbital()`:
! `type` must be one of "numeric", "class", or "prob", not "invalid".

---

Code
orbital(lm_fit, type = "numeric")
Condition
Error in `orbital()`:
! `type` can only be "class" or "prob" for model with mode "classification", not "numeric".

---

Code
orbital(lm_fit, type = c("class", "numeric"))
Condition
Error in `orbital()`:
! `type` can only be "class" or "prob" for model with mode "classification", not "class" and "numeric".

48 changes: 48 additions & 0 deletions tests/testthat/_snaps/workflows.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# type argument checking works

Code
orbital(wf_fit, type = "invalid")
Condition
Error in `orbital()`:
! `type` must be one of "numeric", "class", or "prob", not "invalid".

---

Code
orbital(wf_fit, type = "class")
Condition
Error in `orbital()`:
! `type` can only be "numeric" for model with mode "regression", not "class".

---

Code
orbital(wf_fit, type = c("class", "numeric"))
Condition
Error in `orbital()`:
! `type` can only be "numeric" for model with mode "regression", not "class" and "numeric".

---

Code
orbital(wf_fit, type = "invalid")
Condition
Error in `orbital()`:
! `type` must be one of "numeric", "class", or "prob", not "invalid".

---

Code
orbital(wf_fit, type = "numeric")
Condition
Error in `orbital()`:
! `type` can only be "class" or "prob" for model with mode "classification", not "numeric".

---

Code
orbital(wf_fit, type = c("class", "numeric"))
Condition
Error in `orbital()`:
! `type` can only be "class" or "prob" for model with mode "classification", not "class" and "numeric".

3 changes: 3 additions & 0 deletions tests/testthat/test-json.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ test_that("read and write json works", {

new <- orbital_json_read(tmp_file)

# temp fix
attr(orbital_obj, "pred_names") <- NULL

expect_identical(new, orbital_obj)
})
Loading

0 comments on commit e652446

Please sign in to comment.