Skip to content

Commit

Permalink
Merge pull request #256 from venpopov/m3-priors-and-cleanup
Browse files Browse the repository at this point in the history
Add M3 example & adapt priors for custom version
  • Loading branch information
GidonFrischkorn authored Feb 7, 2025
2 parents 80b93eb + 66cf73a commit 188804f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 71 deletions.
127 changes: 92 additions & 35 deletions R/model_m3.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
list(
resp_vars = nlist(resp_cats),
other_vars = nlist(num_options, choice_rule),
domain = "Working Memory (categorical)",
domain = "Working Memory (categorical), Categorical Decision Making",
task = "n-AFC retrieval",
name = "The Memory Measurement Model by Oberauer & Lewandowsky (2019)",
name = "The Multinomial / Memory Measurement Model",
citation = glue(
"Oberauer, K., & Lewandowsky, S. (2019). Simple measurement models \\
for complex working-memory tasks. Psychological Review, 126."
Expand Down Expand Up @@ -99,10 +99,11 @@
#' @name m3
#'
#' @description
#' The Memory Measurement Model (M3) is a measurement model for working memory
#' tasks with categorical responses. It assumes that each candidate in each response
#' category is activated by a combination of sources of activation. The probability
#' of choosing a response category is determined by the activation of the candidates.
#' The Multinomial / Memory Measurement Model (M3) is a measurement model that was originally introduced
#' for working memory tasks with categorical responses. It assumes that each candidate in each response
#' category is activated by a combination of sources of activation. The probability of choosing a response
#' category is determined by the activation of the candidates. The model can be used for any n-AFC categorical
#' decision task.
#'
#' @param resp_cats The variable names that contain the number of responses for each of the
#' response categories used for the M3.
Expand Down Expand Up @@ -132,10 +133,47 @@
#'
#' @keywords bmmodel
#'
#' @examples
#' \dontrun{
#' # put a full example here (see 'R/bmm_model_mixture3p.R' for an example)
#' }
#' @examplesIf isTRUE(Sys.getenv("BMM_EXAMPLES"))
#' data <- oberauer_lewandowsky_2019_e1
#'
#' # initiate the model object
#' m3_model <- m3(
#' resp_cats = c("corr", "other", "dist", "npl"),
#' num_options = c("n_corr", "n_other", "n_dist", "n_npl"),
#' choice_rule = "simple"
#' )
#'
#' # specify the model formula including the activation formulas for each response category
#' m3_formula <- bmf(
#' corr ~ b + a + c,
#' other ~ b + a,
#' dist ~ b + d,
#' npl ~ b,
#' c ~ 1 + cond + (1 + cond | ID),
#' a ~ 1 + cond + (1 + cond | ID),
#' d ~ 1 + (1 | ID)
#' )
#'
#' # specify links for the model parameters
#' m3_model$links <- list(
#' c = "log",
#' a = "log",
#' d = "log"
#' )
#'
#' # check if the default priors are applied correctly
#' default_prior(m3_formula, data = data, model = m3_model)
#'
#' # fit the model
#' m3_fit <- bmm(
#' formula = m3_formula,
#' data = data,
#' model = m3_model,
#' cores = 4
#' )
#'
#' # print summary of the model
#' summary(m3_fit)
#'
#' @export
m3 <- function(resp_cats, num_options, choice_rule = "softmax", version = "custom", ...) {
Expand Down Expand Up @@ -185,17 +223,25 @@ check_model.m3_custom <- function(model, data = NULL, formula = NULL) {
missing_priors <- setdiff(missing_priors, names(model$fixed_parameters))
warnif(
length(missing_priors) > 0 && getOption("bmm.default_priors"),
"You have not provided default_priors for at least one parameter in the model.
Default priors will be specified internally based on the provided link function.
"Default priors for each parameter will be specified internally based on the provided link function.
Please check if the used priors are reasonable for your application"
)
additional_priors <- lapply(missing_priors, function(m) {
switch(model$links[[m]],
log = list(main = "normal(1,1)", effect = "normal(0,0.5)"),
identity = list(main = "normal(0,1)", effect = "normal(0,1)"),
logit = list(main = "normal(0,1)", effect = "normal(0,1)"),
stop2("Invalid link function provided! Please use one of the following link functions: identity, log, logit")
)
if (model$other_vars$choice_rule == "simple") {
switch(model$links[[m]],
log = list(main = "normal(1, 1)", effects = "normal(0, 0.5)"),
identity = list(main = "normal(10, 4)", effects = "normal(0, 1)"),
logit = list(main = "logistic(0, 1)", effects = "normal(0, 0.5)"),
stop2("Invalid link function provided! Please use one of the following link functions: identity, log, logit")
)
} else if (model$other_vars$choice_rule == "softmax") {
switch(model$links[[m]],
log = list(main = "normal(0, 1)", effects = "normal(0, 0.5)"),
identity = list(main = "normal(1, 1)", effects = "normal(0, 1)"),
logit = list(main = "logistic(0, 1)", effects = "normal(0, 0.5)"),
stop2("Invalid link function provided! Please use one of the following link functions: identity, log, logit")
)
}
})
model$default_priors <- c(model$default_priors, setNames(additional_priors, missing_priors))

Expand Down Expand Up @@ -252,7 +298,6 @@ check_data.m3 <- function(model, data, formula) {
NextMethod("check_data")
}


############################################################################# !
# CHECK_Formula S3 methods ####
############################################################################# !
Expand Down Expand Up @@ -310,38 +355,40 @@ bmf2bf.m3 <- function(model, formula) {
names(n_opt_idx_vars) <- resp_cats
names(options_vars) <- resp_cats

# add transformation to activation according to choice rules
choice_rule <- tolower(model$other_vars$choice_rule)
open <- ifelse(choice_rule == "simple", "log(", "")
close <- ifelse(choice_rule == "simple", ")", "")
zero_opt <- ifelse(choice_rule == "softmax", "(-100)", "exp(-100)")
operator <- ifelse(choice_rule == "softmax", "+", "*")
open_n_opts <- ifelse(choice_rule == "softmax", "log(", "")
close_n_opts <- ifelse(choice_rule == "softmax", ")", "")

# set the base brmsformula based
cat <- resp_cats[1]
brms_formula <- brms::bf(glue(
"Y | trials(nTrials) ~
{open}
{n_opt_idx_vars[cat]} * ({cat} {operator} {open_n_opts}{options_vars[cat]}{close_n_opts}) + (1 - {n_opt_idx_vars[cat]}) * {zero_opt}
{close}"
{n_opt_idx_vars[cat]} *", glue_choice_rule_functions(model$other_vars$choice_rule, cat, options_vars),
"+ (1 - {n_opt_idx_vars[cat]}) * (-100)"
), nl = TRUE)

# for each dependent parameter, check if it is used as a non-linear predictor of
# another parameter and add the corresponding brms function
for (cat in resp_cats[-1]) {
brms_formula <- brms_formula + glue_nlf(
"mu{cat} ~
{open}
{n_opt_idx_vars[cat]} * ({cat} {operator} {open_n_opts}{options_vars[cat]}{close_n_opts}) + (1 - {n_opt_idx_vars[cat]}) * {zero_opt}
{close}"
{n_opt_idx_vars[cat]} *", glue_choice_rule_functions(model$other_vars$choice_rule, cat, options_vars),
"+ (1 - {n_opt_idx_vars[cat]}) * (-100)"
)
}

brms_formula
}

#' @title glue the activation functions for the different choice rules
#'
#' @param choice_rule The choice rule that should be used for the M3. The options are "softmax" and "simple"
#' @param cat The name of the response category for which the activation function should be generated
#' @param options_vars The variable names that contain the number of candidates in each response category
#' @noRd
glue_choice_rule_functions <- function(choice_rule, cat, options_vars) {
switch(
choice_rule,
simple = glue("log({cat} * {options_vars[cat]})"),
softmax = glue("({cat} + log({options_vars[cat]}))")
)
}

############################################################################# !
# CONFIGURE_MODEL S3 METHODS ####
Expand All @@ -359,7 +406,15 @@ configure_model.m3 <- function(model, data, formula) {
formula$family$cats <- model$resp_vars$resp_cats
formula$family$dpars <- paste0("mu", model$resp_vars$resp_cats)

nlist(formula, data)
# set initial values to be set to zero if the choice rule is "simple" and "identity"
# link functions are used
if(model$other_vars$choice_rule == "simple" && any(model$links == "identity")){
init <- 0
} else {
init <- NULL
}

nlist(formula, data, init)
}


Expand Down Expand Up @@ -442,3 +497,5 @@ construct_m3_act_funs <- function(model = NULL, warnings = TRUE) {

act_funs
}


59 changes: 49 additions & 10 deletions man/m3.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 188804f

Please sign in to comment.