Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure function and internal clean-up #151

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: bmm
Title: Easy and Accesible Bayesian Measurement Models using 'brms'
Version: 0.4.3.9000
Version: 0.4.4.9000
Authors@R: c(
person("Vencislav", "Popov", , "[email protected]", role = c("aut", "cre", "cph")),
person("Gidon", "Frischkorn", , "[email protected]", role = c("aut", "cph")),
Expand Down Expand Up @@ -43,8 +43,7 @@ Imports:
stats,
matrixStats,
crayon,
methods,
assertthat
methods
URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/
BugReports: https://github.com/venpopov/bmm/issues
Additional_repositories:
Expand Down
8 changes: 7 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ S3method(postprocess_brm,default)
S3method(postprocess_brm,sdmSimple)
S3method(print,bmmsummary)
S3method(print,message)
S3method(reset_env,bmmfit)
S3method(reset_env,bmmformula)
S3method(reset_env,brmsfamily)
S3method(reset_env,brmsformula)
S3method(reset_env,formula)
S3method(revert_postprocess_brm,default)
S3method(revert_postprocess_brm,sdmSimple)
S3method(rhs_vars,bmmformula)
Expand Down Expand Up @@ -84,6 +89,7 @@ export(qmixture3p)
export(qsdm)
export(rIMM)
export(rad2deg)
export(restructure_bmm)
export(revert_postprocess_brm)
export(rmixture2p)
export(rmixture3p)
Expand All @@ -96,8 +102,8 @@ export(supported_models)
export(use_model_template)
export(wrap)
import(stats)
importFrom(assertthat,assert_that)
importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(glue,glue)
importFrom(magrittr,"%>%")
importFrom(utils,packageVersion)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* add a custom summary() method for bmm models (#144)
* add a global options bmm.summary_backend to control the backend used for the summary() method (choices are "bmm" and "brms")
* deprecate get_model_prior(), get_stancode() and get_standata(). These functions will be removed in future versions of the package. Due to [recent changes](https://github.com/paul-buerkner/brms/pull/1604) in *brms* version 2.20.14, you can now use the *brms* functions `default_prior`, `stancode` and `standata` directly with *bmm* models (alternatively, their older aliases, "get_prior", "make_stancode", "make_standata").
* function restructure() now allows to apply methods introduced in newer bmm versions to bmmfit objects created by older bmm versions

# bmm 0.4.0

Expand Down
59 changes: 29 additions & 30 deletions R/helpers-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,47 +427,48 @@ use_model_template <- function(model_name,
"# ?postprocess_brm for details\n\n")


model_object <- glue::glue(".model_<<model_name>> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, ...) {\n",
" out <- list(\n",
" resp_vars = nlist(resp_var1),\n",
" other_vars = nlist(required_arg1, required_arg2),\n",
" info = list(\n",
" domain = '',\n",
" task = '',\n",
" name = '',\n",
" citation = '',\n",
" version = '',\n",
" requirements = '',\n",
" parameters = list(),\n",
" fixed_parameters = list()\n",
" ),\n",
" void_mu = FALSE\n",
" )\n",
" class(out) <- c('bmmmodel', '<<model_name>>')\n",
" out\n",
"}\n\n",
.open = "<<", .close = ">>")
model_object <- glue(".model_<<model_name>> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, links = NULL, ...) {\n",
" out <- list(\n",
" resp_vars = nlist(resp_var1),\n",
" other_vars = nlist(required_arg1, required_arg2),\n",
" domain = '',\n",
" task = '',\n",
" name = '',\n",
" citation = '',\n",
" version = '',\n",
" requirements = '',\n",
" parameters = list(),\n",
" links = list(),\n",
" fixed_parameters = list()\n",
" void_mu = FALSE\n",
" )\n",
" class(out) <- c('bmmmodel', '<<model_name>>')\n",
" out$links[names(links)] <- links\n",
" out\n",
"}\n\n",
.open = "<<", .close = ">>")

user_facing_alias <- glue::glue("# user facing alias\n",
"# information in the title and details sections will be filled in\n",
"# automatically based on the information in the .model_<<model_name>>()$info\n \n",
"#' @title `r .model_<<model_name>>()$name`\n",
"#' @name Model Name",
"#' @details `r model_info(model_<<model_name>>())`\n",
"#' @details `r model_info(.model_<<model_name>>())`\n",
"#' @param resp_var1 A description of the response variable\n",
"#' @param required_arg1 A description of the required argument\n",
"#' @param required_arg2 A description of the required argument\n",
"#' @param links A list of links for the parameters.",
"#' @param ... used internally for testing, ignore it\n",
"#' @return An object of class `bmmmodel`\n",
"#' @export\n",
"#' @examples\n",
"#' \\dontrun{\n",
"#' # put a full example here (see 'R/bmm_model_mixture3p.R' for an example)\n",
"#' }\n",
"<<model_name>> <- function(resp_var1, required_arg1, required_arg2, ...) {\n",
"<<model_name>> <- function(resp_var1, required_arg1, required_arg2, links = NULL, ...) {\n",
" stop_missing_args()\n",
" .model_<<model_name>>(resp_var1 = resp_var1, required_arg1 = required_arg1,",
" required_arg2 = required_arg2, ...)\n",
" required_arg2 = required_arg2, links = links, ...)\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand All @@ -479,8 +480,7 @@ use_model_template <- function(model_name,
" # check the data (required)\n\n\n",
" # compute any necessary transformations (optional)\n\n",
" # save some variables as attributes of the data for later use (optional)\n\n",
" data = NextMethod('check_data')\n\n",
" return(data)\n",
" NextMethod('check_data')\n\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand All @@ -505,7 +505,7 @@ use_model_template <- function(model_name,
" brms_formula <- brms_formula + brms::lf(pform)\n",
" }\n",
" }\n\n",
" return(brms_formula)\n",
" brms_formula\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand Down Expand Up @@ -551,9 +551,9 @@ use_model_template <- function(model_name,
}

if (custom_family) {
out_template <- " out <- nlist(formula, data, family, prior, stanvars)\n"
out_template <- " nlist(formula, data, family, prior, stanvars)\n"
} else {
out_template <- " out <- nlist(formula, data, family, prior)\n"
out_template <- " nlist(formula, data, family, prior)\n"
}


Expand All @@ -574,14 +574,13 @@ use_model_template <- function(model_name,
" prior <- NULL\n\n",
" # return the list\n",
out_template,
" return(out)\n",
"}\n\n",
.open = "<<", .close = ">>")

postprocess_brm_method <- glue::glue("#' @export\n",
"postprocess_brm.<<model_name>> <- function(model, fit) {\n",
" # any required postprocessing (if none, delete this section)\n\n",
" return(fit)\n",
" fit\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand Down
10 changes: 4 additions & 6 deletions R/helpers-postprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ postprocess_brm <- function(model, fit, ...) {
postprocess_brm.bmmmodel <- function(model, fit, ...) {
dots <- list(...)
class(fit) <- c('bmmfit','brmsfit')
fit$bmm$fit_args <- dots$fit_args
fit$version$bmm <- utils::packageVersion('bmm')
fit$bmm$model <- model
fit$bmm$user_formula <- dots$user_formula
fit$bmm$configure_opts <- dots$configure_opts
fit$bmm <- nlist(model, user_formula = dots$user_formula, configure_opts = dots$configure_opts)
attr(fit$data, 'data_name') <- attr(dots$fit_args$data, 'data_name')

NextMethod('postprocess_brm')
fit <- NextMethod('postprocess_brm')
# clean up environments stored in the fit object
reset_env(fit)
}

#' @export
Expand Down
13 changes: 8 additions & 5 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#'
#' @name get_model_prior
#'
#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}.
#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}.
#'
#' @keywords extract_info
#'
Expand All @@ -44,10 +44,13 @@
#' }
#' @export
get_model_prior <- function(object, data, model, formula = object, ...) {
if (utils::packageVersion('brms') >= "2.20.14") {
message("get_model_prior is deprecated. Please use get_prior() or default_prior()")
} else {
message("get_model_prior is deprecated. Please use get_prior() instead.")
fcall <- as.character(match.call()[1])
if (fcall == "get_model_prior") {
if (utils::packageVersion('brms') >= "2.20.14") {
message("get_model_prior is deprecated. Please use get_prior() or default_prior()")
} else {
message("get_model_prior is deprecated. Please use get_prior() instead.")
}
}
if (missing(object) && !missing(formula)) {
warning2("The 'formula' argument is deprecated for consistency with brms (>= 2.20.14).",
Expand Down
64 changes: 57 additions & 7 deletions R/restructure.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
#' @importFrom assertthat assert_that
restructure.bmm <- function(x) {
assert_that(is_bmmfit(x) | !is.null(x$version$bmm), msg = "Please provide a bmmfit object")
#' Restructure Old \code{bmmfit} Objects
#'
#' Restructure old \code{bmmfit} objects to work with
#' the latest \pkg{bmm} version. This function is called
#' internally when applying post-processing methods.
#'
#' @param x An object of class \code{bmmfit}.
#' @param ... Currently ignored.
#'
#' @return A \code{bmmfit} object compatible with the latest version
#' of \pkg{bmm} and \pkg{brms}.
#' @keywords transform
#' @export
#' @importFrom utils packageVersion
restructure_bmm <- function(x, ...) {
version <- x$version$bmm
if (is.null(version)) {
version <- as.package_version('0.1.1')
version <- as.package_version('0.2.1')
x$version$bmm <- version
}
if (!inherits(x, 'bmmfit')) {
class(x) <- c('bmmfit', class(x))
}
current_version <- utils::packageVersion('bmm')
current_version <- packageVersion('bmm')
restr_version <- restructure_version.bmm(x)

if (restr_version >= current_version) {
if (packageVersion("brms") >= "2.20.15") {
x <- NextMethod('restructure')
} else {
x <- brms::restructure(x)
}
return(x)
}

Expand All @@ -25,8 +46,17 @@ restructure.bmm <- function(x) {
x$bmm$user_formula <- assign_nl(x$bmm$user_formula)
}

if (restr_version < "0.4.4") {
x$bmm$fit_args <- NULL
}

x$version$bmm_restructure <- current_version
brms::restructure(x)
if (packageVersion("brms") >= "2.20.15") {
x <- NextMethod('restructure')
} else {
x <- brms::restructure(x)
}
x
}

restructure_version.bmm <- function(x) {
Expand Down Expand Up @@ -56,6 +86,26 @@ add_links.bmmmodel <- function(x) {
}

add_bmm_info <- function(x) {
# TODO:
env <- x$family$env
if (is.null(env)) {
stop2("Unable to restructure the object for use with the latest version of bmm. Please refit.")
}
pforms <- env$formula$pforms
names(pforms) <- NULL
user_formula <- brms::do_call("bmf", pforms)
model = env$model
model$resp_vars <- list(resp_err = env$formula$resp)
model$other_vars <- list()
if (inherits(model, 'sdmSimple')) {
model$info$parameters$mu <- glue('Location parameter of the SDM distribution \\
(in radians; by default fixed internally to 0)')
} else {
model$info$parameters$mu1 = glue(
"Location parameter of the von Mises distribution for memory responses \\
(in radians). Fixed internally to 0 by default."
)
}

x$bmm <- nlist(model, user_formula)
x
}
7 changes: 5 additions & 2 deletions R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
#' options(bmm.color_summary = FALSE) or bmm_options(color_summary = FALSE)
#' @export
summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE, mc_se = FALSE, ..., backend = 'bmm') {
object <- restructure.bmm(object)
if (packageVersion('brms') < '2.20.15') {
object <- restructure_bmm(object)
} else {
object <- brms::restructure(object)
}
backend <- match.arg(backend, c('bmm', 'brms'))

# get summary object from brms, since it contains a lot of necessary information:
Expand All @@ -20,7 +24,6 @@ summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE,
out <- rename_mu_smry(out, get_mu_pars(object))

# get the bmm specific information
bmmargs <- object$bmm$fit_args
bmmmodel <- object$bmm$model
bmmform <- object$bmm$user_formula

Expand Down
9 changes: 6 additions & 3 deletions R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, ..
stop2("You cannot update with a different model.\n",
"If you want to use a different model, please use `fit_model()` instead.")
}
object <- restructure.bmm(object)
if (packageVersion('brms') < '2.20.15') {
object <- restructure_bmm(object)
} else {
object <- brms::restructure(object)
}

fit_args <- object$bmm$fit_args
model <- object$bmm$model
old_user_formula <- object$bmm$user_formula
olddata <- object$data
Expand Down Expand Up @@ -75,7 +78,7 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, ..
new_fit_args <- combine_args(nlist(config_args, dots))

# construct the new formula and data only if they have changed
if (!identical(new_fit_args$formula, fit_args$formula)) {
if (!identical(new_fit_args$formula, object$formula)) {
formula. <- new_fit_args$formula
}
if (!identical(new_fit_args$data, olddata)) {
Expand Down
Loading