Skip to content

Commit

Permalink
refactor out default_type()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Dec 17, 2024
1 parent f7cf5cd commit 76db153
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
5 changes: 1 addition & 4 deletions R/model-glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ orbital.glm <- function(
type = NULL
) {
mode <- rlang::arg_match(mode)
type <- default_type(type)

if (mode == "classification") {
outcome <- names(attr(x$terms, "dataClasses"))[attr(x$terms, "response")]
Expand All @@ -14,10 +15,6 @@ orbital.glm <- function(
eq <- tidypredict::tidypredict_fit(x)
eq <- deparse1(eq)

if (is.null(type)) {
type <- "class"
}

res <- NULL
if ("class" %in% type) {
res <- c(
Expand Down
5 changes: 1 addition & 4 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ orbital.xgb.Booster <- function(
lvl = NULL
) {
mode <- rlang::arg_match(mode)
type <- default_type(type)

if (mode == "classification") {
objective <- x$params$objective
Expand All @@ -15,10 +16,6 @@ orbital.xgb.Booster <- function(
c("multi:softprob", "binary:logistic")
)

if (is.null(type)) {
type <- "class"
}

extractor <- switch(
objective,
"multi:softprob" = xgboost_multisoft,
Expand Down
13 changes: 9 additions & 4 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
mode <- x$spec$mode
check_mode(mode)
check_type(type, mode)
type <- default_type(type)

res <- try(
orbital(x$fit, mode = mode, type = type, lvl = x$lvl),
Expand Down Expand Up @@ -31,10 +32,6 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
if (mode == "classification") {
names <- NULL

if (is.null(type)) {
type <- "class"
}

if ("class" %in% type) {
names <- c(names, paste0(prefix, "_class"))
}
Expand Down Expand Up @@ -108,3 +105,11 @@ check_type <- function(type, mode, call = rlang::caller_env()) {
)
}
}

default_type <- function(type) {
if (is.null(type)) {
type <- "class"
}

type
}

0 comments on commit 76db153

Please sign in to comment.