Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/dartsort
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 4, 2024
2 parents fb9244c + eb6988e commit ab00f47
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
ppca_inner_em_iter=ppca_inner_em_iter,
ppca_atol=ppca_atol,
)
self.split_unit_args = self.unit_args | dict(cov_kind="zero", ppca_rank=0)

# clustering with noise unit to hopefully grab false positives
self.with_noise_unit = with_noise_unit
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def mini_merge(
features,
weights=w,
neighborhoods=self.data.extract_neighborhoods,
**self.unit_args,
**self.split_unit_args,
)
units.append(unit)

Expand Down Expand Up @@ -2222,14 +2223,16 @@ def kl_divergence(self, other_means, other_covs, other_logdets):
# other covs
k = my_cov.shape[0]
tr = k
ld = 0.0
if self.cov_kind == "ppca" and self.ppca_rank:
oW = other_covs.reshape(n, k, self.ppca_rank)
solve = my_cov.solve(oW)
ncov = self.noise.marginal_covariance().to_dense()
solve = solve @ oW.mT + my_cov.solve(ncov)
ncov = self.noise.full_dense_cov()
solve = solve @ oW.mT
tr = solve.diagonal(dim1=-2, dim2=-1).sum(dim=1)
ld = self_logdet - other_logdets
return 0.5 * (tr + inv_quad - k + ld)
tr += torch.trace(my_cov.solve(ncov))
ld = self_logdet - other_logdets
return 0.5 * (inv_quad + ((tr - k) + ld))


# -- utilities
Expand Down
6 changes: 6 additions & 0 deletions src/dartsort/util/noise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,11 @@ def whiten(self, data, channels=slice(None)):
assert res.ndim == 3 and res.shape == (*data.shape, 1)
return res

def full_dense_cov(self):
if self._full_cov is None:
self.marginal_covariance()
return self._full_cov_dense

def marginal_covariance(
self, channels=slice(None), cache_prefix=None, cache_key=None, device=None
):
Expand All @@ -408,6 +413,7 @@ def marginal_covariance(
if channels == slice(None):
if self._full_cov is None:
fcov = self._marginal_covariance()
self._full_cov_dense = fcov.to_dense()
fcov = operators.CholLinearOperator(fcov.cholesky())
self._full_cov = fcov
self._logdet = self._full_cov.logdet()
Expand Down

0 comments on commit ab00f47

Please sign in to comment.