-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gaussian funsor variable elimination #559
Comments
@eb8680 it looks like @fehiepsi how much effort do you think it would it take for us to port your Pyro PR #2019 to funsor (where it would also be available in NumPyro 😉)? |
Here is the optimized GFVE schedule for my pyro-cov model. It fits in main memory but runs out of GPU memory.
The crux is this pair of Gaussian contractions with over 1e9 elements
I believe we can work around this using a combination of @fehiepsi's |
My impression is most of the details can be preserved (e.g. block vector, block matrix, align gaussian). Back then, one issue was batch qr is very slow on GPU, but torch linalg seems to have been improved a lot since then. |
@fehiepsi do you recall whether Cholesky was sufficient instead of QR? IIRC there was a PyTorch discussion about cheaply testing for positive definiteness or condition number using torch.linalg.cholesky_ex(). |
Looking at the code, I guess we need to triangulate a non-positive-definite precision matrix (e.g. zeros matrix) but I can't recall when we need such triangularization. :( Probably, it is unnecessary. (anyway, we can switch to qr if we face the positive definiteness issue) |
@eb8680 want to pair code next week on the high-level algorithm for variable elimination, continuing our work from https://github.com/pyro-ppl/funsor/compare/tractable-for-gaussians ? |
Sure! |
Addresses pyro-ppl/pyro#2929
See design doc
This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While
funsor.sum_product.sum_product()
is a partial solution, we'd like to generalize to a complete solution.Tasks
Introduce a new Funsor
ConditionalGaussian(info_vec, precision, conditional, inputs)
representing the batched conditional distribution of the rightmost real input variable, conditioned on other real input variables. This could be (i) a new Funsor in addition toGaussian
, (ii) a replacement or generalization ofGaussian
, or (iii) a special case ofGaussian
where the inputinfo_vec
andprecision
are structured (requires Refactor Gaussian info_vec,precision from backend arrays to Funsors #556). This may allow cheaper linear algebra.Alternatively Switch to sqrt(prescision) representation in Gaussian? #567
Temporary Workaround: naively scatter the three parameters
(info_vec, precision, conditional)
into a denseGaussian
. This can be much more computationally expensive.Handle collider variables where a latent variable outside a plate depends on an upstream latent variable inside a plate, thereby coupling the upstream variables via moralization. Currently such problems cannot even be specified in the plated-einsum DSL.
Temporary workaround: Globally break all plates out of which any arrow leads; equivalent to
.to_event()
.Handle complete bipartite graphs resulting from the RBM motif (
x_i --> y_ij <-- z_j
). Currentlysum_product()
and the TVE algorithm give up in this case with "intractable!".Temporary workaround: no known workaround
The text was updated successfully, but these errors were encountered: