From 22ea0ba0cae27873789c4dd24bb1c8daed9a09f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Tue, 3 Dec 2024 09:29:43 -0500 Subject: [PATCH] changes for #972 --- NAMESPACE | 1 + R/finalize.R | 38 ++++++ R/grid_helpers.R | 9 +- man/finalize_model.Rd | 3 + tests/testthat/_snaps/finalization.md | 176 ++++++++++++++++++++++++++ tests/testthat/test-finalization.R | 51 ++++++++ 6 files changed, 276 insertions(+), 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2adf3835f..2bb03b083 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -193,6 +193,7 @@ export(extract_workflow) export(filter_parameters) export(finalize_model) export(finalize_recipe) +export(finalize_tailor) export(finalize_workflow) export(finalize_workflow_preprocessor) export(first_eval_time) diff --git a/R/finalize.R b/R/finalize.R index a0c365b41..fef25bb1d 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -95,6 +95,44 @@ finalize_workflow <- function(x, parameters) { x <- set_workflow_recipe(x, rec) } + if (has_postprocessor(x)) { + tailor <- extract_postprocessor(x) + tailor <- finalize_tailor(tailor, parameters) + x <- set_workflow_tailor(x, tailor) + } + + x +} + +#' @export +#' @rdname finalize_model +finalize_tailor <- function(x, parameters) { + + if (!inherits(x, "tailor")) { + cli::cli_abort("{.arg x} should be a tailor, not {.obj_type_friendly {x}}.") + } + check_final_param(parameters) + pset <- + hardhat::extract_parameter_set_dials(x) %>% + dplyr::filter(id %in% names(parameters) & source == "tailor") + + if (tibble::is_tibble(parameters)) { + parameters <- as.list(parameters) + } + + parameters <- parameters[names(parameters) %in% pset$id] + parameters <- parameters[pset$id] + + for (i in seq_along(x$adjustments)) { + adj <- x$adjustments[[i]] + adj_comps <- purrr::map_lgl(pset$component, ~ inherits(adj, .x)) + if (any(adj_comps)) { + adj_ids <- pset$id[adj_comps] + adj_prms <- parameters[which(names(parameters) %in% adj_ids)] + adj$arguments <- purrr::list_modify(adj$arguments, !!!adj_prms) + x$adjustments[[i]] <- adj + } + } x } diff --git a/R/grid_helpers.R b/R/grid_helpers.R index 91b6131e1..3f6ad66d2 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -21,10 +21,10 @@ predict_model <- function(new_data, orig_rows, workflow, grid, metrics, msg <- c( msg, - i = + i = "Consider using {.code skip = TRUE} on any recipe steps that remove rows to avoid calling them on the assessment set." - + ) } else { msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?") @@ -464,3 +464,8 @@ set_workflow_recipe <- function(workflow, recipe) { workflow$pre$actions$recipe$recipe <- recipe workflow } + +set_workflow_tailor <- function(workflow, tailor) { + workflow$post$actions$tailor$tailor <- tailor + workflow +} diff --git a/man/finalize_model.Rd b/man/finalize_model.Rd index caf391633..2e5f1b1b5 100644 --- a/man/finalize_model.Rd +++ b/man/finalize_model.Rd @@ -4,6 +4,7 @@ \alias{finalize_model} \alias{finalize_recipe} \alias{finalize_workflow} +\alias{finalize_tailor} \title{Splice final parameters into objects} \usage{ finalize_model(x, parameters) @@ -11,6 +12,8 @@ finalize_model(x, parameters) finalize_recipe(x, parameters) finalize_workflow(x, parameters) + +finalize_tailor(x, parameters) } \arguments{ \item{x}{A recipe, \code{parsnip} model specification, or workflow.} diff --git a/tests/testthat/_snaps/finalization.md b/tests/testthat/_snaps/finalization.md index 38b240ac6..da1ad8347 100644 --- a/tests/testthat/_snaps/finalization.md +++ b/tests/testthat/_snaps/finalization.md @@ -7,3 +7,179 @@ ! Some model parameters require finalization but there are recipe parameters that require tuning. i Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument. +# finalize tailors + + Code + print(adj_1) + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, ?]. + +--- + + Code + print(adj_2) + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, 3]. + +--- + + Code + print(adj_3) + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, 3]. + +--- + + Code + print(adj_4) + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [?, ?]. + +--- + + Code + finalize_tailor(linear_reg(), tibble()) + Condition + Error in `finalize_tailor()`: + ! `x` should be a tailor, not a object. + +# finalize workflows with tailors + + Code + print(wflow_1) + Output + == Workflow ==================================================================== + Preprocessor: Formula + Model: linear_reg() + Postprocessor: tailor + + -- Preprocessor ---------------------------------------------------------------- + y ~ . + + -- Model ----------------------------------------------------------------------- + Linear Regression Model Specification (regression) + + Computational engine: lm + + + -- Postprocessor --------------------------------------------------------------- + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, ?]. + Output + NA + NA + NA + +--- + + Code + print(wflow_2) + Output + == Workflow ==================================================================== + Preprocessor: Formula + Model: linear_reg() + Postprocessor: tailor + + -- Preprocessor ---------------------------------------------------------------- + y ~ . + + -- Model ----------------------------------------------------------------------- + Linear Regression Model Specification (regression) + + Computational engine: lm + + + -- Postprocessor --------------------------------------------------------------- + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, 3]. + Output + NA + NA + NA + +--- + + Code + print(wflow_3) + Output + == Workflow ==================================================================== + Preprocessor: Formula + Model: linear_reg() + Postprocessor: tailor + + -- Preprocessor ---------------------------------------------------------------- + y ~ . + + -- Model ----------------------------------------------------------------------- + Linear Regression Model Specification (regression) + + Computational engine: lm + + + -- Postprocessor --------------------------------------------------------------- + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [2, 3]. + Output + NA + NA + NA + +--- + + Code + print(wflow_4) + Output + == Workflow ==================================================================== + Preprocessor: Formula + Model: linear_reg() + Postprocessor: tailor + + -- Preprocessor ---------------------------------------------------------------- + y ~ . + + -- Model ----------------------------------------------------------------------- + Linear Regression Model Specification (regression) + + Computational engine: lm + + + -- Postprocessor --------------------------------------------------------------- + Message + + -- tailor ---------------------------------------------------------------------- + A regression postprocessor with 1 adjustment: + + * Constrain numeric predictions to be between [?, ?]. + Output + NA + NA + NA + diff --git a/tests/testthat/test-finalization.R b/tests/testthat/test-finalization.R index 36f0ca5e2..8b5af9932 100644 --- a/tests/testthat/test-finalization.R +++ b/tests/testthat/test-finalization.R @@ -73,3 +73,54 @@ test_that("finalize recipe step with multiple tune parameters", { expect_equal(finalize_recipe(rec, best)$steps[[1]]$degree, 1) expect_equal(finalize_recipe(rec, best)$steps[[1]]$deg_free, 2) }) + +# ------------------------------------------------------------------------------ +# post-processing + +test_that("finalize tailors", { + library(tailor) + + adjust_rng <- + tailor() %>% + adjust_numeric_range(lower_limit = tune(), upper_limit = tune()) + + adj_1 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2)) + expect_snapshot(print(adj_1)) + + adj_2 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3)) + expect_snapshot(print(adj_2)) + + adj_3 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3, a = 2)) + expect_snapshot(print(adj_3)) + + adj_4 <- finalize_tailor(adjust_rng, tibble()) + expect_snapshot(print(adj_4)) + + expect_snapshot( + finalize_tailor(linear_reg(), tibble()), + error = TRUE + ) + +}) + +test_that("finalize workflows with tailors", { + library(tailor) + + adjust_rng <- + tailor() %>% + adjust_numeric_range(lower_limit = tune(), upper_limit = tune()) + wflow <- workflow(y ~ ., linear_reg(), adjust_rng) + + wflow_1 <- finalize_workflow(wflow, tibble(lower_limit = 2)) + expect_snapshot(print(wflow_1)) + + wflow_2 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3)) + expect_snapshot(print(wflow_2)) + + wflow_3 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3, a = 2)) + expect_snapshot(print(wflow_3)) + + wflow_4 <- finalize_workflow(wflow, tibble()) + expect_snapshot(print(wflow_4)) + +})