Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Feb 5, 2021
1 parent b0297e1 commit 691179b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pycave/bayes/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def log_normal(x, means, covars, covariance_type):
num_features = x.size(1)
precisions = 1 / covars

if covariance_type == 'diag': # diagonal covariance
if covariance_type in ('diag', 'diag-shared'): # diagonal covariance
if covariance_type == 'diag-shared': # shared diagonal covariance
num_means = means.size(0)
precisions = precisions.view(1, num_features).expand(num_means, num_features)
cov_det = (-precisions.log()).sum(1)
x_prob = torch.matmul(x * x, precisions.t())
m_prob = torch.einsum('ij,ij,ij->i', means, means, precisions)
xm_prob = torch.matmul(x, (means * precisions).t())
else: # spherical or shared diagonal covariance
if covariance_type == 'diag-shared': # shared diagonal covariance
num_means = means.size(0)
precisions = precisions.view(1, num_features).expand(num_means, num_features)
else: # spherical covariance
cov_det = -precisions.log() * num_features
x_prob = torch.ger(torch.einsum('ij,ij->i', x, x), precisions)
m_prob = torch.einsum('ij,ij->i', means, means) * precisions
Expand Down

0 comments on commit 691179b

Please sign in to comment.