From 1e481f3708fe6e342b7fcf81841a44336c387b2c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 14 Jan 2025 12:25:24 -0500 Subject: [PATCH] Stricter channel strategies and corresponding changes to let ppca take advantage --- src/dartsort/cluster/gaussian_mixture.py | 153 ++++++++++++----- src/dartsort/cluster/ppcalib.py | 100 ++++++++--- src/dartsort/cluster/refine.py | 1 + src/dartsort/cluster/stable_features.py | 35 +++- src/dartsort/config.py | 1 + src/dartsort/util/internal_config.py | 11 +- .../util/testing_util/mixture_testing_util.py | 9 +- tests/test_dartsort.py | 18 ++ tests/test_modeling.py | 158 ++++++++++-------- 9 files changed, 349 insertions(+), 137 deletions(-) diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 505f32e7..72bed64c 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -27,7 +27,12 @@ from .kmeans import kmeans from .modes import smoothed_dipscore_at from .ppcalib import ppca_em -from .stable_features import SpikeFeatures, StableSpikeDataset, occupied_chans +from .stable_features import ( + SpikeFeatures, + SpikeNeighborhoods, + StableSpikeDataset, + occupied_chans, +) logger = logging.getLogger(__name__) @@ -55,7 +60,7 @@ def __init__( use_proportions: bool = True, proportions_sample_size: int = 2**16, likelihood_batch_size: int = 2**16, - channels_strategy: Literal["all", "snr", "count"] = "count", + channels_strategy: Literal["all", "snr", "count", "count_core"] = "count", channels_count_min: float = 25.0, channels_snr_amp: float = 1.0, with_noise_unit: bool = True, @@ -82,7 +87,14 @@ def __init__( merge_bimodality_threshold: float = 0.1, merge_criterion_threshold: float | None = 1.0, merge_criterion: Literal[ - "heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl" + "heldout_loglik", + "heldout_ccl", + "loglik", + "ccl", + "aic", + "bic", + "icl", + "bimodality", ] = "heldout_ccl", split_bimodality_threshold: float = 0.1, merge_bimodality_cut: float = 0.0, @@ -150,6 +162,7 @@ def __init__( # store arguments to the unit constructor in a dict self.ppca_rank = ppca_rank + self.channels_strategy = channels_strategy self.unit_args = dict( noise=noise, mean_kind=mean_kind, @@ -347,7 +360,7 @@ def em( log_liks, clean_props=convergence_props ) assert convergence_props is not None # for typing. - meanlogpx = spike_logliks[train_ix].mean() + meanlogpx = spike_logliks.mean() self.train_meanlogpxs.append(meanlogpx.item()) # M step: fit units based on responsibilities @@ -643,7 +656,7 @@ def update_proportions(self, log_liks): def reassign(self, log_liks): n_units = log_liks.shape[0] - self.with_noise_unit - assignments, spike_logliks, log_liks_csc = loglik_reassign( + spike_ix, assignments, spike_logliks, log_liks_csc = loglik_reassign( log_liks, has_noise_unit=self.with_noise_unit, log_proportions=self.log_proportions, @@ -651,33 +664,34 @@ def reassign(self, log_liks): assignments = torch.from_numpy(assignments).to(self.labels) # track reassignments + original = self.labels[spike_ix] same = torch.zeros_like(assignments) - torch.eq(self.labels, assignments, out=same) + torch.eq(original, assignments, out=same) # total number of reassigned spikes reassign_count = len(same) - same.sum() # helpers for intersection over union (kept,) = (assignments >= 0).nonzero(as_tuple=True) - (kept0,) = (self.labels >= 0).nonzero(as_tuple=True) + (kept0,) = (original >= 0).nonzero(as_tuple=True) # intersection intersection = torch.zeros(n_units, dtype=int) - spiketorch.add_at_(intersection, assignments[kept], same[kept]) + spiketorch.add_at_(intersection, assignments[kept], original[kept]) # union by include/exclude union = torch.zeros_like(intersection) _1 = union.new_ones((1,)) union -= intersection spiketorch.add_at_(union, assignments[kept], _1.broadcast_to(kept.shape)) - spiketorch.add_at_(union, self.labels[kept0], _1.broadcast_to(kept0.shape)) + spiketorch.add_at_(union, original[kept0], _1.broadcast_to(kept0.shape)) # define "churn" as 1-iou iou = intersection / union unit_churn = 1.0 - iou # update labels - self.labels.copy_(assignments) + self.labels[spike_ix] = assignments return unit_churn, reassign_count, spike_logliks, log_liks_csc @@ -975,15 +989,14 @@ def random_indices( n_full = indices_full.numel() split_indices = split_indices_full + indices = indices_full if max_size and n_full > max_size: - indices = self.rg.choice(n_full, size=max_size, replace=False) - indices.sort() - indices = torch.asarray(indices, device=indices_full.device) + choices = self.rg.choice(n_full, size=max_size, replace=False) + choices.sort() + choices = torch.asarray(choices, device=indices_full.device) if split_name is not None: - split_indices = split_indices_full[indices] - indices = indices_full[indices] - else: - indices = indices_full + split_indices = split_indices[choices] + indices = indices[choices] return indices_full, indices, split_indices @@ -1066,14 +1079,28 @@ def fit_unit( unit_args = self.unit_args | unit_args _, train_extract_neighborhoods = self.data.neighborhoods(neighborhood="extract") + core_neighborhoods = core_ids = None + if self.channels_strategy.endswith("core"): + assert features.split_indices is not None + _, core_neighborhoods = self.data.neighborhoods() + core_ids = core_neighborhoods.neighborhood_ids[features.split_indices] + if warm_start and unit_id in self: unit = self[unit_id] - unit.fit(features, weights, neighborhoods=train_extract_neighborhoods) + unit.fit( + features, + weights, + neighborhoods=train_extract_neighborhoods, + core_neighborhood_ids=core_ids, + core_neighborhoods=core_neighborhoods, + ) else: unit = GaussianUnit.from_features( features, weights, neighborhoods=train_extract_neighborhoods, + core_neighborhood_ids=core_ids, + core_neighborhoods=core_neighborhoods, **unit_args, ) return unit @@ -1382,10 +1409,18 @@ def mini_merge( for j, label in enumerate(unique_labels[big_enough]): (in_label,) = torch.nonzero(labels == label, as_tuple=True) features = spike_data[in_label.to(spike_data.indices.device)] + core_neighborhoods = core_neighborhood_ids = None + if self.channels_strategy.endswith("core"): + _, core_neighborhoods = self.data.neighborhoods() + core_neighborhood_ids = core_neighborhoods.neighborhood_ids[ + spike_data.split_indices + ] unit = GaussianUnit.from_features( features, weights=weights[j][in_label], neighborhoods=train_extract_neighborhoods, + core_neighborhoods=core_neighborhoods, + core_neighborhood_ids=core_neighborhood_ids, **self.split_unit_args, ) if unit.channels.numel(): @@ -2257,6 +2292,8 @@ def from_features( ppca_inner_em_iter=1, ppca_atol=0.05, scale_mean: float = 0.1, + core_neighborhoods=None, + core_neighborhood_ids=None, **annotations, ): self = cls( @@ -2276,7 +2313,13 @@ def from_features( ppca_atol=ppca_atol, **annotations, ) - self.fit(features, weights, neighborhoods=neighborhoods) + self.fit( + features, + weights, + neighborhoods=neighborhoods, + core_neighborhoods=core_neighborhoods, + core_neighborhood_ids=core_neighborhood_ids, + ) self = self.to(features.features.device) return self @@ -2306,6 +2349,8 @@ def fit( weights: Optional[torch.Tensor] = None, neighborhoods: Optional["SpikeNeighborhoods"] = None, show_progress: bool = False, + core_neighborhood_ids: Optional[torch.Tensor] = None, + core_neighborhoods: Optional["SpikeNeighborhoods"] = None, ): if features is None or len(features) < self.channels_count_min: self.pick_channels(None, None) @@ -2317,9 +2362,35 @@ def fit( features = features[kept] weights = weights[kept] - achans = occupied_chans(features, self.n_channels, neighborhoods=neighborhoods) + if self.channels_strategy.endswith("fuzzcore"): + achans_full = occupied_chans( + features, self.n_channels, neighborhoods=neighborhoods + ) + achans = occupied_chans( + features, + neighborhood_ids=core_neighborhood_ids, + n_channels=self.n_channels, + neighborhoods=core_neighborhoods, + fuzz=1, + ) + achans = achans[spiketorch.isin_sorted(achans, achans_full)] + needs_direct = True + elif self.channels_strategy.endswith("core"): + achans = occupied_chans( + features, + neighborhood_ids=core_neighborhood_ids, + n_channels=self.n_channels, + neighborhoods=core_neighborhoods, + ) + needs_direct = True + else: + achans = occupied_chans( + features, self.n_channels, neighborhoods=neighborhoods + ) + needs_direct = False + # achans = achans.cpu() - je_suis = achans.numel() + je_suis = bool(achans.numel()) do_pca = self.cov_kind == "ppca" and self.ppca_rank active_mean = active_W = None @@ -2344,6 +2415,7 @@ def fit( mean_prior_pseudocount=self.prior_pseudocount, show_progress=show_progress, W_initialization="zeros", + cache_local_direct=needs_direct, ) if hasattr(self, "mean"): @@ -2369,7 +2441,7 @@ def fit( self.pick_channels(achans, nobs) def pick_channels(self, active_chans, nobs=None): - if self.channels_strategy == "all": + if self.channels_strategy.startswith("all"): self.register_buffer("channels", torch.arange(self.n_channels)) return @@ -2384,12 +2456,12 @@ def pick_channels(self, active_chans, nobs=None): full_snr[active_chans] = snr self.snr = full_snr.cpu() - if self.channels_strategy == "snr": + if self.channels_strategy.startswith("snr"): snr_min = np.sqrt(self.channels_count_min) * self.channels_snr_amp strong = snr >= snr_min self.register_buffer("channels", active_chans[strong.cpu()]) return - if self.channels_strategy == "count": + if self.channels_strategy.startswith("count"): strong = nobs >= self.channels_count_min self.register_buffer("channels", active_chans[strong.cpu()]) return @@ -2726,16 +2798,15 @@ def marginal_loglik( def loglik_reassign( log_liks, has_noise_unit=False, proportions=None, log_proportions=None ): - log_liks_csc, assignments, spike_logliks = sparse_reassign( + nz_lines, log_liks_csc, assignments, spike_logliks = sparse_reassign( log_liks, - return_csc=True, proportions=proportions, log_proportions=log_proportions, ) n_units = log_liks.shape[0] - has_noise_unit if has_noise_unit: assignments[assignments >= n_units] = -1 - return assignments, spike_logliks, log_liks_csc + return nz_lines, assignments, spike_logliks, log_liks_csc def logmeanexp(x_csr): @@ -2753,9 +2824,7 @@ def logmeanexp(x_csr): return log_mean_exp -def sparse_reassign( - liks, match_threshold=None, return_csc=False, proportions=None, log_proportions=None -): +def sparse_reassign(liks, match_threshold=None, proportions=None, log_proportions=None): """Reassign spikes to units with largest likelihood liks is (n_units, n_spikes). This computes the argmax for each column, @@ -2765,24 +2834,28 @@ def sparse_reassign( this uses a numba replacement, but I'd like to upstream a cython version. """ if not liks.nnz: - return np.full(liks.shape[1], -1), np.full(liks.shape[1], -np.inf) + return ( + np.arange(0), + liks, + np.full(liks.shape[1], -1), + np.full(liks.shape[1], -np.inf), + ) # csc is needed here for this to be fast liks = liks.tocsc() nz_lines = np.flatnonzero(np.diff(liks.indptr)) + nnz = len(nz_lines) # see scipy csc argmin/argmax for reference here. this is just numba-ing # a special case of that code which has a python hot loop. - assignments = np.full(liks.shape[1], -1) + assignments = np.full(nnz, -1) # these will be filled with logsumexps - likelihoods = np.full(liks.shape[1], -np.inf, dtype=np.float32) + likelihoods = np.full(nnz, -np.inf, dtype=np.float32) # get log proportions, either given logs or otherwise... if log_proportions is None: if proportions is None: - log_proportions = np.full( - liks.shape[0], -np.log(liks.shape[0]), dtype=np.float32 - ) + log_proportions = np.full(nnz, -np.log(liks.shape[0]), dtype=np.float32) elif torch.is_tensor(proportions): log_proportions = proportions.log().numpy(force=True) else: @@ -2803,9 +2876,7 @@ def sparse_reassign( log_proportions, ) - if return_csc: - return liks, assignments, likelihoods - return assignments, likelihoods + return nz_lines, liks, assignments, likelihoods # csc can have int32 or 64 coos on dif platforms? is this an intp? :P @@ -2832,9 +2903,9 @@ def hot_argmax_loop( ix = indices[p:q] dx = data[p:q] + log_proportions[ix] best = dx.argmax() - assignments[i] = ix[best] + assignments[j] = ix[best] mx = dx.max() - scores[i] = mx + np.log(np.exp(dx - mx).sum()) + scores[j] = mx + np.log(np.exp(dx - mx).sum()) def bimodalities_dense( diff --git a/src/dartsort/cluster/ppcalib.py b/src/dartsort/cluster/ppcalib.py index 3252a02e..f01f5d17 100644 --- a/src/dartsort/cluster/ppcalib.py +++ b/src/dartsort/cluster/ppcalib.py @@ -32,7 +32,8 @@ def ppca_em( normalize=False, em_converged_atol=0.1, prior_var=1.0, - cache_direct=True, + cache_global_direct=True, + cache_local_direct=False, ): new_zeros = sp.features.new_zeros if active_W is not None: @@ -52,11 +53,19 @@ def ppca_em( ess = weights.sum() assert torch.isfinite(ess) neighb_data = get_neighborhood_data( - sp, neighborhoods, active_channels, rank, weights, D, noise, cache_prefix + sp, + neighborhoods, + active_channels, + rank, + weights, + D, + noise, + cache_prefix, + cache_direct=cache_local_direct, ) any_missing = any(nd.have_missing for nd in neighb_data) cache_kw = {} - if cache_direct: + if cache_global_direct: cache_kw = dict( cache_prefix="direct", cache_key=tuple(active_channels.tolist()) ) @@ -421,7 +430,6 @@ def embed( @dataclass(kw_only=True, frozen=True, slots=True) class NeighborhoodPPCAData: - neighb_id: int neighb_nc: int neighb_n_spikes: int D_neighb: int @@ -440,32 +448,91 @@ class NeighborhoodPPCAData: def get_neighborhood_data( - sp, neighborhoods, active_channels, rank, weights, D, noise, cache_prefix + sp, + neighborhoods, + active_channels, + rank, + weights, + D, + noise, + cache_prefix, + cache_direct=False, ): neighborhood_info, ns = neighborhoods.spike_neighborhoods( channels=active_channels, neighborhood_ids=sp.neighborhood_ids, min_coverage=0, ) - neighborhood_data = [] - ess = weights.sum() + + # two passes: first is deduplication + dedup_data = {} for nid, neighb_chans, neighb_members, _ in neighborhood_info: - n_neighb = neighb_members.numel() # -- neighborhood channels neighb_valid = neighborhoods.valid_mask(nid) # subset of neighborhood's chans which are active - # needs to be subset of full nchans set, not just the valid ones - # neighb_subset = spiketorch.isin_sorted(neighb_chans, active_channels) - neighb_subset = neighb_valid # those are the same. tested by assert blo. + # needs to be subset of full neighborhood channel set, not just the ones 1: + w = torch.concatenate(ws) + x = torch.concatenate(xs) + neighb_members = torch.concatenate(mems) + nid = None + else: + w = ws[0] + x = xs[0] + neighb_members = mems[0] + + n_neighb = neighb_members.numel() + cache_kw = {} + if cache_direct: + cache_kw = dict( + cache_prefix="direct", + cache_key=tuple(active_channels[active_subset].tolist()), + ) + elif can_cache_by_neighborhood: + cache_kw = dict( + cache_prefix=cache_prefix, + cache_key=nid, + ) + # -- missing channels - have_missing = not active_subset.all() missing_subset = missing_chans = None if have_missing: (missing_subset,) = torch.logical_not(active_subset).nonzero(as_tuple=True) @@ -477,10 +544,7 @@ def get_neighborhood_data( # -- neighborhood data device = sp.features.device C_oo = noise.marginal_covariance( - channels=neighb_chans, - cache_prefix=cache_prefix, - cache_key=nid, - device=device, + channels=neighb_chans, device=device, **cache_kw ) assert C_oo.shape == (D_neighb, D_neighb) C_oo_chol = CholLinearOperator(C_oo.cholesky()) @@ -492,11 +556,9 @@ def get_neighborhood_data( channels_right=neighb_chans, ) C_mo = C_mo.to_dense().to(device) - x = sp.features[neighb_members][:, :, neighb_subset] x = x.view(n_neighb, D_neighb) nd = NeighborhoodPPCAData( - neighb_id=nid, neighb_nc=neighb_nc, neighb_n_spikes=n_neighb, D_neighb=D_neighb, diff --git a/src/dartsort/cluster/refine.py b/src/dartsort/cluster/refine.py index c3e9d540..2c0db529 100644 --- a/src/dartsort/cluster/refine.py +++ b/src/dartsort/cluster/refine.py @@ -66,6 +66,7 @@ def refine_clustering( em_converged_prop=refinement_config.em_converged_prop, em_converged_churn=refinement_config.em_converged_churn, em_converged_atol=refinement_config.em_converged_atol, + channels_strategy=refinement_config.channels_strategy, ) gmm.cleanup() for it in range(refinement_config.n_total_iters): diff --git a/src/dartsort/cluster/stable_features.py b/src/dartsort/cluster/stable_features.py index 533e203d..0299dd5c 100644 --- a/src/dartsort/cluster/stable_features.py +++ b/src/dartsort/cluster/stable_features.py @@ -142,6 +142,9 @@ def __init__( device=device, ) self._train_extract_channels = extract_channels.cpu()[train_ixs] + core_channel_index = waveform_util.make_channel_index( + prgeom, core_radius, to_torch=True + ) _core_neighborhoods = { f"key_{k}": SpikeNeighborhoods.from_channels( core_channels[ix], @@ -152,11 +155,11 @@ def __init__( neighborhoods=core_channels[neighborhood_ix], features=core_features[ix] if k in _core_feature_splits else None, device=device, + channel_index=core_channel_index, ) for k, ix in self.split_indices.items() } self._core_neighborhoods = torch.nn.ModuleDict(_core_neighborhoods) - self.core_channels = core_channels.cpu() # channel neighborhoods and features @@ -496,6 +499,7 @@ def __init__( features=None, neighborhood_members=None, store_on_device: bool = False, + channel_index=None, device=None, ): """SpikeNeighborhoods @@ -521,6 +525,8 @@ def __init__( self.register_buffer("neighborhoods", neighborhoods) # self.neighborhoods = neighborhoods.cpu() self.n_neighborhoods = len(neighborhoods) + if channel_index is not None: + self.register_buffer("channel_index", channel_index) # store neighborhoods as an indicator matrix # also store nonzero-d masks @@ -591,6 +597,7 @@ def from_channels( device=None, deduplicate=False, features=None, + channel_index=None, ): if neighborhood_ids is not None: assert neighborhoods is not None @@ -602,6 +609,7 @@ def from_channels( device=device, deduplicate=deduplicate, features=features, + channel_index=channel_index, ) if device is not None: channels = channels.to(device) @@ -614,6 +622,7 @@ def from_channels( neighborhood_ids=neighborhood_ids, features=features, device=channels.device, + channel_index=channel_index, ) @classmethod @@ -626,6 +635,7 @@ def from_known_ids( device=None, deduplicate=False, features=None, + channel_index=None, ): neighborhoods = torch.asarray(neighborhoods) if device is not None: @@ -645,6 +655,7 @@ def from_known_ids( neighborhood_ids=neighborhood_ids, features=features, device=device, + channel_index=channel_index, ) def has_feature_cache(self): @@ -742,15 +753,24 @@ def spike_neighborhoods( # -- helpers -def occupied_chans(spike_data, n_channels, neighborhoods=None): +def occupied_chans( + spike_data, n_channels, neighborhood_ids=None, neighborhoods=None, fuzz=0 +): if spike_data.neighborhood_ids is None: chans = torch.unique(spike_data.channels) return chans[chans < n_channels] assert neighborhoods is not None - ids = torch.unique(spike_data.neighborhood_ids) + if neighborhood_ids is None: + neighborhood_ids = spike_data.neighborhood_ids + ids = torch.unique(neighborhood_ids) chans = neighborhoods.neighborhoods[ids] chans = torch.unique(chans) - return chans[chans < n_channels] + chans = chans[chans < n_channels] + for _ in range(fuzz): + chans = neighborhoods.channel_index[chans] + chans = torch.unique(chans) + chans = chans[chans < n_channels] + return chans def interp_to_chans( @@ -928,7 +948,12 @@ def get_stable_channels( workers=workers, ) - return extract_channels, core_channels, neighborhood_ids, neighborhood_ix + return ( + extract_channels, + core_channels, + neighborhood_ids, + neighborhood_ix, + ) def unique_with_index(x, dim=0): diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 3c4c6706..83a59a9e 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -166,6 +166,7 @@ class DeveloperConfig(DARTsortUserConfig): ] = "heldout_ccl" merge_bimodality_threshold: float = 0.05 n_refinement_iters: int = 3 + channels_strategy: str = "count_core" gmm_max_spikes: Annotated[int, Field(gt=0)] = 4_000_000 gmm_val_proportion: Annotated[float, Field(gt=0)] = 0.25 diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 54bb8abc..143295c4 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -327,6 +327,7 @@ class RefinementConfig: max_n_spikes: float | int = argfield(default=4_000_000, arg_type=int_or_inf) # model params + channels_strategy: str = "count" min_count: int = 50 signal_rank: int = 0 n_spikes_fit: int = 4096 @@ -338,7 +339,14 @@ class RefinementConfig: # if None, switches to bimodality merge_criterion_threshold: float | None = 0.0 merge_criterion: Literal[ - "heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl", "bimodality" + "heldout_loglik", + "heldout_ccl", + "loglik", + "ccl", + "aic", + "bic", + "icl", + "bimodality", ] = "heldout_ccl" merge_bimodality_threshold: float = 0.05 em_converged_prop: float = 0.02 @@ -491,6 +499,7 @@ def to_internal_config(cfg): n_total_iters=cfg.n_refinement_iters, max_n_spikes=cfg.gmm_max_spikes, val_proportion=cfg.gmm_val_proportion, + channels_strategy=cfg.channels_strategy, ) motion_estimation_config = MotionEstimationConfig( **{k.name: getattr(cfg, k.name) for k in fields(MotionEstimationConfig)} diff --git a/src/dartsort/util/testing_util/mixture_testing_util.py b/src/dartsort/util/testing_util/mixture_testing_util.py index e730e7c1..22813d5f 100644 --- a/src/dartsort/util/testing_util/mixture_testing_util.py +++ b/src/dartsort/util/testing_util/mixture_testing_util.py @@ -111,7 +111,6 @@ def simulate_moppca( if init_label_corruption: to_corrupt = rg.binomial(1, init_label_corruption, size=N) init_labels[to_corrupt] = rg.integers(K, size=to_corrupt.sum()) - print(f"{init_labels=}") init_sorting = dartsort.DARTsortSorting( times_samples=torch.arange(N), @@ -160,6 +159,7 @@ def fit_moppcas( em_converged_atol=0.05, with_noise_unit=True, return_before_fit=False, + channels_strategy="count", ): import dartsort @@ -179,6 +179,7 @@ def fit_moppcas( em_converged_prop=1e-6, n_threads=1, with_noise_unit=with_noise_unit, + channels_strategy=channels_strategy, ) torch.manual_seed(0) if not return_before_fit: @@ -196,6 +197,7 @@ def fit_ppca( show_progress=True, normalize=True, em_converged_atol=1e-6, + cache_local_direct=False, ): from dartsort.cluster import ppcalib @@ -213,6 +215,7 @@ def fit_ppca( show_progress=show_progress, normalize=normalize, em_converged_atol=em_converged_atol, + cache_local_direct=cache_local_direct, ) return res @@ -324,6 +327,7 @@ def test_ppca( show_vis=False, figsize=(4, 3), normalize=True, + cache_local=False, rg=0, ): rg = np.random.default_rng(rg) @@ -352,6 +356,7 @@ def test_ppca( show_progress=make_vis, normalize=normalize, em_converged_atol=em_converged_atol, + cache_local_direct=cache_local, ) muerr, werr, panel = compare_subspaces( @@ -392,6 +397,7 @@ def test_moppcas( make_vis=True, figsize=(4, 3), return_before_fit=False, + channels_strategy="count", snr=10.0, rg=0, ): @@ -420,6 +426,7 @@ def test_moppcas( em_converged_atol=em_converged_atol, with_noise_unit=with_noise_unit, return_before_fit=return_before_fit, + channels_strategy=channels_strategy, ) if return_before_fit: return dict(sim_res=sim_res, gmm=mm) diff --git a/tests/test_dartsort.py b/tests/test_dartsort.py index e791ab1a..cf99443d 100644 --- a/tests/test_dartsort.py +++ b/tests/test_dartsort.py @@ -80,6 +80,24 @@ def test_fakedata(): rec = sc.NumpyRecording(rec, fs) rec.set_dummy_probe_from_locations(geom) + with tempfile.TemporaryDirectory() as tempdir: + cfg = dartsort.DARTsortInternalConfig( + subtraction_config=dartsort.SubtractionConfig( + subtraction_denoising_config=dartsort.FeaturizationConfig( + denoise_only=True, do_nn_denoise=False + ) + ), + refinement_config=dartsort.RefinementConfig( + min_count=10, channels_strategy="count_fuzzcore" + ), + featurization_config=dartsort.FeaturizationConfig(n_residual_snips=512), + motion_estimation_config=dartsort.MotionEstimationConfig( + do_motion_estimation=False + ), + matching_iterations=0, + ) + res = dartsort.dartsort(rec, output_directory=tempdir, cfg=cfg) + with tempfile.TemporaryDirectory() as tempdir: cfg = dartsort.DARTsortInternalConfig( subtraction_config=dartsort.SubtractionConfig( diff --git a/tests/test_modeling.py b/tests/test_modeling.py index d83035f7..c312fc72 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -1,4 +1,5 @@ import numpy as np +import torch from dartsort.util.testing_util import mixture_testing_util @@ -9,47 +10,50 @@ t_mu_test = ("zero", "random") t_cov_test = ("eye", "random") t_w_test = ("zero", "hot", "random") +t_channels_strategy_test = ("count", "count_core") def test_ppca(): - for t_missing in t_missing_test: - for t_mu in t_mu_test: - for t_cov in t_cov_test: - for t_w in t_w_test: - print(f"{t_mu=} {t_cov=} {t_w=} {t_missing=}") - res = mixture_testing_util.test_ppca( - t_mu=t_mu, - t_cov=t_cov, - t_w=t_w, - t_missing=t_missing, - em_converged_atol=1e-1, - em_iter=100, - figsize=(3, 2.5), - make_vis=False, - show_vis=False, - normalize=False, - ) + for t_channels_strategy in t_channels_strategy_test: + for t_missing in t_missing_test: + for t_mu in t_mu_test: + for t_cov in t_cov_test: + for t_w in t_w_test: + print(f"{t_mu=} {t_cov=} {t_w=} {t_missing=}") + res = mixture_testing_util.test_ppca( + t_mu=t_mu, + t_cov=t_cov, + t_w=t_w, + t_missing=t_missing, + em_converged_atol=1e-1, + em_iter=100, + figsize=(3, 2.5), + make_vis=False, + show_vis=False, + normalize=False, + cache_local=t_channels_strategy.endswith("core"), + ) - mumse = np.square(res["muerr"]).mean() - mugood = mumse < mu_atol - assert mugood - Wgood = True - wmse = 0 - if "W" in res: - W = res["W"] - W = rank, nc, M = W.shape - WTW = W.reshape(rank * nc, M) - wmse = np.square(res["Werr"]).mean() - Wgood = wmse / np.square(WTW).mean() < wtw_rtol - assert Wgood + mumse = np.square(res["muerr"]).mean() + mugood = mumse < mu_atol + assert mugood + Wgood = True + wmse = 0 + if "W" in res: + W = res["W"] + W = rank, nc, M = W.shape + WTW = W.reshape(rank * nc, M) + wmse = np.square(res["Werr"]).mean() + Wgood = wmse / np.square(WTW).mean() < wtw_rtol + assert Wgood - # print(f"{mumse=} {wmse=}") - # if not (mugood and Wgood): - # print(f"{mugood=} {Wgood=}") - # plt.show() - # plt.close(res['panel']) - # assert False - # plt.close(res['panel']) + # print(f"{mumse=} {wmse=}") + # if not (mugood and Wgood): + # print(f"{mugood=} {Wgood=}") + # plt.show() + # plt.close(res['panel']) + # assert False + # plt.close(res['panel']) def test_mixture(): @@ -58,40 +62,54 @@ def test_mixture(): for t_w in t_w_test: # for t_w in ("zero", "random"): for t_missing in t_missing_test: - print(f"{t_mu=} {t_cov=} {t_w=} {t_missing=}") - kw = dict( - t_mu=t_mu, - t_cov=t_cov, - t_w=t_w, - t_missing=t_missing, - em_converged_atol=1e-3, - inner_em_iter=100, - figsize=(3, 2.5), - make_vis=False, - with_noise_unit=False, - snr=10.0, - ) - res = mixture_testing_util.test_moppcas( - **kw, return_before_fit=False - ) + for t_channels_strategy in t_channels_strategy_test: + print(f"{t_mu=} {t_cov=} {t_w=} {t_missing=}") + kw = dict( + t_mu=t_mu, + t_cov=t_cov, + t_w=t_w, + t_missing=t_missing, + em_converged_atol=1e-3, + inner_em_iter=100, + figsize=(3, 2.5), + make_vis=False, + with_noise_unit=False, + channels_strategy=t_channels_strategy, + snr=10.0, + ) + res = mixture_testing_util.test_moppcas( + **kw, return_before_fit=False + ) - mugood = np.square(res["muerrs"]).mean() < mu_atol - assert mugood - agood = res["acc"] >= 1.0 - assert agood + sf = res["sim_res"]["data"] + train = sf.split_indices["train"] + corechans1 = sf.core_channels[train] + assert torch.vmap(torch.isin)( + corechans1, sf._train_extract_channels + ).all() + _, coretrainneighb = sf.neighborhoods() + corechans2 = coretrainneighb.neighborhoods[ + coretrainneighb.neighborhood_ids + ] + assert torch.equal(corechans1, corechans2) - Wgood = True - if "W" in res: - W = res["W"] - k, rank, nc, M = W.shape - mss = 0.0 - for ww in W: - WTW = W.reshape(rank * nc, M) - mss = max(mss, np.square(WTW).mean()) - Wgood = np.square(res["Werrs"]).mean() / mss < wtw_rtol - assert Wgood - # if not (mugood and Wgood and agood): - # print(f"rerun. {mugood=} {Wgood=} {agood=}") - # mixture_testing_util.test_moppcas(**kw) + mugood = np.square(res["muerrs"]).mean() < mu_atol + assert mugood + agood = res["acc"] >= 1.0 + assert agood + + Wgood = True + if "W" in res: + W = res["W"] + k, rank, nc, M = W.shape + mss = 0.0 + for ww in W: + WTW = W.reshape(rank * nc, M) + mss = max(mss, np.square(WTW).mean()) + Wgood = np.square(res["Werrs"]).mean() / mss < wtw_rtol + assert Wgood + # if not (mugood and Wgood and agood): + # print(f"rerun. {mugood=} {Wgood=} {agood=}") + # mixture_testing_util.test_moppcas(**kw) - # assert False + # assert False