From 932a8b68b71809c7b8ad3555cec8a9eeb59690e8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 17 Jan 2025 08:27:10 -0800 Subject: [PATCH] PPCA parens --- src/dartsort/cluster/gaussian_mixture.py | 4 -- src/dartsort/cluster/ppcalib.py | 82 +++++++++++++++--------- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 6ed7c050..611f08cc 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -453,10 +453,6 @@ def m_step( if self.use_proportions and likelihoods is not None: self.update_proportions(likelihoods) - if self.log_proportions is not None: - assert ( - len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit - ) fit_full_indices, fit_split_indices = quick_indices( self.rg, diff --git a/src/dartsort/cluster/ppcalib.py b/src/dartsort/cluster/ppcalib.py index 0838c1cc..21428b62 100644 --- a/src/dartsort/cluster/ppcalib.py +++ b/src/dartsort/cluster/ppcalib.py @@ -2,6 +2,7 @@ from typing import Optional import linear_operator +from linear_operator import operators from linear_operator.operators import CholLinearOperator import torch import torch.nn.functional as F @@ -10,7 +11,7 @@ from ..util.noise_util import EmbeddedNoise from .stable_features import SpikeFeatures, SpikeNeighborhoods -from ..util import spiketorch +from ..util import spiketorch, more_operators vecdot = torch.linalg.vecdot @@ -29,7 +30,7 @@ def ppca_em( mean_prior_pseudocount=0.0, show_progress=False, W_initialization="svd", - normalize=False, + normalize=True, em_converged_atol=0.1, prior_var=1.0, cache_global_direct=True, @@ -145,6 +146,7 @@ def ppca_em( break if show_progress: iters.set_description(f"PPCA[{dmu=:.2g}, {dW=:.2g}]") + print(i, dmu, dW) if normalize and any_missing and state["W"] is not None: _, _, state["W"], state["mu"] = embed( @@ -156,6 +158,7 @@ def ppca_em( state["W"], state["mu"], active_channels=active_channels, + ess=ess, active_cov_chol_factor=active_cov_chol_factor, prior_var=prior_var, normalize=normalize, @@ -216,18 +219,18 @@ def ppca_e_step( D = rank * nc # get normalized weights - y = sp.features + new_zeros = sp.features.new_zeros # we will build our outputs by iterating over the unique # neighborhoods and adding weighted sums of moments in each - e_y = y.new_zeros((rank, nc)) + e_y = new_zeros((rank, nc)) yc = e_u = e_ycu = e_uu = None if return_yc: - yc = y.new_zeros((n, rank, nc)) + yc = new_zeros((n, rank, nc)) if yes_pca: - e_u = y.new_zeros((M,)) - e_ycu = y.new_zeros((rank, nc, M)) - e_uu = y.new_zeros((M, M)) + e_u = new_zeros((M,)) + e_ycu = new_zeros((rank, nc, M)) + e_uu = new_zeros((M, M)) # helpful tensors to keep around if yes_pca: @@ -257,8 +260,8 @@ def ppca_e_step( W_m = active_W[:, nd.missing_subset].reshape(D - nd.D_neighb, M) if yes_pca: - ubar = full_ubar[nd.neighb_members] - uubar = full_uubar[nd.neighb_members] + ubar = full_ubar[nd.u_slice] + uubar = full_uubar[nd.u_slice] # actual data in neighborhood xcc = xc = nd.x - nu @@ -312,7 +315,7 @@ def ppca_e_step( wx = nd.w_norm @ nd.x if nd.have_missing: wxbar_m = nd.w_norm @ xbar_m - ybar = y.new_zeros((rank, nc)) + ybar = new_zeros((rank, nc)) ybar[:, nd.active_subset] = wx.view(rank, nd.neighb_nc) ybar[:, nd.missing_subset] = wxbar_m.view(rank, nc - nd.neighb_nc) else: @@ -324,7 +327,7 @@ def ppca_e_step( if nd.have_missing and yes_pca: wmxcu = nd.w_norm @ e_mxcu.reshape(nd.neighb_n_spikes, -1) wmxcu = wmxcu.view(e_mxcu.shape[1:]) - ycubar = y.new_zeros((rank, nc, M)) + ycubar = new_zeros((rank, nc, M)) ycubar[:, nd.active_subset] = wxcu.view(rank, nd.neighb_nc, M) ycubar[:, nd.missing_subset] = wmxcu.view(rank, nc - nd.neighb_nc, M) elif yes_pca: @@ -334,12 +337,12 @@ def ppca_e_step( if return_yc: if nd.have_missing: xc = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc).mT - yc[nd.neighb_members[:, None], :, nd.active_subset[None, :]] = xc + yc[nd.u_slice][:, :, nd.active_subset[None, :]] = xc xbar_m -= tnu txc = xbar_m.view(nd.neighb_n_spikes, rank, nc - nd.neighb_nc).mT - yc[nd.neighb_members[:, None], :, nd.missing_subset[None, :]] = txc + yc[nd.u_slice][:, :, nd.missing_subset[None, :]] = txc else: - yc[nd.neighb_members] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc) + yc[nd.u_slice] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc) # accumulate results e_y += ybar @@ -367,14 +370,18 @@ def embed( scratch=None, ): N = len(sp) + new_zeros = sp.features.new_zeros + device = sp.features.device + dtype = sp.features.dtype + if scratch is not None: _ubar, _uubar = scratch else: - _ubar = sp.features.new_zeros((N, M)) + _ubar = features.new_zeros((N, M)) # if not normalize: - _uubar = sp.features.new_zeros(N, M, M) - # _T = sp.features.new_zeros((N, M, M)) - eye_M = prior_var * torch.eye(M, device=sp.features.device, dtype=sp.features.dtype) + _uubar = features.new_zeros(N, M, M) + eye_M_ = torch.eye(M, device=device, dtype=dtype) + eye_M = prior_var * eye_M_ for nd in neighb_data: nu = active_mean[:, nd.active_subset].reshape(nd.D_neighb) @@ -388,14 +395,22 @@ def embed( # moments of embeddings # T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o) T_inv = eye_M + W_o.T @ nd.C_oo_inv @ W_o - T = torch.linalg.inv(T_inv) - u_proj = nd.C_oo_inv @ W_o @ T + # root = operators.LowRankRootLinearOperator(W_o.T @ nd.C_oo_cholinv) + # print(f"{root.shape=} {I_M.shape=}") + # helper = root + I_M + # helper = operators.LowRankRootSumLinearOperator(I_M + # print(f"{T_inv.shape=}") + # T = helper.solve(eye_M_) + T, info = torch.linalg.inv_ex(T_inv) + u_proj = nd.C_oo_inv @ (W_o @ T) # ubar = Cooinvxc @ (W_o @ T) - ubar = xc @ u_proj - uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :]) + # ubar = xc @ u_proj + # uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :]) - _ubar[nd.neighb_members] = ubar - _uubar[nd.neighb_members] = uubar + # _ubar[nd.u_slice] = ubar + # _uubar[nd.u_slice] = uubar + torch.mm(xc, u_proj, out=_ubar[nd.u_slice]) + torch.baddbmm(T, _ubar[nd.u_slice].unsqueeze(2), _ubar[nd.u_slice].unsqueeze(1), out=_uubar[nd.u_slice]) if normalize: if active_cov_chol_factor is None: @@ -412,14 +427,14 @@ def embed( # active_mean = active_mean + W @ um # whitening. need to do a GEVP to start... - S = (weights @ _uubar.view(N, M * M)).view(N, M, M) + S = (weights @ _uubar.view(N, M * M)).view(M, M) Dx, U = torch.linalg.eigh(S) Dx = Dx.flip(dims=(0,)) U = U.flip(dims=(1,)) U.mul_(sgn(U[0])) UDxrt = U * Dx.sqrt() rhs = Wflat @ UDxrt.T - gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs) + gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs, upper=False) gevp_W = gevp_W_right.T @ gevp_W_right # gevp_W = linear_operator.solve(lhs=rhs.T, input=active_cov, rhs=rhs) Dw, V = torch.linalg.eigh(gevp_W) @@ -435,7 +450,8 @@ def embed( W @= W_tf _ubar @= u_tf _uubar = torch.einsum("nij,ip,jq->npq", _uubar, u_tf, u_tf) - active_mean.addmm_(W, um) + active_mean += W @ um + # .addmm_(W.view(-1, M), um.unsqueeze(1)) return _ubar, _uubar, W, active_mean @@ -449,11 +465,13 @@ class NeighborhoodPPCAData: C_oo: linear_operator.LinearOperator C_oo_chol: CholLinearOperator + C_oo_cholinv: torch.Tensor C_oo_inv: CholLinearOperator w: torch.Tensor w_norm: torch.Tensor x: torch.Tensor neighb_members: torch.Tensor + u_slice: torch.Tensor C_mo: Optional[torch.Tensor] active_subset: Optional[torch.Tensor] @@ -516,14 +534,15 @@ def get_neighborhood_data( neighborhood_data = [] ess = weights.sum() + n_start = 0 for chans_tuple, chans_data in dedup_data.items(): *info, xs, mems = chans_data nid, neighb_chans, active_subset, can_cache_by_neighborhood, have_missing = info if len(mems) > 1: x = torch.concatenate(xs) neighb_members = torch.concatenate(mems) - neighb_members, order = neighb_members.sort() - x = x[order] + # neighb_members, order = neighb_members.sort() + # x = x[order] nid = None else: x = xs[0] @@ -578,16 +597,19 @@ def get_neighborhood_data( have_missing=have_missing, C_oo=C_oo, C_oo_chol=C_oo_chol, + C_oo_cholinv=Linv, C_oo_inv=C_oo_inv, w=w, w_norm=w / ess, x=x, neighb_members=neighb_members, + u_slice=slice(n_start, n_start + n_neighb), C_mo=C_mo, active_subset=active_subset, missing_subset=missing_subset, ) neighborhood_data.append(nd) + n_start += n_neighb return neighborhood_data