Skip to content

Commit

Permalink
Edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 4, 2024
1 parent ab00f47 commit cc7fc33
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,7 @@ def fit(

achans = occupied_chans(features, self.n_channels, neighborhoods=neighborhoods)
je_suis = achans.numel()
do_pca = self.cov_kind == "ppca" and self.ppca_rank

active_mean = active_W = None
if hasattr(self, "mean"):
Expand Down Expand Up @@ -2093,15 +2094,15 @@ def fit(
if hasattr(self, "W"):
W_full = self.W
W_full.fill_(0.0)
elif je_suis and res.get("W", None) is not None:
elif do_pca:
W_full = new_zeros((self.noise.rank, self.noise.n_channels, self.ppca_rank))

if je_suis:
mean_full[:, achans] = res["mu"]
if res.get("W", None) is not None:
W_full[:, achans] = res["W"]
self.register_buffer("mean", mean_full)
if je_suis and res.get("W", None) is not None:
if do_pca:
self.register_buffer("W", W_full)
nobs = res["nobs"] if je_suis else None
self.pick_channels(achans, nobs)
Expand Down

0 comments on commit cc7fc33

Please sign in to comment.