From 3041f0d8856e0d228a234a3bebc4f74d77a7d464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Heiko=20Sch=C3=BCtt?= Date: Mon, 29 Jul 2024 12:09:44 +0200 Subject: [PATCH 1/2] allow passing a target descriptor for ordering in concatenation. Also check uniqueness --- src/rsatoolbox/rdm/rdms.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/rsatoolbox/rdm/rdms.py b/src/rsatoolbox/rdm/rdms.py index 447ed480..5a379e1a 100644 --- a/src/rsatoolbox/rdm/rdms.py +++ b/src/rsatoolbox/rdm/rdms.py @@ -542,7 +542,7 @@ def load_rdm(filename, file_type=None): return rdms_from_dict(rdm_dict) -def concat(*rdms: RDMs) -> RDMs: +def concat(*rdms: RDMs, target_pdesc: Optional[str] = None) -> RDMs: """Merge into single RDMs object requires that the rdms have the same shape descriptor and pattern descriptors are taken from the first rdms object @@ -552,6 +552,7 @@ def concat(*rdms: RDMs) -> RDMs: Args: rdms(iterable of rsatoolbox.rdm.RDMs): RDMs objects to be concatenated or multiple RDMs as separate arguments + target_pdesc(optional, str): a pattern descriptor to use for sorting Returns: rsatoolbox.rdm.RDMs: concatenated rdms object @@ -569,14 +570,24 @@ def concat(*rdms: RDMs) -> RDMs: descriptors, rdm_descriptors = _merged_rdm_descriptors(rdms_list) - ## see if we can find an authoritative descriptor for pattern order - pdescs = rdms_list[0].pattern_descriptors.keys() - pdesc_candidates = list(filter(lambda n: n!='index', pdescs)) - target_pdesc = None - if len(pdesc_candidates) > 0: - target_pdesc = pdesc_candidates[0] - if len(pdesc_candidates) > 1: - warnings.warn(f'[concat] Multiple pattern descriptors found, using "{target_pdesc}"') + if target_pdesc is None: + # see if we can find an authoritative descriptor for pattern order + pdescs = rdms_list[0].pattern_descriptors.keys() + pdesc_candidates = list(filter( + lambda n: n != 'index' and ( + len(rdms_list[0].pattern_descriptors[n]) + == len(set(rdms_list[0].pattern_descriptors[n]))), + pdescs)) + target_pdesc = None + if len(pdesc_candidates) > 0: + target_pdesc = pdesc_candidates[0] + if len(pdesc_candidates) > 1: + warnings.warn(f'[concat] Multiple pattern descriptors found, using "{target_pdesc}"') + else: + assert target_pdesc in rdms_list[0].pattern_descriptors.keys(), \ + 'The provided descriptor is not a pattern descriptor' + assert len(rdms_list[0].pattern_descriptors[target_pdesc]) == rdms_list[0].n_cond, \ + 'The provided descriptor is not unique' for rdm_new in rdms_list[1:]: assert isinstance(rdm_new, RDMs), 'rdm for concat should be an RDMs' @@ -584,12 +595,12 @@ def concat(*rdms: RDMs) -> RDMs: assert rdm_new.dissimilarity_measure == rdms_list[0].dissimilarity_measure, \ 'appended rdm had wrong dissimilarity measure' if target_pdesc: - ## if we have a target descriptor, check if the order is the same + # if we have a target descriptor, check if the order is the same auth_order = rdms_list[0].pattern_descriptors[target_pdesc] other_order = rdm_new.pattern_descriptors[target_pdesc] if not np.all(other_order == auth_order): - ## order varies; reorder this rdms object - _, new_order = np.where(auth_order[:,None] == other_order) + # order varies; reorder this rdms object + _, new_order = np.where(auth_order[:, None] == other_order) rdm_new.reorder(new_order) dissimilarities = np.concatenate([ From 967be717d004d16304b05d003f22dd64132087cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Heiko=20Sch=C3=BCtt?= Date: Thu, 1 Aug 2024 12:45:03 +0200 Subject: [PATCH 2/2] added bures similarities to compare tooltip --- src/rsatoolbox/rdm/compare.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rsatoolbox/rdm/compare.py b/src/rsatoolbox/rdm/compare.py index c9f057a4..768c65c1 100644 --- a/src/rsatoolbox/rdm/compare.py +++ b/src/rsatoolbox/rdm/compare.py @@ -48,6 +48,10 @@ def compare(rdm1, rdm2, method='cosine', sigma_k=None): 'neg_riem_dist' = negative riemannian distance + 'bures' = bures similarity of equivalend cented kernel matrices + + 'bures_metric' = distances based on bures similarity, which is a metric + sigma_k (numpy.ndarray): covariance matrix of the pattern estimates. Used only for methods 'corr_cov' and 'cosine_cov'.