Skip to content

Commit

Permalink
add method for dbarts::bart() (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Jul 15, 2024
1 parent 31f4502 commit da505eb
Show file tree
Hide file tree
Showing 18 changed files with 325 additions and 0 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Suggests:
callr,
caret,
covr,
dbarts,
embed,
h2o,
keras,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(bundle,H2OAutoML)
S3method(bundle,H2OBinomialModel)
S3method(bundle,H2OMultinomialModel)
S3method(bundle,H2ORegressionModel)
S3method(bundle,bart)
S3method(bundle,default)
S3method(bundle,keras.engine.training.Model)
S3method(bundle,luz_module_fitted)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# bundle (development version)

* Added bundle method for objects from `dbarts::bart()` and, by extension,
`parsnip::bart(engine = "dbarts")` (#64).

# bundle 0.1.1

* Fixed bundling of recipes steps situated inside of workflows.
Expand Down
60 changes: 60 additions & 0 deletions R/bundle_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#' @templateVar class a `bart`
#' @template title_desc
#'
#' @templateVar outclass `bundled_bart`
#' @templateVar default .
#' @template return_bundle
#' @family bundlers
#'
#' @param x A `bart` object returned from [dbarts::bart()]. Notably, this ought
#' not to be the output of [parsnip::bart()].
#' @template param_unused_dots
#' @rdname bundle_bart
#' @template butcher_details
#' @examplesIf rlang::is_installed(c("dbarts"))
#' # fit model and bundle ------------------------------------------------
#' library(dbarts)
#'
#' mtcars$vs <- as.factor(mtcars$vs)
#'
#' set.seed(1)
#' fit <- dbarts::bart(mtcars[c("disp", "hp")], mtcars$vs, keeptrees = TRUE)
#'
#' fit_bundle <- bundle(fit)
#'
#' # then, after saveRDS + readRDS or passing to a new session ----------
#' fit_unbundled <- unbundle(fit_bundle)
#'
#' fit_unbundled_preds <- predict(fit_unbundled, mtcars)
#' @aliases bundle.bart
#' @method bundle bart
#' @export
bundle.bart <- function(x, ...) {
rlang::check_installed("dbarts")
rlang::check_dots_empty()

# `parsnip::bart()` and `dbarts::bart()` unfortunately both inherit from `bart`
if (inherits(x, "model_spec")) {
rlang::abort(c(
paste0("`x` should be the output of `dbarts::bart()`, not a model ",
"specification from `parsnip::bart()`."),
"To bundle `parsnip::bart()` output, train it with `parsnip::fit()` first."
))
}

if (is.null(x$fit)) {
rlang::abort(c(
"`x` can't be bundled.",
"`x` must have been fitted with argument `keeptrees = TRUE`."
))
}

# "touch" the object's state (#64)
invisible(x$fit$state)

bundle_constr(
object = x,
situate = situate_constr(identity),
desc_class = class(x)[1]
)
}
1 change: 1 addition & 0 deletions man/bundle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

110 changes: 110 additions & 0 deletions man/bundle_bart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_caret.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_embed.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_h2o.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_keras.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_parsnip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_recipe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_stacks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_workflows.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/bundle_xgboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/testthat/_snaps/bundle_bart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# bundle.bart errors informatively with model_spec input (#64)

Code
bundle(parsnip::bart())
Condition
Error in `bundle()`:
! `x` should be the output of `dbarts::bart()`, not a model specification from `parsnip::bart()`.
* To bundle `parsnip::bart()` output, train it with `parsnip::fit()` first.

# bundle.bart errors informatively when `keeptrees = FALSE` (#64)

Code
bundle(fit)
Condition
Error in `bundle()`:
! `x` can't be bundled.
* `x` must have been fitted with argument `keeptrees = TRUE`.

Loading

0 comments on commit da505eb

Please sign in to comment.