Skip to content

Commit

Permalink
add butcher methods for nestedmodels (#256)
Browse files Browse the repository at this point in the history
* Added axe methods and tests

* Fixed formatting and whitespace

* Fixed formatting and whitespace

* Updated NEWS

* Fixed examples

* Update NEWS

* Fixed disabled attribute

---------

Co-authored-by: Julia Silge <[email protected]>
  • Loading branch information
ashbythorpe and juliasilge authored Mar 19, 2023
1 parent dbf59f3 commit d7b07cc
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 1 deletion.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Suggests:
mda,
mgcv,
modeldata,
nestedmodels,
nnet,
parsnip (>= 0.1.6),
pkgload,
Expand All @@ -69,6 +70,7 @@ Suggests:
survival (>= 3.2-10),
testthat (>= 3.0.0),
TH.data,
tidyr,
usethis (>= 1.5.0),
xgboost (>= 1.3.2.1),
xrf
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ S3method(axe_call,mixo_spls)
S3method(axe_call,ml_model)
S3method(axe_call,model_fit)
S3method(axe_call,multnet)
S3method(axe_call,nested_model_fit)
S3method(axe_call,nnet)
S3method(axe_call,randomForest)
S3method(axe_call,ranger)
Expand All @@ -43,6 +44,7 @@ S3method(axe_ctrl,default)
S3method(axe_ctrl,gam)
S3method(axe_ctrl,ml_model)
S3method(axe_ctrl,model_fit)
S3method(axe_ctrl,nested_model_fit)
S3method(axe_ctrl,randomForest)
S3method(axe_ctrl,regbagg)
S3method(axe_ctrl,rpart)
Expand All @@ -63,6 +65,7 @@ S3method(axe_data,mixo_pls)
S3method(axe_data,mixo_spls)
S3method(axe_data,ml_model)
S3method(axe_data,model_fit)
S3method(axe_data,nested_model_fit)
S3method(axe_data,regbagg)
S3method(axe_data,rpart)
S3method(axe_data,survbagg)
Expand All @@ -85,6 +88,7 @@ S3method(axe_env,lda)
S3method(axe_env,lm)
S3method(axe_env,mda)
S3method(axe_env,model_fit)
S3method(axe_env,nested_model_fit)
S3method(axe_env,nnet)
S3method(axe_env,qda)
S3method(axe_env,quosure)
Expand Down Expand Up @@ -132,6 +136,7 @@ S3method(axe_fitted,mixo_pls)
S3method(axe_fitted,mixo_spls)
S3method(axe_fitted,ml_model)
S3method(axe_fitted,model_fit)
S3method(axe_fitted,nested_model_fit)
S3method(axe_fitted,nnet)
S3method(axe_fitted,ranger)
S3method(axe_fitted,recipe)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# butcher (development version)

* Added methods for `nestedmodels::nested()` (@ashbythorpe, #256).

* Updated methods for `mgcv::gam()` to also remove the `hat` and `offset`
components (@rdavis120, #255).

# butcher 0.3.2

* Added butcher methods for `mixOmics::pls()`, `mixOmics::spls()`,
Expand Down
115 changes: 115 additions & 0 deletions R/nested_model_fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#' Axing a nested_model_fit.
#'
#' nested_model_fit objects are created from the \pkg{nestedmodels}
#' package, which allows parsnip models to be fitted on nested data. Axing a
#' nested_model_fit object involves axing all the inner model_fit objects.
#'
#' @inheritParams butcher
#'
#' @seealso [axe-model_fit]
#'
#' @return Axed nested_model_fit object.
#'
#' @examplesIf rlang::is_installed(c("parsnip", "nestedmodels"))
#'
#' library(nestedmodels)
#' library(parsnip)
#'
#' model <- linear_reg() %>%
#' set_engine("lm") %>%
#' nested()
#'
#' nested_data <- tidyr::nest(example_nested_data, data = -id)
#'
#' fit <- fit(model, z ~ x + y + a + b, nested_data)
#'
#' # Reduce the model size
#' butcher(fit)
#'
#' @name axe-nested_model_fit
NULL

#' Remove the call.
#'
#' @rdname axe-nested_model_fit
#' @export
axe_call.nested_model_fit <- function(x, verbose = FALSE, ...) {
old <- x
x$fit$.model_fit <- purrr::map(
x$fit$.model_fit,
axe_call,
verbose = FALSE,
...
)

disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose)
}

#' Remove controls used for training.
#'
#' @rdname axe-nested_model_fit
#' @export
axe_ctrl.nested_model_fit <- function(x, verbose = FALSE, ...) {
old <- x
x$fit$.model_fit <- purrr::map(
x$fit$.model_fit,
axe_ctrl,
verbose = FALSE,
...
)

disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose)
}

#' Remove the training data.
#'
#' @rdname axe-nested_model_fit
#' @export
axe_data.nested_model_fit <- function(x, verbose = FALSE, ...) {
old <- x
x$fit$.model_fit <- purrr::map(
x$fit$.model_fit,
axe_data,
verbose = FALSE,
...
)

disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose)
}

#' Remove environments.
#'
#' @rdname axe-nested_model_fit
#' @export
axe_env.nested_model_fit <- function(x, verbose = FALSE, ...) {
old <- x
x$fit$.model_fit <- purrr::map(
x$fit$.model_fit,
axe_env,
verbose = FALSE,
...
)

disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose)
}

#' Remove fitted values.
#'
#' @rdname axe-nested_model_fit
#' @export
axe_fitted.nested_model_fit <- function(x, verbose = FALSE, ...) {
old <- x
x$fit$.model_fit <- purrr::map(
x$fit$.model_fit,
axe_fitted,
verbose = FALSE,
...
)

disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
add_butcher_attributes(x, old, disabled = disabled, verbose = verbose)
}
59 changes: 59 additions & 0 deletions man/axe-nested_model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions tests/testthat/test-nested_model_fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
test_that("nested_model_fit + axe_() works", {
skip_if_not_installed("parsnip")
skip_if_not_installed("nestedmodels")

model <- nestedmodels::nested(
parsnip::set_engine(parsnip::linear_reg(), "lm")
)

# tidyr is a dependency of nestedmodels
nested_data <- tidyr::nest(nestedmodels::example_nested_data, data = -id)

nm_fit <- parsnip::fit(model, z ~ x + y + a + b, nested_data)

x <- axe_call(nm_fit)

expect_equal(x$fit$.model_fit[[1]], axe_call(nm_fit$fit$.model_fit[[1]]))

x <- axe_ctrl(nm_fit)

expect_equal(x$fit$.model_fit[[1]], axe_ctrl(nm_fit$fit$.model_fit[[1]]))

x <- axe_data(nm_fit)

expect_equal(x$fit$.model_fit[[1]], axe_data(nm_fit$fit$.model_fit[[1]]))

x <- axe_env(nm_fit)

expect_equal(x$fit$.model_fit[[1]], axe_env(nm_fit$fit$.model_fit[[1]]))

x <- axe_fitted(nm_fit)

expect_equal(x$fit$.model_fit[[1]], axe_fitted(nm_fit$fit$.model_fit[[1]]))

expect_equal(
attr(x, "butcher_disabled"),
attr(x$fit$.model_fit[[5]]$fit, "butcher_disabled")
)

x <- butcher(nm_fit)

expect_equal(
attr(x, "butcher_disabled"),
attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled")
)

expect_equal(x$fit$.model_fit[[1]], butcher(nm_fit$fit$.model_fit[[1]]))

# Predict
expect_equal(
predict(x, nestedmodels::example_nested_data),
predict(nm_fit, nestedmodels::example_nested_data)
)
})

0 comments on commit d7b07cc

Please sign in to comment.