From 691179be4340ec3f9757faf3c73dfbab6200490e Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 5 Feb 2021 20:58:14 +0100 Subject: [PATCH] Fix tests --- pycave/bayes/_internal/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pycave/bayes/_internal/utils.py b/pycave/bayes/_internal/utils.py index 6f567e4..ecd0914 100644 --- a/pycave/bayes/_internal/utils.py +++ b/pycave/bayes/_internal/utils.py @@ -26,15 +26,15 @@ def log_normal(x, means, covars, covariance_type): num_features = x.size(1) precisions = 1 / covars - if covariance_type == 'diag': # diagonal covariance + if covariance_type in ('diag', 'diag-shared'): # diagonal covariance + if covariance_type == 'diag-shared': # shared diagonal covariance + num_means = means.size(0) + precisions = precisions.view(1, num_features).expand(num_means, num_features) cov_det = (-precisions.log()).sum(1) x_prob = torch.matmul(x * x, precisions.t()) m_prob = torch.einsum('ij,ij,ij->i', means, means, precisions) xm_prob = torch.matmul(x, (means * precisions).t()) - else: # spherical or shared diagonal covariance - if covariance_type == 'diag-shared': # shared diagonal covariance - num_means = means.size(0) - precisions = precisions.view(1, num_features).expand(num_means, num_features) + else: # spherical covariance cov_det = -precisions.log() * num_features x_prob = torch.ger(torch.einsum('ij,ij->i', x, x), precisions) m_prob = torch.einsum('ij,ij->i', means, means) * precisions