Skip to content

Commit

Permalink
Split defaults; device stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 8, 2024
1 parent 9b849cc commit 0ae17d4
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def __init__(
kmeans_drop_prop: float = 0.025,
kmeans_with_proportions: bool = False,
kmeans_kmeanspp_initial: str = "mean",
split_em_iter: int = 1,
split_em_iter: int = 0,
split_whiten: bool = True,
ppca_in_split: bool = False,
distance_metric: Literal["noise_metric", "kl", "reverse_kl", "js"] = "js",
distance_normalization_kind: Literal["none", "noise", "channels"] = "channels",
criterion_normalization_kind: Literal["none", "noise", "channels"] = "none",
Expand Down Expand Up @@ -147,7 +148,10 @@ def __init__(
ppca_inner_em_iter=ppca_inner_em_iter,
ppca_atol=ppca_atol,
)
self.split_unit_args = self.unit_args | dict(cov_kind="zero", ppca_rank=0)
if ppca_in_split:
self.split_unit_args = self.unit_args
else:
self.split_unit_args = self.unit_args | dict(cov_kind="zero", ppca_rank=0)

# clustering with noise unit to hopefully grab false positives
self.with_noise_unit = with_noise_unit
Expand Down Expand Up @@ -798,7 +802,7 @@ def dist_job(j, unit):
denom = self.noise_unit.divergence(
means, other_covs=covs, other_logdets=logdets, kind=kind
)
denom = denom.sqrt()
denom = denom.sqrt_().numpy(force=True)
dists[:, ids] /= denom[None, :]
dists[ids, :] /= denom[:, None]
elif normalization_kind == "channels":
Expand Down Expand Up @@ -1107,6 +1111,7 @@ def kmeans_split_unit(self, unit_id, debug=False):
if debug:
result["split_labels"] = split_labels
result["responsibilities"] = responsibilities
split_labels = split_labels.cpu()
split_ids, split_labels = split_labels.unique(return_inverse=True)
assert split_ids.min() >= 0
if split_labels.unique().numel() <= 1:
Expand Down Expand Up @@ -1241,6 +1246,7 @@ def mini_merge(
lls[j] = lls_
best_liks, labels = lls.max(dim=0)
labels[torch.isinf(best_liks)] = -1
labels = labels.cpu()
weights = F.softmax(lls, dim=0)

labels = labels.numpy(force=True)
Expand Down

0 comments on commit 0ae17d4

Please sign in to comment.