Skip to content

Commit

Permalink
changes for #972
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Dec 3, 2024
1 parent f6772f4 commit 22ea0ba
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 2 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ export(extract_workflow)
export(filter_parameters)
export(finalize_model)
export(finalize_recipe)
export(finalize_tailor)
export(finalize_workflow)
export(finalize_workflow_preprocessor)
export(first_eval_time)
Expand Down
38 changes: 38 additions & 0 deletions R/finalize.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,44 @@ finalize_workflow <- function(x, parameters) {
x <- set_workflow_recipe(x, rec)
}

if (has_postprocessor(x)) {
tailor <- extract_postprocessor(x)
tailor <- finalize_tailor(tailor, parameters)
x <- set_workflow_tailor(x, tailor)
}

x
}

#' @export
#' @rdname finalize_model
finalize_tailor <- function(x, parameters) {

if (!inherits(x, "tailor")) {
cli::cli_abort("{.arg x} should be a tailor, not {.obj_type_friendly {x}}.")
}
check_final_param(parameters)
pset <-
hardhat::extract_parameter_set_dials(x) %>%
dplyr::filter(id %in% names(parameters) & source == "tailor")

if (tibble::is_tibble(parameters)) {
parameters <- as.list(parameters)
}

parameters <- parameters[names(parameters) %in% pset$id]
parameters <- parameters[pset$id]

for (i in seq_along(x$adjustments)) {
adj <- x$adjustments[[i]]
adj_comps <- purrr::map_lgl(pset$component, ~ inherits(adj, .x))
if (any(adj_comps)) {
adj_ids <- pset$id[adj_comps]
adj_prms <- parameters[which(names(parameters) %in% adj_ids)]
adj$arguments <- purrr::list_modify(adj$arguments, !!!adj_prms)
x$adjustments[[i]] <- adj
}
}
x
}

Expand Down
9 changes: 7 additions & 2 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
msg <-
c(
msg,
i =
i =
"Consider using {.code skip = TRUE} on any recipe steps that
remove rows to avoid calling them on the assessment set."

)
} else {
msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?")
Expand Down Expand Up @@ -464,3 +464,8 @@ set_workflow_recipe <- function(workflow, recipe) {
workflow$pre$actions$recipe$recipe <- recipe
workflow
}

set_workflow_tailor <- function(workflow, tailor) {
workflow$post$actions$tailor$tailor <- tailor
workflow
}
3 changes: 3 additions & 0 deletions man/finalize_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

176 changes: 176 additions & 0 deletions tests/testthat/_snaps/finalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,179 @@
! Some model parameters require finalization but there are recipe parameters that require tuning.
i Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.

# finalize tailors

Code
print(adj_1)
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, ?].

---

Code
print(adj_2)
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, 3].

---

Code
print(adj_3)
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, 3].

---

Code
print(adj_4)
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [?, ?].

---

Code
finalize_tailor(linear_reg(), tibble())
Condition
Error in `finalize_tailor()`:
! `x` should be a tailor, not a <linear_reg> object.

# finalize workflows with tailors

Code
print(wflow_1)
Output
== Workflow ====================================================================
Preprocessor: Formula
Model: linear_reg()
Postprocessor: tailor
-- Preprocessor ----------------------------------------------------------------
y ~ .
-- Model -----------------------------------------------------------------------
Linear Regression Model Specification (regression)
Computational engine: lm
-- Postprocessor ---------------------------------------------------------------
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, ?].
Output
NA
NA
NA

---

Code
print(wflow_2)
Output
== Workflow ====================================================================
Preprocessor: Formula
Model: linear_reg()
Postprocessor: tailor
-- Preprocessor ----------------------------------------------------------------
y ~ .
-- Model -----------------------------------------------------------------------
Linear Regression Model Specification (regression)
Computational engine: lm
-- Postprocessor ---------------------------------------------------------------
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, 3].
Output
NA
NA
NA

---

Code
print(wflow_3)
Output
== Workflow ====================================================================
Preprocessor: Formula
Model: linear_reg()
Postprocessor: tailor
-- Preprocessor ----------------------------------------------------------------
y ~ .
-- Model -----------------------------------------------------------------------
Linear Regression Model Specification (regression)
Computational engine: lm
-- Postprocessor ---------------------------------------------------------------
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [2, 3].
Output
NA
NA
NA

---

Code
print(wflow_4)
Output
== Workflow ====================================================================
Preprocessor: Formula
Model: linear_reg()
Postprocessor: tailor
-- Preprocessor ----------------------------------------------------------------
y ~ .
-- Model -----------------------------------------------------------------------
Linear Regression Model Specification (regression)
Computational engine: lm
-- Postprocessor ---------------------------------------------------------------
Message
-- tailor ----------------------------------------------------------------------
A regression postprocessor with 1 adjustment:
* Constrain numeric predictions to be between [?, ?].
Output
NA
NA
NA

51 changes: 51 additions & 0 deletions tests/testthat/test-finalization.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,54 @@ test_that("finalize recipe step with multiple tune parameters", {
expect_equal(finalize_recipe(rec, best)$steps[[1]]$degree, 1)
expect_equal(finalize_recipe(rec, best)$steps[[1]]$deg_free, 2)
})

# ------------------------------------------------------------------------------
# post-processing

test_that("finalize tailors", {
library(tailor)

adjust_rng <-
tailor() %>%
adjust_numeric_range(lower_limit = tune(), upper_limit = tune())

adj_1 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2))
expect_snapshot(print(adj_1))

adj_2 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3))
expect_snapshot(print(adj_2))

adj_3 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3, a = 2))
expect_snapshot(print(adj_3))

adj_4 <- finalize_tailor(adjust_rng, tibble())
expect_snapshot(print(adj_4))

expect_snapshot(
finalize_tailor(linear_reg(), tibble()),
error = TRUE
)

})

test_that("finalize workflows with tailors", {
library(tailor)

adjust_rng <-
tailor() %>%
adjust_numeric_range(lower_limit = tune(), upper_limit = tune())
wflow <- workflow(y ~ ., linear_reg(), adjust_rng)

wflow_1 <- finalize_workflow(wflow, tibble(lower_limit = 2))
expect_snapshot(print(wflow_1))

wflow_2 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3))
expect_snapshot(print(wflow_2))

wflow_3 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3, a = 2))
expect_snapshot(print(wflow_3))

wflow_4 <- finalize_workflow(wflow, tibble())
expect_snapshot(print(wflow_4))

})

0 comments on commit 22ea0ba

Please sign in to comment.