-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add butcher methods for nestedmodels (#256)
* Added axe methods and tests * Fixed formatting and whitespace * Fixed formatting and whitespace * Updated NEWS * Fixed examples * Update NEWS * Fixed disabled attribute --------- Co-authored-by: Julia Silge <[email protected]>
- Loading branch information
1 parent
dbf59f3
commit d7b07cc
Showing
6 changed files
with
237 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#' Axing a nested_model_fit. | ||
#' | ||
#' nested_model_fit objects are created from the \pkg{nestedmodels} | ||
#' package, which allows parsnip models to be fitted on nested data. Axing a | ||
#' nested_model_fit object involves axing all the inner model_fit objects. | ||
#' | ||
#' @inheritParams butcher | ||
#' | ||
#' @seealso [axe-model_fit] | ||
#' | ||
#' @return Axed nested_model_fit object. | ||
#' | ||
#' @examplesIf rlang::is_installed(c("parsnip", "nestedmodels")) | ||
#' | ||
#' library(nestedmodels) | ||
#' library(parsnip) | ||
#' | ||
#' model <- linear_reg() %>% | ||
#' set_engine("lm") %>% | ||
#' nested() | ||
#' | ||
#' nested_data <- tidyr::nest(example_nested_data, data = -id) | ||
#' | ||
#' fit <- fit(model, z ~ x + y + a + b, nested_data) | ||
#' | ||
#' # Reduce the model size | ||
#' butcher(fit) | ||
#' | ||
#' @name axe-nested_model_fit | ||
NULL | ||
|
||
#' Remove the call. | ||
#' | ||
#' @rdname axe-nested_model_fit | ||
#' @export | ||
axe_call.nested_model_fit <- function(x, verbose = FALSE, ...) { | ||
old <- x | ||
x$fit$.model_fit <- purrr::map( | ||
x$fit$.model_fit, | ||
axe_call, | ||
verbose = FALSE, | ||
... | ||
) | ||
|
||
disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) | ||
} | ||
|
||
#' Remove controls used for training. | ||
#' | ||
#' @rdname axe-nested_model_fit | ||
#' @export | ||
axe_ctrl.nested_model_fit <- function(x, verbose = FALSE, ...) { | ||
old <- x | ||
x$fit$.model_fit <- purrr::map( | ||
x$fit$.model_fit, | ||
axe_ctrl, | ||
verbose = FALSE, | ||
... | ||
) | ||
|
||
disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) | ||
} | ||
|
||
#' Remove the training data. | ||
#' | ||
#' @rdname axe-nested_model_fit | ||
#' @export | ||
axe_data.nested_model_fit <- function(x, verbose = FALSE, ...) { | ||
old <- x | ||
x$fit$.model_fit <- purrr::map( | ||
x$fit$.model_fit, | ||
axe_data, | ||
verbose = FALSE, | ||
... | ||
) | ||
|
||
disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) | ||
} | ||
|
||
#' Remove environments. | ||
#' | ||
#' @rdname axe-nested_model_fit | ||
#' @export | ||
axe_env.nested_model_fit <- function(x, verbose = FALSE, ...) { | ||
old <- x | ||
x$fit$.model_fit <- purrr::map( | ||
x$fit$.model_fit, | ||
axe_env, | ||
verbose = FALSE, | ||
... | ||
) | ||
|
||
disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) | ||
} | ||
|
||
#' Remove fitted values. | ||
#' | ||
#' @rdname axe-nested_model_fit | ||
#' @export | ||
axe_fitted.nested_model_fit <- function(x, verbose = FALSE, ...) { | ||
old <- x | ||
x$fit$.model_fit <- purrr::map( | ||
x$fit$.model_fit, | ||
axe_fitted, | ||
verbose = FALSE, | ||
... | ||
) | ||
|
||
disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
test_that("nested_model_fit + axe_() works", { | ||
skip_if_not_installed("parsnip") | ||
skip_if_not_installed("nestedmodels") | ||
|
||
model <- nestedmodels::nested( | ||
parsnip::set_engine(parsnip::linear_reg(), "lm") | ||
) | ||
|
||
# tidyr is a dependency of nestedmodels | ||
nested_data <- tidyr::nest(nestedmodels::example_nested_data, data = -id) | ||
|
||
nm_fit <- parsnip::fit(model, z ~ x + y + a + b, nested_data) | ||
|
||
x <- axe_call(nm_fit) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], axe_call(nm_fit$fit$.model_fit[[1]])) | ||
|
||
x <- axe_ctrl(nm_fit) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], axe_ctrl(nm_fit$fit$.model_fit[[1]])) | ||
|
||
x <- axe_data(nm_fit) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], axe_data(nm_fit$fit$.model_fit[[1]])) | ||
|
||
x <- axe_env(nm_fit) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], axe_env(nm_fit$fit$.model_fit[[1]])) | ||
|
||
x <- axe_fitted(nm_fit) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], axe_fitted(nm_fit$fit$.model_fit[[1]])) | ||
|
||
expect_equal( | ||
attr(x, "butcher_disabled"), | ||
attr(x$fit$.model_fit[[5]]$fit, "butcher_disabled") | ||
) | ||
|
||
x <- butcher(nm_fit) | ||
|
||
expect_equal( | ||
attr(x, "butcher_disabled"), | ||
attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") | ||
) | ||
|
||
expect_equal(x$fit$.model_fit[[1]], butcher(nm_fit$fit$.model_fit[[1]])) | ||
|
||
# Predict | ||
expect_equal( | ||
predict(x, nestedmodels::example_nested_data), | ||
predict(nm_fit, nestedmodels::example_nested_data) | ||
) | ||
}) |