Skip to content

Commit

Permalink
Stricter channel strategies and corresponding changes to let ppca tak…
Browse files Browse the repository at this point in the history
…e advantage
  • Loading branch information
cwindolf committed Jan 14, 2025
1 parent 4fa23f4 commit 1e481f3
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 137 deletions.
153 changes: 112 additions & 41 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from .kmeans import kmeans
from .modes import smoothed_dipscore_at
from .ppcalib import ppca_em
from .stable_features import SpikeFeatures, StableSpikeDataset, occupied_chans
from .stable_features import (
SpikeFeatures,
SpikeNeighborhoods,
StableSpikeDataset,
occupied_chans,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +60,7 @@ def __init__(
use_proportions: bool = True,
proportions_sample_size: int = 2**16,
likelihood_batch_size: int = 2**16,
channels_strategy: Literal["all", "snr", "count"] = "count",
channels_strategy: Literal["all", "snr", "count", "count_core"] = "count",
channels_count_min: float = 25.0,
channels_snr_amp: float = 1.0,
with_noise_unit: bool = True,
Expand All @@ -82,7 +87,14 @@ def __init__(
merge_bimodality_threshold: float = 0.1,
merge_criterion_threshold: float | None = 1.0,
merge_criterion: Literal[
"heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl"
"heldout_loglik",
"heldout_ccl",
"loglik",
"ccl",
"aic",
"bic",
"icl",
"bimodality",
] = "heldout_ccl",
split_bimodality_threshold: float = 0.1,
merge_bimodality_cut: float = 0.0,
Expand Down Expand Up @@ -150,6 +162,7 @@ def __init__(

# store arguments to the unit constructor in a dict
self.ppca_rank = ppca_rank
self.channels_strategy = channels_strategy
self.unit_args = dict(
noise=noise,
mean_kind=mean_kind,
Expand Down Expand Up @@ -347,7 +360,7 @@ def em(
log_liks, clean_props=convergence_props
)
assert convergence_props is not None # for typing.
meanlogpx = spike_logliks[train_ix].mean()
meanlogpx = spike_logliks.mean()
self.train_meanlogpxs.append(meanlogpx.item())

# M step: fit units based on responsibilities
Expand Down Expand Up @@ -643,41 +656,42 @@ def update_proportions(self, log_liks):

def reassign(self, log_liks):
n_units = log_liks.shape[0] - self.with_noise_unit
assignments, spike_logliks, log_liks_csc = loglik_reassign(
spike_ix, assignments, spike_logliks, log_liks_csc = loglik_reassign(
log_liks,
has_noise_unit=self.with_noise_unit,
log_proportions=self.log_proportions,
)
assignments = torch.from_numpy(assignments).to(self.labels)

# track reassignments
original = self.labels[spike_ix]
same = torch.zeros_like(assignments)
torch.eq(self.labels, assignments, out=same)
torch.eq(original, assignments, out=same)

# total number of reassigned spikes
reassign_count = len(same) - same.sum()

# helpers for intersection over union
(kept,) = (assignments >= 0).nonzero(as_tuple=True)
(kept0,) = (self.labels >= 0).nonzero(as_tuple=True)
(kept0,) = (original >= 0).nonzero(as_tuple=True)

# intersection
intersection = torch.zeros(n_units, dtype=int)
spiketorch.add_at_(intersection, assignments[kept], same[kept])
spiketorch.add_at_(intersection, assignments[kept], original[kept])

# union by include/exclude
union = torch.zeros_like(intersection)
_1 = union.new_ones((1,))
union -= intersection
spiketorch.add_at_(union, assignments[kept], _1.broadcast_to(kept.shape))
spiketorch.add_at_(union, self.labels[kept0], _1.broadcast_to(kept0.shape))
spiketorch.add_at_(union, original[kept0], _1.broadcast_to(kept0.shape))

# define "churn" as 1-iou
iou = intersection / union
unit_churn = 1.0 - iou

# update labels
self.labels.copy_(assignments)
self.labels[spike_ix] = assignments

return unit_churn, reassign_count, spike_logliks, log_liks_csc

Expand Down Expand Up @@ -975,15 +989,14 @@ def random_indices(

n_full = indices_full.numel()
split_indices = split_indices_full
indices = indices_full
if max_size and n_full > max_size:
indices = self.rg.choice(n_full, size=max_size, replace=False)
indices.sort()
indices = torch.asarray(indices, device=indices_full.device)
choices = self.rg.choice(n_full, size=max_size, replace=False)
choices.sort()
choices = torch.asarray(choices, device=indices_full.device)
if split_name is not None:
split_indices = split_indices_full[indices]
indices = indices_full[indices]
else:
indices = indices_full
split_indices = split_indices[choices]
indices = indices[choices]

return indices_full, indices, split_indices

Expand Down Expand Up @@ -1066,14 +1079,28 @@ def fit_unit(
unit_args = self.unit_args | unit_args

_, train_extract_neighborhoods = self.data.neighborhoods(neighborhood="extract")
core_neighborhoods = core_ids = None
if self.channels_strategy.endswith("core"):
assert features.split_indices is not None
_, core_neighborhoods = self.data.neighborhoods()
core_ids = core_neighborhoods.neighborhood_ids[features.split_indices]

if warm_start and unit_id in self:
unit = self[unit_id]
unit.fit(features, weights, neighborhoods=train_extract_neighborhoods)
unit.fit(
features,
weights,
neighborhoods=train_extract_neighborhoods,
core_neighborhood_ids=core_ids,
core_neighborhoods=core_neighborhoods,
)
else:
unit = GaussianUnit.from_features(
features,
weights,
neighborhoods=train_extract_neighborhoods,
core_neighborhood_ids=core_ids,
core_neighborhoods=core_neighborhoods,
**unit_args,
)
return unit
Expand Down Expand Up @@ -1382,10 +1409,18 @@ def mini_merge(
for j, label in enumerate(unique_labels[big_enough]):
(in_label,) = torch.nonzero(labels == label, as_tuple=True)
features = spike_data[in_label.to(spike_data.indices.device)]
core_neighborhoods = core_neighborhood_ids = None
if self.channels_strategy.endswith("core"):
_, core_neighborhoods = self.data.neighborhoods()
core_neighborhood_ids = core_neighborhoods.neighborhood_ids[
spike_data.split_indices
]
unit = GaussianUnit.from_features(
features,
weights=weights[j][in_label],
neighborhoods=train_extract_neighborhoods,
core_neighborhoods=core_neighborhoods,
core_neighborhood_ids=core_neighborhood_ids,
**self.split_unit_args,
)
if unit.channels.numel():
Expand Down Expand Up @@ -2257,6 +2292,8 @@ def from_features(
ppca_inner_em_iter=1,
ppca_atol=0.05,
scale_mean: float = 0.1,
core_neighborhoods=None,
core_neighborhood_ids=None,
**annotations,
):
self = cls(
Expand All @@ -2276,7 +2313,13 @@ def from_features(
ppca_atol=ppca_atol,
**annotations,
)
self.fit(features, weights, neighborhoods=neighborhoods)
self.fit(
features,
weights,
neighborhoods=neighborhoods,
core_neighborhoods=core_neighborhoods,
core_neighborhood_ids=core_neighborhood_ids,
)
self = self.to(features.features.device)
return self

Expand Down Expand Up @@ -2306,6 +2349,8 @@ def fit(
weights: Optional[torch.Tensor] = None,
neighborhoods: Optional["SpikeNeighborhoods"] = None,
show_progress: bool = False,
core_neighborhood_ids: Optional[torch.Tensor] = None,
core_neighborhoods: Optional["SpikeNeighborhoods"] = None,
):
if features is None or len(features) < self.channels_count_min:
self.pick_channels(None, None)
Expand All @@ -2317,9 +2362,35 @@ def fit(
features = features[kept]
weights = weights[kept]

achans = occupied_chans(features, self.n_channels, neighborhoods=neighborhoods)
if self.channels_strategy.endswith("fuzzcore"):
achans_full = occupied_chans(
features, self.n_channels, neighborhoods=neighborhoods
)
achans = occupied_chans(
features,
neighborhood_ids=core_neighborhood_ids,
n_channels=self.n_channels,
neighborhoods=core_neighborhoods,
fuzz=1,
)
achans = achans[spiketorch.isin_sorted(achans, achans_full)]
needs_direct = True
elif self.channels_strategy.endswith("core"):
achans = occupied_chans(
features,
neighborhood_ids=core_neighborhood_ids,
n_channels=self.n_channels,
neighborhoods=core_neighborhoods,
)
needs_direct = True
else:
achans = occupied_chans(
features, self.n_channels, neighborhoods=neighborhoods
)
needs_direct = False

# achans = achans.cpu()
je_suis = achans.numel()
je_suis = bool(achans.numel())
do_pca = self.cov_kind == "ppca" and self.ppca_rank

active_mean = active_W = None
Expand All @@ -2344,6 +2415,7 @@ def fit(
mean_prior_pseudocount=self.prior_pseudocount,
show_progress=show_progress,
W_initialization="zeros",
cache_local_direct=needs_direct,
)

if hasattr(self, "mean"):
Expand All @@ -2369,7 +2441,7 @@ def fit(
self.pick_channels(achans, nobs)

def pick_channels(self, active_chans, nobs=None):
if self.channels_strategy == "all":
if self.channels_strategy.startswith("all"):
self.register_buffer("channels", torch.arange(self.n_channels))
return

Expand All @@ -2384,12 +2456,12 @@ def pick_channels(self, active_chans, nobs=None):
full_snr[active_chans] = snr
self.snr = full_snr.cpu()

if self.channels_strategy == "snr":
if self.channels_strategy.startswith("snr"):
snr_min = np.sqrt(self.channels_count_min) * self.channels_snr_amp
strong = snr >= snr_min
self.register_buffer("channels", active_chans[strong.cpu()])
return
if self.channels_strategy == "count":
if self.channels_strategy.startswith("count"):
strong = nobs >= self.channels_count_min
self.register_buffer("channels", active_chans[strong.cpu()])
return
Expand Down Expand Up @@ -2726,16 +2798,15 @@ def marginal_loglik(
def loglik_reassign(
log_liks, has_noise_unit=False, proportions=None, log_proportions=None
):
log_liks_csc, assignments, spike_logliks = sparse_reassign(
nz_lines, log_liks_csc, assignments, spike_logliks = sparse_reassign(
log_liks,
return_csc=True,
proportions=proportions,
log_proportions=log_proportions,
)
n_units = log_liks.shape[0] - has_noise_unit
if has_noise_unit:
assignments[assignments >= n_units] = -1
return assignments, spike_logliks, log_liks_csc
return nz_lines, assignments, spike_logliks, log_liks_csc


def logmeanexp(x_csr):
Expand All @@ -2753,9 +2824,7 @@ def logmeanexp(x_csr):
return log_mean_exp


def sparse_reassign(
liks, match_threshold=None, return_csc=False, proportions=None, log_proportions=None
):
def sparse_reassign(liks, match_threshold=None, proportions=None, log_proportions=None):
"""Reassign spikes to units with largest likelihood
liks is (n_units, n_spikes). This computes the argmax for each column,
Expand All @@ -2765,24 +2834,28 @@ def sparse_reassign(
this uses a numba replacement, but I'd like to upstream a cython version.
"""
if not liks.nnz:
return np.full(liks.shape[1], -1), np.full(liks.shape[1], -np.inf)
return (
np.arange(0),
liks,
np.full(liks.shape[1], -1),
np.full(liks.shape[1], -np.inf),
)

# csc is needed here for this to be fast
liks = liks.tocsc()
nz_lines = np.flatnonzero(np.diff(liks.indptr))
nnz = len(nz_lines)

# see scipy csc argmin/argmax for reference here. this is just numba-ing
# a special case of that code which has a python hot loop.
assignments = np.full(liks.shape[1], -1)
assignments = np.full(nnz, -1)
# these will be filled with logsumexps
likelihoods = np.full(liks.shape[1], -np.inf, dtype=np.float32)
likelihoods = np.full(nnz, -np.inf, dtype=np.float32)

# get log proportions, either given logs or otherwise...
if log_proportions is None:
if proportions is None:
log_proportions = np.full(
liks.shape[0], -np.log(liks.shape[0]), dtype=np.float32
)
log_proportions = np.full(nnz, -np.log(liks.shape[0]), dtype=np.float32)
elif torch.is_tensor(proportions):
log_proportions = proportions.log().numpy(force=True)
else:
Expand All @@ -2803,9 +2876,7 @@ def sparse_reassign(
log_proportions,
)

if return_csc:
return liks, assignments, likelihoods
return assignments, likelihoods
return nz_lines, liks, assignments, likelihoods


# csc can have int32 or 64 coos on dif platforms? is this an intp? :P
Expand All @@ -2832,9 +2903,9 @@ def hot_argmax_loop(
ix = indices[p:q]
dx = data[p:q] + log_proportions[ix]
best = dx.argmax()
assignments[i] = ix[best]
assignments[j] = ix[best]
mx = dx.max()
scores[i] = mx + np.log(np.exp(dx - mx).sum())
scores[j] = mx + np.log(np.exp(dx - mx).sum())


def bimodalities_dense(
Expand Down
Loading

0 comments on commit 1e481f3

Please sign in to comment.