Skip to content

Commit

Permalink
PPCA parens
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 17, 2025
1 parent 0caf553 commit 932a8b6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
4 changes: 0 additions & 4 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,6 @@ def m_step(

if self.use_proportions and likelihoods is not None:
self.update_proportions(likelihoods)
if self.log_proportions is not None:
assert (
len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit
)

fit_full_indices, fit_split_indices = quick_indices(
self.rg,
Expand Down
82 changes: 52 additions & 30 deletions src/dartsort/cluster/ppcalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

import linear_operator
from linear_operator import operators
from linear_operator.operators import CholLinearOperator
import torch
import torch.nn.functional as F
Expand All @@ -10,7 +11,7 @@

from ..util.noise_util import EmbeddedNoise
from .stable_features import SpikeFeatures, SpikeNeighborhoods
from ..util import spiketorch
from ..util import spiketorch, more_operators

vecdot = torch.linalg.vecdot

Expand All @@ -29,7 +30,7 @@ def ppca_em(
mean_prior_pseudocount=0.0,
show_progress=False,
W_initialization="svd",
normalize=False,
normalize=True,
em_converged_atol=0.1,
prior_var=1.0,
cache_global_direct=True,
Expand Down Expand Up @@ -145,6 +146,7 @@ def ppca_em(
break
if show_progress:
iters.set_description(f"PPCA[{dmu=:.2g}, {dW=:.2g}]")
print(i, dmu, dW)

if normalize and any_missing and state["W"] is not None:
_, _, state["W"], state["mu"] = embed(
Expand All @@ -156,6 +158,7 @@ def ppca_em(
state["W"],
state["mu"],
active_channels=active_channels,
ess=ess,
active_cov_chol_factor=active_cov_chol_factor,
prior_var=prior_var,
normalize=normalize,
Expand Down Expand Up @@ -216,18 +219,18 @@ def ppca_e_step(
D = rank * nc

# get normalized weights
y = sp.features
new_zeros = sp.features.new_zeros

# we will build our outputs by iterating over the unique
# neighborhoods and adding weighted sums of moments in each
e_y = y.new_zeros((rank, nc))
e_y = new_zeros((rank, nc))
yc = e_u = e_ycu = e_uu = None
if return_yc:
yc = y.new_zeros((n, rank, nc))
yc = new_zeros((n, rank, nc))
if yes_pca:
e_u = y.new_zeros((M,))
e_ycu = y.new_zeros((rank, nc, M))
e_uu = y.new_zeros((M, M))
e_u = new_zeros((M,))
e_ycu = new_zeros((rank, nc, M))
e_uu = new_zeros((M, M))

# helpful tensors to keep around
if yes_pca:
Expand Down Expand Up @@ -257,8 +260,8 @@ def ppca_e_step(
W_m = active_W[:, nd.missing_subset].reshape(D - nd.D_neighb, M)

if yes_pca:
ubar = full_ubar[nd.neighb_members]
uubar = full_uubar[nd.neighb_members]
ubar = full_ubar[nd.u_slice]
uubar = full_uubar[nd.u_slice]

# actual data in neighborhood
xcc = xc = nd.x - nu
Expand Down Expand Up @@ -312,7 +315,7 @@ def ppca_e_step(
wx = nd.w_norm @ nd.x
if nd.have_missing:
wxbar_m = nd.w_norm @ xbar_m
ybar = y.new_zeros((rank, nc))
ybar = new_zeros((rank, nc))
ybar[:, nd.active_subset] = wx.view(rank, nd.neighb_nc)
ybar[:, nd.missing_subset] = wxbar_m.view(rank, nc - nd.neighb_nc)
else:
Expand All @@ -324,7 +327,7 @@ def ppca_e_step(
if nd.have_missing and yes_pca:
wmxcu = nd.w_norm @ e_mxcu.reshape(nd.neighb_n_spikes, -1)
wmxcu = wmxcu.view(e_mxcu.shape[1:])
ycubar = y.new_zeros((rank, nc, M))
ycubar = new_zeros((rank, nc, M))
ycubar[:, nd.active_subset] = wxcu.view(rank, nd.neighb_nc, M)
ycubar[:, nd.missing_subset] = wmxcu.view(rank, nc - nd.neighb_nc, M)
elif yes_pca:
Expand All @@ -334,12 +337,12 @@ def ppca_e_step(
if return_yc:
if nd.have_missing:
xc = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc).mT
yc[nd.neighb_members[:, None], :, nd.active_subset[None, :]] = xc
yc[nd.u_slice][:, :, nd.active_subset[None, :]] = xc
xbar_m -= tnu
txc = xbar_m.view(nd.neighb_n_spikes, rank, nc - nd.neighb_nc).mT
yc[nd.neighb_members[:, None], :, nd.missing_subset[None, :]] = txc
yc[nd.u_slice][:, :, nd.missing_subset[None, :]] = txc
else:
yc[nd.neighb_members] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc)
yc[nd.u_slice] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc)

# accumulate results
e_y += ybar
Expand Down Expand Up @@ -367,14 +370,18 @@ def embed(
scratch=None,
):
N = len(sp)
new_zeros = sp.features.new_zeros
device = sp.features.device
dtype = sp.features.dtype

if scratch is not None:
_ubar, _uubar = scratch
else:
_ubar = sp.features.new_zeros((N, M))
_ubar = features.new_zeros((N, M))
# if not normalize:
_uubar = sp.features.new_zeros(N, M, M)
# _T = sp.features.new_zeros((N, M, M))
eye_M = prior_var * torch.eye(M, device=sp.features.device, dtype=sp.features.dtype)
_uubar = features.new_zeros(N, M, M)
eye_M_ = torch.eye(M, device=device, dtype=dtype)
eye_M = prior_var * eye_M_

for nd in neighb_data:
nu = active_mean[:, nd.active_subset].reshape(nd.D_neighb)
Expand All @@ -388,14 +395,22 @@ def embed(
# moments of embeddings
# T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
T_inv = eye_M + W_o.T @ nd.C_oo_inv @ W_o
T = torch.linalg.inv(T_inv)
u_proj = nd.C_oo_inv @ W_o @ T
# root = operators.LowRankRootLinearOperator(W_o.T @ nd.C_oo_cholinv)
# print(f"{root.shape=} {I_M.shape=}")
# helper = root + I_M
# helper = operators.LowRankRootSumLinearOperator(I_M
# print(f"{T_inv.shape=}")
# T = helper.solve(eye_M_)
T, info = torch.linalg.inv_ex(T_inv)
u_proj = nd.C_oo_inv @ (W_o @ T)
# ubar = Cooinvxc @ (W_o @ T)
ubar = xc @ u_proj
uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])
# ubar = xc @ u_proj
# uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])

_ubar[nd.neighb_members] = ubar
_uubar[nd.neighb_members] = uubar
# _ubar[nd.u_slice] = ubar
# _uubar[nd.u_slice] = uubar
torch.mm(xc, u_proj, out=_ubar[nd.u_slice])
torch.baddbmm(T, _ubar[nd.u_slice].unsqueeze(2), _ubar[nd.u_slice].unsqueeze(1), out=_uubar[nd.u_slice])

if normalize:
if active_cov_chol_factor is None:
Expand All @@ -412,14 +427,14 @@ def embed(
# active_mean = active_mean + W @ um

# whitening. need to do a GEVP to start...
S = (weights @ _uubar.view(N, M * M)).view(N, M, M)
S = (weights @ _uubar.view(N, M * M)).view(M, M)
Dx, U = torch.linalg.eigh(S)
Dx = Dx.flip(dims=(0,))
U = U.flip(dims=(1,))
U.mul_(sgn(U[0]))
UDxrt = U * Dx.sqrt()
rhs = Wflat @ UDxrt.T
gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs)
gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs, upper=False)
gevp_W = gevp_W_right.T @ gevp_W_right
# gevp_W = linear_operator.solve(lhs=rhs.T, input=active_cov, rhs=rhs)
Dw, V = torch.linalg.eigh(gevp_W)
Expand All @@ -435,7 +450,8 @@ def embed(
W @= W_tf
_ubar @= u_tf
_uubar = torch.einsum("nij,ip,jq->npq", _uubar, u_tf, u_tf)
active_mean.addmm_(W, um)
active_mean += W @ um
# .addmm_(W.view(-1, M), um.unsqueeze(1))

return _ubar, _uubar, W, active_mean

Expand All @@ -449,11 +465,13 @@ class NeighborhoodPPCAData:

C_oo: linear_operator.LinearOperator
C_oo_chol: CholLinearOperator
C_oo_cholinv: torch.Tensor
C_oo_inv: CholLinearOperator
w: torch.Tensor
w_norm: torch.Tensor
x: torch.Tensor
neighb_members: torch.Tensor
u_slice: torch.Tensor

C_mo: Optional[torch.Tensor]
active_subset: Optional[torch.Tensor]
Expand Down Expand Up @@ -516,14 +534,15 @@ def get_neighborhood_data(

neighborhood_data = []
ess = weights.sum()
n_start = 0
for chans_tuple, chans_data in dedup_data.items():
*info, xs, mems = chans_data
nid, neighb_chans, active_subset, can_cache_by_neighborhood, have_missing = info
if len(mems) > 1:
x = torch.concatenate(xs)
neighb_members = torch.concatenate(mems)
neighb_members, order = neighb_members.sort()
x = x[order]
# neighb_members, order = neighb_members.sort()
# x = x[order]
nid = None
else:
x = xs[0]
Expand Down Expand Up @@ -578,16 +597,19 @@ def get_neighborhood_data(
have_missing=have_missing,
C_oo=C_oo,
C_oo_chol=C_oo_chol,
C_oo_cholinv=Linv,
C_oo_inv=C_oo_inv,
w=w,
w_norm=w / ess,
x=x,
neighb_members=neighb_members,
u_slice=slice(n_start, n_start + n_neighb),
C_mo=C_mo,
active_subset=active_subset,
missing_subset=missing_subset,
)
neighborhood_data.append(nd)
n_start += n_neighb

return neighborhood_data

Expand Down

0 comments on commit 932a8b6

Please sign in to comment.