diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 7cfa9dc5..f8b9cc9c 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -2102,14 +2102,15 @@ def fit( self.register_buffer("mean", mean_full) if je_suis and res.get("W", None) is not None: self.register_buffer("W", W_full) - self.pick_channels(achans, res["nobs"]) + nobs = res["nobs"] if je_suis else None + self.pick_channels(achans, nobs) - def pick_channels(self, active_chans, nobs): + def pick_channels(self, active_chans, nobs=None): if self.channels_strategy == "all": self.register_buffer("channels", torch.arange(self.n_channels)) return - if not active_chans.numel(): + if not active_chans.numel() or nobs is None: return amp = torch.linalg.vector_norm(self.mean[:, active_chans], dim=0) diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 6b5627ee..8af79869 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -3,12 +3,14 @@ import numpy as np import pandas as pd import torch -from linear_operator import operators +from linear_operator import operators, settings from scipy.fftpack import next_fast_len from tqdm.auto import trange from ..util import drift_util, spiketorch +settings. + class FullNoise(torch.nn.Module): """Do not use this, it's just for comparison to the others.""" @@ -407,7 +409,9 @@ def marginal_covariance( return res if channels == slice(None): if self._full_cov is None: - self._full_cov = self._marginal_covariance() + fcov = self._marginal_covariance() + fcov = operators.CholLinearOperator(fcov.cholesky()) + self._full_cov = fcov self._logdet = self._full_cov.logdet() if device is not None and device != self._full_cov.device: self._full_cov = self._full_cov.to(device)