From d7b07cc1b4cef386757a1622f1d4ea98de4d470b Mon Sep 17 00:00:00 2001 From: Ashby Thorpe <61390103+ashbythorpe@users.noreply.github.com> Date: Sun, 19 Mar 2023 21:45:18 +0000 Subject: [PATCH] add butcher methods for nestedmodels (#256) * Added axe methods and tests * Fixed formatting and whitespace * Fixed formatting and whitespace * Updated NEWS * Fixed examples * Update NEWS * Fixed disabled attribute --------- Co-authored-by: Julia Silge --- DESCRIPTION | 2 + NAMESPACE | 5 ++ NEWS.md | 4 +- R/nested_model_fit.R | 115 +++++++++++++++++++++++++ man/axe-nested_model_fit.Rd | 59 +++++++++++++ tests/testthat/test-nested_model_fit.R | 53 ++++++++++++ 6 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 R/nested_model_fit.R create mode 100644 man/axe-nested_model_fit.Rd create mode 100644 tests/testthat/test-nested_model_fit.R diff --git a/DESCRIPTION b/DESCRIPTION index a1948af..2d69c97 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -52,6 +52,7 @@ Suggests: mda, mgcv, modeldata, + nestedmodels, nnet, parsnip (>= 0.1.6), pkgload, @@ -69,6 +70,7 @@ Suggests: survival (>= 3.2-10), testthat (>= 3.0.0), TH.data, + tidyr, usethis (>= 1.5.0), xgboost (>= 1.3.2.1), xrf diff --git a/NAMESPACE b/NAMESPACE index bf61088..ccecdb2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,6 +23,7 @@ S3method(axe_call,mixo_spls) S3method(axe_call,ml_model) S3method(axe_call,model_fit) S3method(axe_call,multnet) +S3method(axe_call,nested_model_fit) S3method(axe_call,nnet) S3method(axe_call,randomForest) S3method(axe_call,ranger) @@ -43,6 +44,7 @@ S3method(axe_ctrl,default) S3method(axe_ctrl,gam) S3method(axe_ctrl,ml_model) S3method(axe_ctrl,model_fit) +S3method(axe_ctrl,nested_model_fit) S3method(axe_ctrl,randomForest) S3method(axe_ctrl,regbagg) S3method(axe_ctrl,rpart) @@ -63,6 +65,7 @@ S3method(axe_data,mixo_pls) S3method(axe_data,mixo_spls) S3method(axe_data,ml_model) S3method(axe_data,model_fit) +S3method(axe_data,nested_model_fit) S3method(axe_data,regbagg) S3method(axe_data,rpart) S3method(axe_data,survbagg) @@ -85,6 +88,7 @@ S3method(axe_env,lda) S3method(axe_env,lm) S3method(axe_env,mda) S3method(axe_env,model_fit) +S3method(axe_env,nested_model_fit) S3method(axe_env,nnet) S3method(axe_env,qda) S3method(axe_env,quosure) @@ -132,6 +136,7 @@ S3method(axe_fitted,mixo_pls) S3method(axe_fitted,mixo_spls) S3method(axe_fitted,ml_model) S3method(axe_fitted,model_fit) +S3method(axe_fitted,nested_model_fit) S3method(axe_fitted,nnet) S3method(axe_fitted,ranger) S3method(axe_fitted,recipe) diff --git a/NEWS.md b/NEWS.md index 9d8c859..401dc4e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,10 @@ # butcher (development version) +* Added methods for `nestedmodels::nested()` (@ashbythorpe, #256). + * Updated methods for `mgcv::gam()` to also remove the `hat` and `offset` components (@rdavis120, #255). - + # butcher 0.3.2 * Added butcher methods for `mixOmics::pls()`, `mixOmics::spls()`, diff --git a/R/nested_model_fit.R b/R/nested_model_fit.R new file mode 100644 index 0000000..3a8b247 --- /dev/null +++ b/R/nested_model_fit.R @@ -0,0 +1,115 @@ +#' Axing a nested_model_fit. +#' +#' nested_model_fit objects are created from the \pkg{nestedmodels} +#' package, which allows parsnip models to be fitted on nested data. Axing a +#' nested_model_fit object involves axing all the inner model_fit objects. +#' +#' @inheritParams butcher +#' +#' @seealso [axe-model_fit] +#' +#' @return Axed nested_model_fit object. +#' +#' @examplesIf rlang::is_installed(c("parsnip", "nestedmodels")) +#' +#' library(nestedmodels) +#' library(parsnip) +#' +#' model <- linear_reg() %>% +#' set_engine("lm") %>% +#' nested() +#' +#' nested_data <- tidyr::nest(example_nested_data, data = -id) +#' +#' fit <- fit(model, z ~ x + y + a + b, nested_data) +#' +#' # Reduce the model size +#' butcher(fit) +#' +#' @name axe-nested_model_fit +NULL + +#' Remove the call. +#' +#' @rdname axe-nested_model_fit +#' @export +axe_call.nested_model_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$.model_fit <- purrr::map( + x$fit$.model_fit, + axe_call, + verbose = FALSE, + ... + ) + + disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) +} + +#' Remove controls used for training. +#' +#' @rdname axe-nested_model_fit +#' @export +axe_ctrl.nested_model_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$.model_fit <- purrr::map( + x$fit$.model_fit, + axe_ctrl, + verbose = FALSE, + ... + ) + + disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) +} + +#' Remove the training data. +#' +#' @rdname axe-nested_model_fit +#' @export +axe_data.nested_model_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$.model_fit <- purrr::map( + x$fit$.model_fit, + axe_data, + verbose = FALSE, + ... + ) + + disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) +} + +#' Remove environments. +#' +#' @rdname axe-nested_model_fit +#' @export +axe_env.nested_model_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$.model_fit <- purrr::map( + x$fit$.model_fit, + axe_env, + verbose = FALSE, + ... + ) + + disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) +} + +#' Remove fitted values. +#' +#' @rdname axe-nested_model_fit +#' @export +axe_fitted.nested_model_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$.model_fit <- purrr::map( + x$fit$.model_fit, + axe_fitted, + verbose = FALSE, + ... + ) + + disabled <- attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + add_butcher_attributes(x, old, disabled = disabled, verbose = verbose) +} diff --git a/man/axe-nested_model_fit.Rd b/man/axe-nested_model_fit.Rd new file mode 100644 index 0000000..34b6bfa --- /dev/null +++ b/man/axe-nested_model_fit.Rd @@ -0,0 +1,59 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/nested_model_fit.R +\name{axe-nested_model_fit} +\alias{axe-nested_model_fit} +\alias{axe_call.nested_model_fit} +\alias{axe_ctrl.nested_model_fit} +\alias{axe_data.nested_model_fit} +\alias{axe_env.nested_model_fit} +\alias{axe_fitted.nested_model_fit} +\title{Axing a nested_model_fit.} +\usage{ +\method{axe_call}{nested_model_fit}(x, verbose = FALSE, ...) + +\method{axe_ctrl}{nested_model_fit}(x, verbose = FALSE, ...) + +\method{axe_data}{nested_model_fit}(x, verbose = FALSE, ...) + +\method{axe_env}{nested_model_fit}(x, verbose = FALSE, ...) + +\method{axe_fitted}{nested_model_fit}(x, verbose = FALSE, ...) +} +\arguments{ +\item{x}{A model object.} + +\item{verbose}{Print information each time an axe method is executed. +Notes how much memory is released and what functions are +disabled. Default is \code{FALSE}.} + +\item{...}{Any additional arguments related to axing.} +} +\value{ +Axed nested_model_fit object. +} +\description{ +nested_model_fit objects are created from the \pkg{nestedmodels} +package, which allows parsnip models to be fitted on nested data. Axing a +nested_model_fit object involves axing all the inner model_fit objects. +} +\examples{ +\dontshow{if (rlang::is_installed(c("parsnip", "nestedmodels"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} + +library(nestedmodels) +library(parsnip) + +model <- linear_reg() \%>\% + set_engine("lm") \%>\% + nested() + +nested_data <- tidyr::nest(example_nested_data, data = -id) + +fit <- fit(model, z ~ x + y + a + b, nested_data) + +# Reduce the model size +butcher(fit) +\dontshow{\}) # examplesIf} +} +\seealso{ +[axe-model_fit] +} diff --git a/tests/testthat/test-nested_model_fit.R b/tests/testthat/test-nested_model_fit.R new file mode 100644 index 0000000..6aeb0ac --- /dev/null +++ b/tests/testthat/test-nested_model_fit.R @@ -0,0 +1,53 @@ +test_that("nested_model_fit + axe_() works", { + skip_if_not_installed("parsnip") + skip_if_not_installed("nestedmodels") + + model <- nestedmodels::nested( + parsnip::set_engine(parsnip::linear_reg(), "lm") + ) + + # tidyr is a dependency of nestedmodels + nested_data <- tidyr::nest(nestedmodels::example_nested_data, data = -id) + + nm_fit <- parsnip::fit(model, z ~ x + y + a + b, nested_data) + + x <- axe_call(nm_fit) + + expect_equal(x$fit$.model_fit[[1]], axe_call(nm_fit$fit$.model_fit[[1]])) + + x <- axe_ctrl(nm_fit) + + expect_equal(x$fit$.model_fit[[1]], axe_ctrl(nm_fit$fit$.model_fit[[1]])) + + x <- axe_data(nm_fit) + + expect_equal(x$fit$.model_fit[[1]], axe_data(nm_fit$fit$.model_fit[[1]])) + + x <- axe_env(nm_fit) + + expect_equal(x$fit$.model_fit[[1]], axe_env(nm_fit$fit$.model_fit[[1]])) + + x <- axe_fitted(nm_fit) + + expect_equal(x$fit$.model_fit[[1]], axe_fitted(nm_fit$fit$.model_fit[[1]])) + + expect_equal( + attr(x, "butcher_disabled"), + attr(x$fit$.model_fit[[5]]$fit, "butcher_disabled") + ) + + x <- butcher(nm_fit) + + expect_equal( + attr(x, "butcher_disabled"), + attr(x$fit$.model_fit[[1]]$fit, "butcher_disabled") + ) + + expect_equal(x$fit$.model_fit[[1]], butcher(nm_fit$fit$.model_fit[[1]])) + + # Predict + expect_equal( + predict(x, nestedmodels::example_nested_data), + predict(nm_fit, nestedmodels::example_nested_data) + ) +})