Skip to content

Commit

Permalink
Merge pull request #400 from rsagroup/bures
Browse files Browse the repository at this point in the history
Bures similarity and distance
  • Loading branch information
JasperVanDenBosch authored Jul 8, 2024
2 parents be201fc + 5d1fdf8 commit 4331bf5
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 8 deletions.
20 changes: 20 additions & 0 deletions docs/source/comparing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,24 @@ Thus, we generally recommend using this :math:`\rho_a` measure now.

This comparison measure can be accessed using ``method='rho-a'`` or using ``rsatoolbox.rdm.compare_rho_a``.

Bures's similarity
--------------
These are a related similarity measure and distance introduced by harvey_2024_ , based on double centered kernel matrices :math:`K_1` and :math:`K_2`.
The normalized Bures similarity (NBS) is defined as:

.. math::
NBS(K_1, K_2) = \frac{\mathcal{F}(K_1, K_2)}{\sqrt{\operatorname{Tr}[K_1] \operatorname{Tr}[K_2]}}
\mathcal{F}(K_1, K_2) = \operatorname{Tr}[(K_1^{1/2}K_2K_1^{1/2})^{1/2}]
and :math:`\mathcal{F}` is known as the fidelity.

and relatedly the Bures distance :math:`\mathcal{B}`, a proper metric is defined as:

.. math::
\mathcal{B}^2(K_1, K_2) = \operatorname{Tr}[K_1] \operatorname{Tr}[K_2] - 2 \operatorname{Tr}[(K_1^{1/2}K_2K_1^{1/2})^{1/2}]
.. _Diedrichsen_2021: https://arxiv.org/abs/2007.02789
.. _harvey_2024: https://proceedings.mlr.press/v243/harvey24a
106 changes: 100 additions & 6 deletions src/rsatoolbox/rdm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rsatoolbox.util.rdm_utils import _get_n_from_reduced_vectors
from rsatoolbox.util.rdm_utils import _get_n_from_length
from rsatoolbox.util.matrix import row_col_indicator_g
from rsatoolbox.util.rdm_utils import batch_to_matrices


def compare(rdm1, rdm2, method='cosine', sigma_k=None):
Expand Down Expand Up @@ -74,6 +75,10 @@ def compare(rdm1, rdm2, method='cosine', sigma_k=None):
sim = compare_cosine_cov_weighted(rdm1, rdm2, sigma_k=sigma_k)
elif method == 'neg_riem_dist':
sim = compare_neg_riemannian_distance(rdm1, rdm2, sigma_k=sigma_k)
elif method == 'bures':
sim = compare_bures_similarity(rdm1, rdm2)
elif method == 'bures_metric':
sim = compare_bures_metric(rdm1, rdm2)
else:
raise ValueError('Unknown RDM comparison method requested!')
return sim
Expand Down Expand Up @@ -274,6 +279,52 @@ def compare_neg_riemannian_distance(rdm1, rdm2, sigma_k=None):
return sim


def compare_bures_similarity(rdm1, rdm2):
"""calculates the Bures similarity between two RDMs objects.
Args:
rdm1 (rsatoolbox.rdm.RDMs):
first set of RDMs
rdm2 (rsatoolbox.rdm.RDMs):
second set of RDMs
Returns:
numpy.ndarray: dist:
Bures similarity between the two RDMs
"""
vector1, vector2, _ = _parse_input_rdms(rdm1, rdm2)
G1, _, _ = batch_to_matrices(-vector1 / 2)
G2, _, _ = batch_to_matrices(-vector2 / 2)
s1 = np.mean(G1, 1, keepdims=True)
G1 = G1 - s1 - np.transpose(s1, (0, 2, 1)) + np.mean(s1, 2, keepdims=True)
s2 = np.mean(G2, 1, keepdims=True)
G2 = G2 - s2 - np.transpose(s2, (0, 2, 1)) + np.mean(s2, 2, keepdims=True)
sim = _all_combinations(G1, G2, _bures_similarity_first_way)
return sim


def compare_bures_metric(rdm1, rdm2):
"""calculates the squared Bures metric between two RDMs objects.
Args:
rdm1 (rsatoolbox.rdm.RDMs):
first set of RDMs
rdm2 (rsatoolbox.rdm.RDMs):
second set of RDMs
Returns:
numpy.ndarray: dist:
squared Bures metric between the two RDMs
"""
vector1, vector2, _ = _parse_input_rdms(rdm1, rdm2)
G1, _, _ = batch_to_matrices(-vector1 / 2)
G2, _, _ = batch_to_matrices(-vector2 / 2)
s1 = np.mean(G1, 1, keepdims=True)
G1 = G1 - s1 - np.transpose(s1, (0, 2, 1)) + np.mean(s1, 2, keepdims=True)
s2 = np.mean(G2, 1, keepdims=True)
G2 = G2 - s2 - np.transpose(s2, (0, 2, 1)) + np.mean(s2, 2, keepdims=True)
sim = _all_combinations(G1, G2, _sq_bures_metric_first_way)
return sim


def _all_combinations(vectors1, vectors2, func, *args, **kwargs):
"""runs a function func on all combinations of v1 in vectors1
and v2 in vectors2 and puts the results into an array
Expand All @@ -291,13 +342,9 @@ def _all_combinations(vectors1, vectors2, func, *args, **kwargs):
"""
value = np.empty((len(vectors1), len(vectors2)))
k1 = 0
for v1 in vectors1:
k2 = 0
for v2 in vectors2:
for k1, v1 in enumerate(vectors1):
for k2, v2 in enumerate(vectors2):
value[k1, k2] = func(v1, v2, *args, **kwargs)
k2 += 1
k1 += 1
return value


Expand Down Expand Up @@ -622,3 +669,50 @@ def _parse_input_rdms(rdm1, rdm2):
if not vector1_no_nan.shape[1] == vector2_no_nan.shape[1]:
raise ValueError('rdm1 and rdm2 have different nan positions')
return vector1_no_nan, vector2_no_nan, nan_idx[0]


def _sq_bures_metric_first_way(A, B):
va, ua = np.linalg.eigh(A)
Asq = ua @ (np.sqrt(np.maximum(va[:, None], 0.0)) * ua.T)
return (
np.trace(A) + np.trace(B)
- 2 * np.sum(np.sqrt(np.maximum(0.0, np.linalg.eigvalsh(Asq @ B @ Asq))))
)


def _sq_bures_metric_second_way(A, B):
va, ua = np.linalg.eigh(A)
vb, ub = np.linalg.eigh(B)
sva = np.sqrt(np.maximum(va, 0.0))
svb = np.sqrt(np.maximum(vb, 0.0))
return (
np.sum(va) + np.sum(vb) - 2 * np.sum(
np.linalg.svd(
(sva[:, None] * ua.T) @ (ub * svb[None, :]),
compute_uv=False
)
)
)


def _bures_similarity_first_way(A, B):
va, ua = np.linalg.eigh(A)
Asq = ua @ (np.sqrt(np.maximum(va[:, None], 0.0)) * ua.T)
num = np.sum(np.sqrt(np.maximum(np.linalg.eigvalsh(Asq @ B @ Asq), 0.0)))
denom = np.sqrt(np.trace(A) * np.trace(B))
return num / denom


def _bures_similarity_second_way(A, B):
va, ua = np.linalg.eigh(A)
vb, ub = np.linalg.eigh(B)
sva = np.sqrt(np.maximum(va, 0.0))
svb = np.sqrt(np.maximum(vb, 0.0))
num = np.sum(
np.linalg.svd(
(sva[:, None] * ua.T) @ (ub * svb[None, :]),
compute_uv=False
)
)
denom = np.sqrt(np.sum(va) * np.sum(vb))
return num / denom
55 changes: 53 additions & 2 deletions tests/test_compare_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,40 @@
import unittest
import numpy as np
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_almost_equal
import rsatoolbox as rsa


class TestCompareRDM(unittest.TestCase):

def setUp(self):
self.rng = np.random.default_rng(0)
dissimilarities1 = self.rng.random((1, 15))
x = self.rng.random((20, 6))
x -= np.mean(x, 1, keepdims=True)
self.k1 = x.T @ x
diag = np.diag(self.k1)
dist = (
np.expand_dims(diag, 0)
+ np.expand_dims(diag, 1)
- 2 * self.k1)
dissimilarities1 = dist[np.triu_indices(6, 1)]
des1 = {'session': 0, 'subj': 0}
self.test_rdm1 = rsa.rdm.RDMs(
dissimilarities=dissimilarities1,
dissimilarity_measure='test',
descriptors=des1)
dissimilarities2 = self.rng.random((3, 15))
x = self.rng.random((3, 20, 6))
x -= np.mean(x, 2, keepdims=True)
self.k2 = np.zeros((3, 6, 6))
dissimilarities2 = np.zeros((3, 15))
for i in range(3):
self.k2[i] = x[i].T @ x[i]
diag = np.diag(self.k2[i])
dist = (
np.expand_dims(diag, 0)
+ np.expand_dims(diag, 1)
- 2 * self.k2[i])
dissimilarities2[i] = dist[np.triu_indices(6, 1)]
des2 = {'session': 0, 'subj': 0}
self.test_rdm2 = rsa.rdm.RDMs(
dissimilarities=dissimilarities2,
Expand Down Expand Up @@ -196,6 +216,35 @@ def test_compare_kendall_tau_a(self):
result = compare_kendall_tau_a(self.test_rdm1, self.test_rdm2)
assert np.all(result < 1)

def test_compare_bures_similarity(self):
from rsatoolbox.rdm.compare import compare_bures_similarity
result = compare_bures_similarity(self.test_rdm1, self.test_rdm1)
assert_array_almost_equal(result, 1)
result = compare_bures_similarity(self.test_rdm1, self.test_rdm2)
assert np.all(result < 1)
# check that Kernel transform is ok
from rsatoolbox.rdm.compare import _bures_similarity_first_way
from rsatoolbox.rdm.compare import _bures_similarity_second_way
d_right1 = _bures_similarity_first_way(self.k1, self.k2[0])
d_right2 = _bures_similarity_second_way(self.k1, self.k2[0])
assert_almost_equal(d_right1, d_right2)
assert_almost_equal(d_right1, result[0, 0])
assert_almost_equal(d_right2, result[0, 0])

def test_compare_bures_metric(self):
from rsatoolbox.rdm.compare import compare_bures_metric
result = compare_bures_metric(self.test_rdm1, self.test_rdm1)
assert_array_almost_equal(result, 0)
result = compare_bures_metric(self.test_rdm1, self.test_rdm2)
# check that Kernel transform is ok
from rsatoolbox.rdm.compare import _sq_bures_metric_first_way
from rsatoolbox.rdm.compare import _sq_bures_metric_second_way
d_right1 = _sq_bures_metric_first_way(self.k1, self.k2[0])
d_right2 = _sq_bures_metric_second_way(self.k1, self.k2[0])
assert_almost_equal(d_right1, d_right2)
assert_almost_equal(d_right1, result[0, 0])
assert_almost_equal(d_right2, result[0, 0])

def test_compare(self):
from rsatoolbox.rdm.compare import compare
result = compare(self.test_rdm1, self.test_rdm1)
Expand All @@ -206,6 +255,8 @@ def test_compare(self):
result = compare(self.test_rdm1, self.test_rdm2, method='cosine')
result = compare(self.test_rdm1, self.test_rdm2, method='cosine_cov')
result = compare(self.test_rdm1, self.test_rdm2, method='kendall')
result = compare(self.test_rdm1, self.test_rdm2, method='bures')
result = compare(self.test_rdm1, self.test_rdm2, method='bures_metric')


class TestCompareRDMNaN(unittest.TestCase):
Expand Down

0 comments on commit 4331bf5

Please sign in to comment.