diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index f7f15a4e..f3a46119 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -249,16 +249,16 @@ def m_step(self, likelihoods=None, show_progress=False, prev_means=None): for j, unit in enumerate(zip(unit_ids, results)): assert unit.annotations["unit_id"] == j self.units.append(unit) - if self.log_proportions is not None: - # this is the index of the noise unit. it's got to be larger than - # the largest unit index - maxix = self.log_proportions.numel() - 1 - assert (unit_ids < maxix).all() - ixs = torch.cat((unit_ids, torch.tensor([maxix]))) - self.log_proportions = self.log_proportions[ixs] + # if self.log_proportions is not None: + # # this is the index of the noise unit. it's got to be larger than + # # the largest unit index + # maxix = self.log_proportions.numel() - 1 + # assert (unit_ids < maxix).all() + # ixs = torch.cat((unit_ids, torch.tensor([maxix]))) + # self.log_proportions = self.log_proportions[ixs] if prev_means is not None: nu = len(unit_ids) - prev_means = prev_means[unit_ids] + # prev_means = prev_means[unit_ids] new_means, *_ = self.stack_units(mean_only=True) dmu = (prev_means - new_means).abs_().view(nu, -1) adif = torch.max(dmu) @@ -1091,19 +1091,32 @@ def unit_group_criterion( if fit_type == "refit_all": units = [] - subunit_logliks = spikes_core.features.new_full((len(unit_ids), len(in_any)), -torch.inf) + subunit_logliks = spikes_core.features.new_full( + (len(unit_ids), len(in_any)), -torch.inf + ) full_loglik = 0.0 for i, k in enumerate(unit_ids): - u = self.fit_unit(unit_id=k, indices=in_any, likelihoods=likelihoods, features=spikes_extract) + u = self.fit_unit( + unit_id=k, + indices=in_any, + likelihoods=likelihoods, + features=spikes_extract, + ) units.append(u) - _, subunit_logliks[i] = self.unit_log_likelihoods(unit=u, spikes=spikes_core) + _, subunit_logliks[i] = self.unit_log_likelihoods( + unit=u, spikes=spikes_core + ) subunit_log_props = F.softmax(subunit_logliks, dim=0).mean(1).log_() # loglik per spik - full_loglik = torch.logsumexp(subunit_logliks.T + subunit_log_props, dim=1).mean() + full_loglik = torch.logsumexp( + subunit_logliks.T + subunit_log_props, dim=1 + ).mean() unit = self.fit_unit(indices=in_any, features=spikes_extract) likelihoods = None elif fit_type == "avg_preexisting": - unit = self.units[unit_ids[0]].avg_with(*[self.units[u] for u in unit_ids[1:]]) + unit = self.units[unit_ids[0]].avg_with( + *[self.units[u] for u in unit_ids[1:]] + ) if debug: subunit_logliks = likelihoods[:, in_any][unit_ids] full_loglik = marginal_loglik( @@ -1377,8 +1390,13 @@ def avg_with(self, *others): n_channels=self.n_channels, noise=self.noise, ) - new.register_buffer("mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others))) - new.register_buffer("channels", torch.cat([self.channels, *[o.channels for o in others]]).unique()) + new.register_buffer( + "mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others)) + ) + new.register_buffer( + "channels", + torch.cat([self.channels, *[o.channels for o in others]]).unique(), + ) assert self.cov_kind == "zero" return new @@ -1726,7 +1744,9 @@ def coo_to_scipy(coo_tensor): return coo_array((data, coords), shape=coo_tensor.shape) -def marginal_loglik(indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean"): +def marginal_loglik( + indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean" +): if unit_ids is not None: # renormalize log props log_proportions = log_proportions[unit_ids]