Skip to content

Commit

Permalink
merge PR #262: transition inner split to separate data argument
Browse files Browse the repository at this point in the history
* remove `make_inner_split()` and its tests: workflows will no longer take an `add_tailor(prop)` or `add_tailor(method)` argument, instead taking a `fit.workflow(calibration)` argument that supersedes both of them.
* transition from `add_tailor(prop)` and `method` to `fit.workflow(calibration)`
* removes `add_tailor(prop)` and `add_tailor(method)`
* adds `fit.workflow(calibration)`
* various documentation updates
* removes rsample Suggests
* `.should_inner_split()` -> `.workflow_includes_calibration()`
  • Loading branch information
simonpcouch authored Oct 1, 2024
2 parents 425f05a + b28a6c4 commit 7929511
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 255 deletions.
2 changes: 0 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ Suggests:
methods,
modeldata (>= 1.0.0),
recipes (>= 1.0.10.9000),
rsample (>= 1.2.1.9000),
rmarkdown,
testthat (>= 3.0.0)
VignetteBuilder:
Expand All @@ -54,7 +53,6 @@ Config/Needs/website:
tidyverse/tidytemplate,
yardstick
Remotes:
tidymodels/rsample,
tidymodels/recipes,
tidymodels/parsnip,
tidymodels/tailor,
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export(.fit_finalize)
export(.fit_model)
export(.fit_post)
export(.fit_pre)
export(.should_inner_split)
export(.workflow_includes_calibration)
export(add_case_weights)
export(add_formula)
export(add_model)
Expand Down
73 changes: 37 additions & 36 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
#' @param object A workflow
#'
#' @param data A data frame of predictors and outcomes to use when fitting the
#' workflow
#' preprocessor and model.
#'
#' @param ... Not used
#'
#' @param calibration A data frame of predictors and outcomes to use when
#' fitting the postprocessor. See the "Data Usage" section of [add_tailor()]
#' for more information.
#'
#' @param control A [control_workflow()] object
#'
#' @return
Expand Down Expand Up @@ -51,34 +55,31 @@
#' add_recipe(recipe)
#'
#' fit(recipe_wf, mtcars)
fit.workflow <- function(object, data, ..., control = control_workflow()) {
fit.workflow <- function(object, data, ..., calibration = NULL, control = control_workflow()) {
check_dots_empty()

if (is_missing(data)) {
cli_abort("{.arg data} must be provided to fit a workflow.")
}

validate_has_calibration(object, calibration)

if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
}

# If `calibration` is not overwritten in the following `if` statement, then the
# the postprocessor doesn't actually require training and the dataset
# passed to `.fit_post()` will have no effect.
calibration <- data
if (.should_inner_split(object)) {
inner_split <- make_inner_split(object, data)

data <- rsample::analysis(inner_split)
calibration <- rsample::assessment(inner_split)
}

workflow <- object
workflow <- .fit_pre(workflow, data)
workflow <- .fit_model(workflow, control)

if (!.workflow_includes_calibration(workflow)) {
# in this case, training the tailor on `data` will not leak data (#262)
calibration <- data
}
if (has_postprocessor(workflow)) {
workflow <- .fit_post(workflow, calibration)
}

workflow <- .fit_finalize(workflow)

workflow
Expand All @@ -87,31 +88,11 @@ fit.workflow <- function(object, data, ..., control = control_workflow()) {
#' @export
#' @rdname workflows-internals
#' @keywords internal
.should_inner_split <- function(workflow) {
.workflow_includes_calibration <- function(workflow) {
has_postprocessor(workflow) &&
tailor::tailor_requires_fit(
extract_postprocessor(workflow, estimated = FALSE)
)
}

make_inner_split <- function(object, data) {
validate_rsample_available()

method <- object$post$actions$tailor$method
mocked_split <-
rsample::make_splits(
list(analysis = seq_len(nrow(data)), assessment = integer()),
data = data,
class = if (is.null(method)) "mc_split" else method
tailor::tailor_requires_fit(
extract_postprocessor(workflow, estimated = FALSE)
)

# add_tailor(prop) is the proportion to train the postprocessor, while
# rsample::mc_cv(prop) is the proportion to train the model (#247)
prop <- object$post$actions$tailor$prop
rsample::inner_split(
mocked_split,
list(prop = if (is.null(prop)) 2/3 else 1 - prop)
)
}

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -249,6 +230,26 @@ validate_has_model <- function(x, ..., call = caller_env()) {
invisible(x)
}

validate_has_calibration <- function(x, calibration, call = caller_env()) {
if (.workflow_includes_calibration(x) && is.null(calibration)) {
cli::cli_abort(
"The workflow requires a {.arg calibration} set to train but none
was supplied.",
call = call
)
}

if (!.workflow_includes_calibration(x) && !is.null(calibration)) {
cli::cli_warn(
"The workflow does not require a {.arg calibration} set to train
but one was supplied.",
call = call
)
}

invisible(x)
}

# ------------------------------------------------------------------------------

finalize_blueprint <- function(workflow) {
Expand Down
70 changes: 20 additions & 50 deletions R/post-action-tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,17 @@
#' should not have been trained already with [tailor::fit()]; workflows
#' will handle training internally.
#'
#' @param prop The proportion of the data in [fit.workflow()] that should be
#' held back specifically for estimating the postprocessor. Only relevant for
#' postprocessors that require estimation---see section Data Usage below to
#' learn more. Defaults to 1/3.
#'
#' @param method The method with which to split the data in [fit.workflow()],
#' as a character vector. Only relevant for postprocessors that
#' require estimation and not required when resampling the workflow with
#' tune. If `fit.workflow(data)` arose as `training(split_object)`, this argument can
#' usually be supplied as `class(split_object)`. Defaults to `"mc_split"`, which
#' randomly samples `fit.workflow(data)` into two sets, similarly to
#' [rsample::initial_split()]. See section Data Usage below to learn more.
#'
#' @section Data Usage:
#'
#' While preprocessors and models are trained on data in the usual sense,
#' postprocessors are training on _predictions_ on data. When a workflow
#' is fitted, the user supplies training data with the `data` argument.
#' is fitted, the user typically supplies training data with the `data` argument.
#' When workflows don't contain a postprocessor that requires training,
#' they can use all of the supplied `data` to train the preprocessor and model.
#' However, in the case where a postprocessor must be trained as well,
#' training the preprocessor and model on all of `data` would leave no data
#' left to train the postprocessor with---if that were the case, workflows
#' users can pass all of the available data to the `data` argument to train the
#' preprocessor and model. However, in the case where a postprocessor must be
#' trained as well, allotting all of the available data to the `data` argument
#' to train the preprocessor and model would leave no data
#' to train the postprocessor with---if that were the case, workflows
#' would need to `predict()` from the preprocessor and model on the same `data`
#' that they were trained on, with the postprocessor then training on those
#' predictions. Predictions on data that a model was trained on likely follow
Expand All @@ -49,22 +37,15 @@
#' is passed to that trained postprocessor and model to generate predictions,
#' which then form the training data for the postprocessor.
#'
#' The arguments `prop` and `method` parameterize how that data is split up.
#' `prop` determines the proportion of rows in `fit.workflow(data)` that are
#' allotted to training the preprocessor and model, while the rest are used to
#' train the postprocessor. `method` determines how that split occurs; since
#' `fit.workflow()` just takes in a data frame, the function doesn't have
#' any information on how that dataset came to be. For example, `data` could
#' have been created as:
#'
#' ```
#' split <- rsample::initial_split(some_other_data)
#' data <- rsample::training(split)
#' ```
#' When fitting a workflow with a postprocessor that requires training
#' (i.e. one that returns `TRUE` in `.workflow_includes_calibration(workflow)`), users
#' must pass two data arguments--the usual `fit.workflow(data)` will be used
#' to train the preprocessor and model while `fit.workflow(calibration)` will
#' be used to train the postprocessor.
#'
#' ...in which case it's okay to randomly allot some rows of `data` to train the
#' preprocessor and model and the rest to train the postprocessor. However,
#' `data` could also have arisen as:
#' In some situations, randomly splitting `fit.workflow(data)` (with
#' `rsample::initial_split()`, for example) is sufficient to prevent data
#' leakage. However, `fit.workflow(data)` could also have arisen as:
#'
#' ```
#' boots <- rsample::bootstraps(some_other_data)
Expand All @@ -78,8 +59,9 @@
#' datasets, resulting in the preprocessor and model generating predictions on
#' rows they've seen before. Similarly problematic situations could arise in the
#' context of other resampling situations, like time-based splits.
#' The `method` argument ensures that data is allotted properly (and is
#' internally handled by the tune package when resampling workflows).
#' In general, use the `rsample::inner_split()` function to prevent data
#' leakage when resampling; when workflows with postprocessors that require
#' training are passed to the tune package, this is handled internally.
#'
#' @param ... Not used.
#'
Expand All @@ -102,14 +84,11 @@
#' remove_tailor(workflow)
#'
#' update_tailor(workflow, adjust_probability_threshold(tailor, .2))
add_tailor <- function(x, tailor, prop = NULL, method = NULL, ...) {
add_tailor <- function(x, tailor, ...) {
check_dots_empty()
validate_tailor_available()
action <- new_action_tailor(tailor, prop = prop, method = method)
action <- new_action_tailor(tailor)
res <- add_action(x, action, "tailor")
if (.should_inner_split(res)) {
validate_rsample_available()
}
res
}

Expand Down Expand Up @@ -185,7 +164,7 @@ mock_trained_workflow <- function(workflow) {

# ------------------------------------------------------------------------------

new_action_tailor <- function(tailor, prop, method, ..., call = caller_env()) {
new_action_tailor <- function(tailor, ..., call = caller_env()) {
check_dots_empty()

if (!is_tailor(tailor)) {
Expand All @@ -196,17 +175,8 @@ new_action_tailor <- function(tailor, prop, method, ..., call = caller_env()) {
cli_abort("Can't add a trained tailor to a workflow.", call = call)
}

if (!is.null(prop) &&
(!rlang::is_double(prop, n = 1) || prop <= 0 || prop >= 1)) {
cli_abort("{.arg prop} must be a numeric on (0, 1).", call = call)
}

# todo: test method

new_action_post(
tailor = tailor,
prop = prop,
method = method,
subclass = "action_tailor"
)
}
Expand Down
14 changes: 0 additions & 14 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,6 @@ validate_tailor_available <- function(..., call = caller_env()) {
invisible()
}

validate_rsample_available <- function(..., call = caller_env()) {
check_dots_empty()

if (!requireNamespace("rsample", quietly = TRUE)) {
cli_abort(
"The {.pkg rsample} package must be available to add a tailor that
requires training.",
call = call
)
}

invisible()
}

# ------------------------------------------------------------------------------

# https://github.com/r-lib/tidyselect/blob/10e00cea2fff3585fc827b6a7eb5e172acadbb2f/R/utils.R#L109
Expand Down
53 changes: 18 additions & 35 deletions man/add_tailor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7929511

Please sign in to comment.