diff --git a/pycave/bayes/_internal/utils.py b/pycave/bayes/_internal/utils.py index 71ce117..6f567e4 100644 --- a/pycave/bayes/_internal/utils.py +++ b/pycave/bayes/_internal/utils.py @@ -26,15 +26,15 @@ def log_normal(x, means, covars, covariance_type): num_features = x.size(1) precisions = 1 / covars - if covariance_type == 'diag-shared': # shared diagonal covariance - num_means = means.size(0) - precisions = precisions.view(1, num_features).expand(num_means, num_features) - elif covariance_type == 'diag': # diagonal covariance + if covariance_type == 'diag': # diagonal covariance cov_det = (-precisions.log()).sum(1) x_prob = torch.matmul(x * x, precisions.t()) m_prob = torch.einsum('ij,ij,ij->i', means, means, precisions) xm_prob = torch.matmul(x, (means * precisions).t()) - else: # spherical covariance + else: # spherical or shared diagonal covariance + if covariance_type == 'diag-shared': # shared diagonal covariance + num_means = means.size(0) + precisions = precisions.view(1, num_features).expand(num_means, num_features) cov_det = -precisions.log() * num_features x_prob = torch.ger(torch.einsum('ij,ij->i', x, x), precisions) m_prob = torch.einsum('ij,ij->i', means, means) * precisions