Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeking opinions on inference time marginalization in BRMS #1741

Open
lll6924 opened this issue Feb 6, 2025 · 3 comments
Open

Seeking opinions on inference time marginalization in BRMS #1741

lll6924 opened this issue Feb 6, 2025 · 3 comments
Labels

Comments

@lll6924
Copy link

lll6924 commented Feb 6, 2025

Hi everyone in this forum. I am a PhD student at University of Massachusetts Amherst working on Bayesian inference and probabilistic programming, with my advisor Professor Dan Sheldon. Recently we have worked on projects about inference time marginalization inside HMC (https://arxiv.org/pdf/2302.00564, https://arxiv.org/pdf/2410.24079). In particular, in the second paper we find that in many linear mixed-effects models it could be beneficial to integrate out one set of random effects during HMC sampling. The core technique is to exploit a block-diagonal structure of a transformed model, akin to what was in lme4 (2.3 in https://cran.r-project.org/web/packages/lme4/vignettes/lmer.pdf).

In our own (early) implementation inside BRMS, we find similar results. We added an argument to a forked BRMS to control the marginalization and implemented the corresponding Stan functions. We tried a simple model using the kidney dataset as belows:

fit1 <- brm(time ~ age + (age+1|disease*sex) + (1|patient), iter = 20000, data = kidney, family = "gaussian", prior = set_prior("cauchy(0,2)", class = "sd"), marginalize = NULL)

fit2 <- brm(time ~ age + (age+1|disease*sex) + (1|patient), iter = 20000, data = kidney, family = "gaussian", prior = set_prior("cauchy(0,2)", class = "sd"), marginalize = 'patient')

The running times are similar but HMC generates more effective samples if we choose to marginalize out the random effects of "patient". Without marginalization, the outputs from BRMS are

Multilevel Hyperparameters:
~disease (Number of levels: 4)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          4.37      8.67     0.08    25.13 1.00    15243    13718
sd(age)                0.70      0.65     0.02     2.44 1.00    10332    13142
cor(Intercept,age)    -0.02      0.57    -0.95     0.95 1.00    21959    21506

~disease:sex (Number of levels: 8)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          6.56     12.69     0.08    42.32 1.01      683     9957
sd(age)                0.85      0.66     0.03     2.42 1.00     1938     9676
cor(Intercept,age)     0.02      0.57    -0.95     0.95 1.00    11406    21189

~patient (Number of levels: 38)
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     4.64      8.46     0.08    30.52 1.00     1777    15509

~sex (Number of levels: 2)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          5.78     14.30     0.09    40.43 1.01      620     2165
sd(age)                1.71      2.86     0.04    14.10 1.01      229       50
cor(Intercept,age)     0.03      0.59    -0.95     0.95 1.00     1040    20248

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept   133.87     51.83    30.08   232.29 1.00     2642    19720
age          -1.45      1.40    -4.19     1.31 1.00     3403    13995

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma   127.27     10.89   108.75   150.72 1.00     2353    24372

But with marginalization, the outputs from BRMS are

Multilevel Hyperparameters:
~disease (Number of levels: 4)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          4.42      8.60     0.07    25.89 1.00    23748    18372
sd(age)                0.71      0.69     0.02     2.45 1.00    15289    14304
cor(Intercept,age)    -0.02      0.58    -0.95     0.95 1.00    20953    21206

~disease:sex (Number of levels: 8)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          6.03     12.12     0.08    41.98 1.00    21783    17672
sd(age)                0.83      0.64     0.04     2.36 1.00    12449    15323
cor(Intercept,age)     0.01      0.58    -0.95     0.95 1.00     9659    10498

~patient (Number of levels: 38)
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     4.81      8.66     0.08    31.97 1.00    24663     8272

~sex (Number of levels: 2)
                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)          6.20     17.21     0.08    41.08 1.00    16713    12188
sd(age)                1.28      1.35     0.04     4.80 1.00     3885     1680
cor(Intercept,age)    -0.01      0.58    -0.95     0.95 1.00    20439    23690

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept   132.99     52.07    29.77   234.22 1.00    23444    25411
age          -1.44      1.42    -4.28     1.30 1.00    10568     3705

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma   127.48     10.83   108.10   150.46 1.00    20993     9319

Before the next move, we would like to seek expert opinions on the topic, listed as the following questions.

  • The whole procedure assumes conjugacy, usually in the form of normal - (log) normal relationship. We are aware that the family of models is limited. In your experience, how widely are this type of models used in applied settings?
  • Do you think it could be beneficial to have a feature of inference time marginalization in BRMS, while keeping everything simple as above? We are thinking of two types of marginalization: one is to marginalize effects using some algorithms (ours, or possibly INLA); the other is to marginalize conjugate hyperparameters, done automatically (with minimal user specification) in the backend.

Many thanks to your attention.

@paul-buerkner
Copy link
Owner

It is exciting to hear you working on such features on top of brms!

About the questions you raised, my current thoughts are as follows:

(1) In the context of brms, this is indeed a very special case but an important one. To prevent too much special case coding, I would currently prefer not to implement this special case even if it targets an important subclass of models.

(2) Having algorithms available that automatically marginalize at inference time could indeed be cool. But wouldn't such a feature rather have to go into Stan than brms? perhaps I am misunderstanding what your concrete plans are

@lll6924
Copy link
Author

lll6924 commented Feb 13, 2025

Hi Paul,

Thank you so much for your feedback! Currently we are trying to develop a modular approach of marginalization in a forked version of BRMS. We hope users that need such an optimization could use it without much trouble.

Regarding marginalization in BRMS or in Stan, here are our thoughts. Marginalization is performed on an abstraction of the probabilistic model, which both BRMS and Stan can provide. However, it also requires a lot of structural information to be efficient, which is easier to obtain from BRMS than from Stan. Theoretically it is possible to do program tracing in Stan and match, for example, linear mixed-effects models from there, but we find it more direct to work on the formulas in BRMS.

@paul-buerkner
Copy link
Owner

That makes sense, thank you. I would love to see the forked version of brms once you feel it is is a good state. Then we can discuss further whether this fork should be merged into brms or if it should be staying stand-alone for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants