Skip to content

Commit

Permalink
Rescale after each iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 21, 2024
1 parent b3d55eb commit 977b769
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/dartsort/cluster/ppcalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,24 @@ def ppca_e_step(
)


def ppca_m_step(e_y, e_u, e_ycu, e_uu, ess, W_old, mean_prior_pseudocount):
def ppca_m_step(
e_y, e_u, e_ycu, e_uu, ess, W_old, mean_prior_pseudocount=10.0, rescale=True
):
"""Lightweight PPCA M step"""
rank, nc, M = e_ycu.shape
mu = e_y - W_old @ e_u
if mean_prior_pseudocount:
mu *= ess / (ess + mean_prior_pseudocount)
if e_u is None:
return dict(mu=mu, W=None)
if rescale:
sigma_u = e_uu - e_u[:, None] * e_u[None]
scales = sigma_u.diagonal().sqrt()
e_uu = e_uu / scales
e_ycu = e_ycu / scales
W = torch.linalg.solve(e_uu, e_ycu.view(rank * nc, M), left=False)
if rescale:
W.mul_(scales)
W = W.view(rank, nc, M)
return dict(mu=mu, W=W)

Expand Down

0 comments on commit 977b769

Please sign in to comment.