Skip to content

Commit

Permalink
Merge pull request #85 from venpopov/feature/issue-81-create-wrapper-…
Browse files Browse the repository at this point in the history
…functions-for-make_standata-and-make_stancode

Feature/issue 81 create wrapper functions for make standata and make stancode
  • Loading branch information
GidonFrischkorn authored Feb 10, 2024
2 parents 87dfd82 + 015d292 commit 269f421
Show file tree
Hide file tree
Showing 18 changed files with 536 additions and 96 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^vignettes/articles$
^doc$
^Meta$
^dev_utils$
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Suggests:
ggplot2,
mixtur,
ggthemes,
cowplot
cowplot,
stringr
Config/testthat/edition: 3
Imports:
brms,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ export(dmixture3p)
export(dsdm)
export(fit_model)
export(get_model_prior)
export(get_stancode)
export(get_stancode_parblock)
export(get_standata)
export(k2sd)
export(mixture2p)
export(mixture3p)
Expand Down
2 changes: 1 addition & 1 deletion R/fit_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' @param chains Numeric. Number of Markov chains (defaults to 4)
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions and combined using the c method or the + operator. See
#' also [brms::get_prior()] for more help. Not necessary for the default model
#' also [get_model_prior()] for more help. Not necessary for the default model
#' fitting, but you can provide prior constraints to model parameters
#' @param ... Further arguments passed to [brms::brm()] or Stan. See the
#' description of [brms::brm()] for more details
Expand Down
72 changes: 0 additions & 72 deletions R/get_model_prior.R

This file was deleted.

71 changes: 71 additions & 0 deletions R/helpers-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,74 @@ deg2rad <- function(deg){
rad2deg <- function(rad){
rad * 180 / pi
}


#' @title Generate data for `bmm` models to be passed to Stan
#' @description A wrapper around `brms::make_standata()` for models specified
#' with `bmm`. Given the `model`, the `data` and the `formula` for the model,
#' this function will return the combined stan data generated by `bmm` and
#' `brms`
#' @param formula An object of class `brmsformula`. A symbolic description of
#' the model to be fitted.
#' @param data An object of class data.frame, containing data of all variables
#' used in the model. The names of the variables must match the variable names
#' passed to the `bmmmodel` object for required argurments.
#' @param model A description of the model to be fitted. This is a call to a
#' `bmmmodel` such as `mixture3p()` function. Every model function has a
#' number of required arguments which need to be specified within the function
#' call. Call [supported_models()] to see the list of supported models and
#' their required arguments
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions and combined using the c method or the + operator. See
#' also [get_model_prior()] for more help. Not necessary for the default model
#' fitting, but you can provide prior constraints to model parameters
#' @param ... Further arguments passed to [brms::make_standata()]. See the
#' description of [brms::make_standata()] for more details
#'
#' @returns A named list of objects containing the required data to fit a bmm
#' model with Stan.
#'
#'
#' @seealso [supported_models()], [brms::make_standata()]
#'
#' @export
#'
#' @keywords extract_stan
#'
#' @examples
#' \dontrun{
#' # generate artificial data from the Signal Discrimination Model
#' dat <- data.frame(y=rsdm(n=2000))
#'
#' # define formula
#' ff <- brms::bf(y ~ 1,
#' c ~ 1,
#' kappa ~ 1)
#'
#' # fit the model
#' get_standata(formula = ff,
#' data = dat,
#' model = sdmSimple()
#' )
#' }
#'
get_standata <- function(formula, data, model, prior=NULL, ...) {

# check model, formula and data, and transform data if necessary
model <- check_model(model)
formula <- check_formula(model, formula)
data <- check_data(model, data, formula)

# generate the model specification to pass to brms later
config_args <- configure_model(model, data, formula)

# combine the default prior plus user given prior
config_args$prior <- combine_prior(config_args$prior, prior)

# extract stan code
dots <- list(...)
fit_args <- c(config_args, dots)
standata <- brms::do_call(brms::make_standata, fit_args)

return(standata)
}
123 changes: 123 additions & 0 deletions R/helpers-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,126 @@ use_model_template <- function(model_name,
cat(file_content)
}
}



#' @title Generate Stan code for bmm models
#' @description A wrapper around `brms::make_stancode()` for models specified with
#' `bmm`. Given the `model`, the `data` and the `formula` for the model, this
#' function will return the combined stan code generated by `bmm` and `brms`
#' @param formula An object of class `brmsformula`. A symbolic description of
#' the model to be fitted.
#' @param data An object of class data.frame, containing data of all variables
#' used in the model. The names of the variables must match the variable names
#' passed to the `bmmmodel` object for required argurments.
#' @param model A description of the model to be fitted. This is a call to a
#' `bmmmodel` such as `mixture3p()` function. Every model function has a
#' number of required arguments which need to be specified within the function
#' call. Call [supported_models()] to see the list of supported models and
#' their required arguments
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions and combined using the c method or the + operator. See
#' also [get_model_prior()] for more help. Not necessary for the default model
#' fitting, but you can provide prior constraints to model parameters
#' @param ... Further arguments passed to [brms::make_stancode()]. See the
#' description of [brms::make_stancode()] for more details
#'
#' @returns A character string containing the fully commented Stan code to fit a
#' bmm model.
#'
#'
#' @seealso [supported_models()], [brms::make_stancode()]
#'
#' @export
#'
#' @keywords extract_stan
#'
#' @examples
#' \dontrun{
#' # generate artificial data from the Signal Discrimination Model
#' dat <- data.frame(y=rsdm(n=2000))
#'
#' # define formula
#' ff <- brms::bf(y ~ 1,
#' c ~ 1,
#' kappa ~ 1)
#'
#' # fit the model
#' get_stancode(formula = ff,
#' data = dat,
#' model = sdmSimple()
#' )
#' }
#'
get_stancode <- function(formula, data, model, prior=NULL, ...) {

# check model, formula and data, and transform data if necessary
model <- check_model(model)
formula <- check_formula(model, formula)
data <- check_data(model, data, formula)

# generate the model specification to pass to brms later
config_args <- configure_model(model, data, formula)

# combine the default prior plus user given prior
config_args$prior <- combine_prior(config_args$prior, prior)

# extract stan code
dots <- list(...)
fit_args <- c(config_args, dots)
stancode <- brms::do_call(brms::make_stancode, fit_args)

return(stancode)
}



#' @title Get the parameter block from a generated Stan code for bmm models
#' @description A wrapper around `get_stancode()` for models specified with
#' `bmm`. Given the `model`, the `data` and the `formula` for the model, this
#' function will return just the parameters block. Useful for figuring out
#' which paramters you can set initial values on
#' @param formula An object of class `brmsformula`. A symbolic description of
#' the model to be fitted.
#' @param data An object of class data.frame, containing data of all variables
#' used in the model. The names of the variables must match the variable names
#' passed to the `bmmmodel` object for required argurments.
#' @param model A description of the model to be fitted. This is a call to a
#' `bmmmodel` such as `mixture3p()` function. Every model function has a
#' number of required arguments which need to be specified within the function
#' call. Call [supported_models()] to see the list of supported models and
#' their required arguments
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions and combined using the c method or the + operator. See
#' also [get_model_prior()] for more help. Not necessary for the default model
#' fitting, but you can provide prior constraints to model parameters
#' @param ... Further arguments passed to [brms::make_stancode()]. See the
#' description of [brms::make_stancode()] for more details
#'
#' @keywords extract_stan
#'
#' @returns A character string containing the parameter block of fully commented
#' Stan code to fit a bmm model.
#'
#'
#' @seealso [supported_models()], [get_stancode()]
#'
#' @export
get_stancode_parblock <- function(formula, data, model, prior=NULL, ...) {
stancode <- get_stancode(formula, data, model, prior, ...)
parblock <- .extract_parblock(stancode)
return(parblock)
}


#' @title Extract the parameter block from the Stan code
#' @description Given the Stan code for a model, this function will extract the
#' parameter block from the Stan code
#' @param stancode A character string containing the fully commented Stan code
#' @noRd
.extract_parblock <- function(stancode) {
parblock <- stringr::str_match(as.character(stancode),
"(?s)parameters \\{\\n(.*?)\\}\\ntransformed")[,2]
class(parblock) <- class(stancode)
return(parblock)
}
76 changes: 76 additions & 0 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,79 @@ combine_prior <- function(prior1, prior2) {
}
return(prior)
}


#' @title Get Default priors for Measurement Models specified in BRMS
#' @description Obtain the default priors for a Bayesian multilevel measurement model,
#' as well as information for which parameters priors can be specified.
#' Given the `model`, the `data` and the `formula` for the model, this function will return
#' the default priors that would be used to estimate the model. Additionally, it will
#' return all model parameters that have no prior specified (flat priors). This can help to
#' get an idea about which priors need to be specified and also know which priors were
#' used if no user-specified priors were passed to the [fit_model()] function.
#' @param formula An object of class `brmsformula`. A symbolic description of
#' the model to be fitted.
#' @param data An object of class data.frame, containing data of all variables
#' used in the model. The names of the variables must match the variable names
#' passed to the `bmmmodel` object for required argurments.
#' @param model A description of the model to be fitted. This is a call to a
#' `bmmmodel` such as `mixture3p()` function. Every model function has a
#' number of required arguments which need to be specified within the function
#' call. Call [supported_models()] to see the list of supported models and
#' their required arguments
#' @param ... Further arguments passed to [brms::get_prior()]. See the
#' description of [brms::get_prior()] for more details
#'
#' @details `r a= supported_models(); a`
#'
#' Type `help(package=bmm)` for a full list of available help topics.
#'
#' @returns A data.frame with columns specifying the `prior`, the `class`, the `coef` and `group`
#' for each of the priors specified. Separate rows contain the information on the
#' parameters (or parameter classes) for which priors can be specified.
#'
#'
#' @seealso [supported_models()], [brms::get_prior()]
#'
#' @keywords extract_stan
#'
#' @export
#'
#' @examples
#' \dontrun{
#' # generate artificial data from the Signal Discrimination Model
#' dat <- data.frame(y = rsdm(n=2000))
#'
#' # define formula
#' ff <- brms::bf(y ~ 1,
#' c ~ 1,
#' kappa ~ 1)
#'
#' # fit the model
#' get_model_prior(formula = ff,
#' data = dat,
#' model = sdmSimple()
#' )
#' }
#'
get_model_prior <- function(formula, data, model, ...) {

# check model, formula and data, and transform data if necessary
model <- check_model(model)
formula <- check_formula(model, formula)
data <- check_data(model, data, formula)

# generate the model specification to pass to brms later
config_args <- configure_model(model, data, formula)

# get priors for the model
dots <- list(...)
prior_args <- c(config_args, dots)
brms_priors <- brms::do_call(brms::get_prior, prior_args)

# combine the brms prior with the model default prior
combined_prior <- combine_prior(brms_priors, prior_args$prior)

return(combined_prior)
}

Loading

0 comments on commit 269f421

Please sign in to comment.