Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/dartsort
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 4, 2024
2 parents 0daf0bd + 461eea2 commit cdd5298
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,14 +2102,15 @@ def fit(
self.register_buffer("mean", mean_full)
if je_suis and res.get("W", None) is not None:
self.register_buffer("W", W_full)
self.pick_channels(achans, res["nobs"])
nobs = res["nobs"] if je_suis else None
self.pick_channels(achans, nobs)

def pick_channels(self, active_chans, nobs):
def pick_channels(self, active_chans, nobs=None):
if self.channels_strategy == "all":
self.register_buffer("channels", torch.arange(self.n_channels))
return

if not active_chans.numel():
if not active_chans.numel() or nobs is None:
return

amp = torch.linalg.vector_norm(self.mean[:, active_chans], dim=0)
Expand Down
8 changes: 6 additions & 2 deletions src/dartsort/util/noise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import numpy as np
import pandas as pd
import torch
from linear_operator import operators
from linear_operator import operators, settings
from scipy.fftpack import next_fast_len
from tqdm.auto import trange

from ..util import drift_util, spiketorch

settings.


class FullNoise(torch.nn.Module):
"""Do not use this, it's just for comparison to the others."""
Expand Down Expand Up @@ -407,7 +409,9 @@ def marginal_covariance(
return res
if channels == slice(None):
if self._full_cov is None:
self._full_cov = self._marginal_covariance()
fcov = self._marginal_covariance()
fcov = operators.CholLinearOperator(fcov.cholesky())
self._full_cov = fcov
self._logdet = self._full_cov.logdet()
if device is not None and device != self._full_cov.device:
self._full_cov = self._full_cov.to(device)
Expand Down

0 comments on commit cdd5298

Please sign in to comment.