Skip to content

Commit

Permalink
Needs chol wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 10, 2025
1 parent 7bcf40e commit 1b11ab1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/dartsort/cluster/ppcalib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

import linear_operator
from linear_operator.operators import CholLinearOperator
import torch
import torch.nn.functional as F
from tqdm.auto import trange
Expand Down Expand Up @@ -251,8 +252,8 @@ def ppca_e_step(
C_oo = noise.marginal_covariance(
channels=neighb_chans, cache_prefix=cache_prefix, cache_key=nid, device=y.device
)
# TODO: genuinely confused about the need for this.
C_oochol = C_oo.cholesky()
# TODO: genuinely confused about the need for this. why doesn't solve() use this cached object?
C_oochol = CholLinearOperator(C_oo.cholesky())
nu = active_mean[:, active_subset].reshape(D_neighb)
if have_missing:
C_mo = noise.offdiag_covariance(
Expand Down Expand Up @@ -283,9 +284,9 @@ def ppca_e_step(
# pca-centered data
if yes_pca and have_missing:
CooinvWo = C_oochol.solve(W_o)
# xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
# Cooinvxcc = C_oochol.solve(xcc.T).T
Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
Cooinvxcc = C_oochol.solve(xcc.T).T
# Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
else:
Cooinvxcc = Cooinvxc

Expand Down

0 comments on commit 1b11ab1

Please sign in to comment.