Skip to content

Commit

Permalink
Merge pull request #408 from rsagroup/bug_fix_concat
Browse files Browse the repository at this point in the history
allow passing a target descriptor for ordering in concatenation.
  • Loading branch information
HeikoSchuett authored Aug 1, 2024
2 parents 08143ff + 967be71 commit ec73820
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
4 changes: 4 additions & 0 deletions src/rsatoolbox/rdm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def compare(rdm1, rdm2, method='cosine', sigma_k=None):
'neg_riem_dist' = negative riemannian distance
'bures' = bures similarity of equivalend cented kernel matrices
'bures_metric' = distances based on bures similarity, which is a metric
sigma_k (numpy.ndarray):
covariance matrix of the pattern estimates.
Used only for methods 'corr_cov' and 'cosine_cov'.
Expand Down
35 changes: 23 additions & 12 deletions src/rsatoolbox/rdm/rdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def load_rdm(filename, file_type=None):
return rdms_from_dict(rdm_dict)


def concat(*rdms: RDMs) -> RDMs:
def concat(*rdms: RDMs, target_pdesc: Optional[str] = None) -> RDMs:
"""Merge into single RDMs object
requires that the rdms have the same shape
descriptor and pattern descriptors are taken from the first rdms object
Expand All @@ -552,6 +552,7 @@ def concat(*rdms: RDMs) -> RDMs:
Args:
rdms(iterable of rsatoolbox.rdm.RDMs): RDMs objects to be concatenated
or multiple RDMs as separate arguments
target_pdesc(optional, str): a pattern descriptor to use for sorting
Returns:
rsatoolbox.rdm.RDMs: concatenated rdms object
Expand All @@ -569,27 +570,37 @@ def concat(*rdms: RDMs) -> RDMs:

descriptors, rdm_descriptors = _merged_rdm_descriptors(rdms_list)

## see if we can find an authoritative descriptor for pattern order
pdescs = rdms_list[0].pattern_descriptors.keys()
pdesc_candidates = list(filter(lambda n: n!='index', pdescs))
target_pdesc = None
if len(pdesc_candidates) > 0:
target_pdesc = pdesc_candidates[0]
if len(pdesc_candidates) > 1:
warnings.warn(f'[concat] Multiple pattern descriptors found, using "{target_pdesc}"')
if target_pdesc is None:
# see if we can find an authoritative descriptor for pattern order
pdescs = rdms_list[0].pattern_descriptors.keys()
pdesc_candidates = list(filter(
lambda n: n != 'index' and (
len(rdms_list[0].pattern_descriptors[n])
== len(set(rdms_list[0].pattern_descriptors[n]))),
pdescs))
target_pdesc = None
if len(pdesc_candidates) > 0:
target_pdesc = pdesc_candidates[0]
if len(pdesc_candidates) > 1:
warnings.warn(f'[concat] Multiple pattern descriptors found, using "{target_pdesc}"')
else:
assert target_pdesc in rdms_list[0].pattern_descriptors.keys(), \
'The provided descriptor is not a pattern descriptor'
assert len(rdms_list[0].pattern_descriptors[target_pdesc]) == rdms_list[0].n_cond, \
'The provided descriptor is not unique'

for rdm_new in rdms_list[1:]:
assert isinstance(rdm_new, RDMs), 'rdm for concat should be an RDMs'
assert rdm_new.n_cond == rdms_list[0].n_cond, 'rdm for concat had wrong shape'
assert rdm_new.dissimilarity_measure == rdms_list[0].dissimilarity_measure, \
'appended rdm had wrong dissimilarity measure'
if target_pdesc:
## if we have a target descriptor, check if the order is the same
# if we have a target descriptor, check if the order is the same
auth_order = rdms_list[0].pattern_descriptors[target_pdesc]
other_order = rdm_new.pattern_descriptors[target_pdesc]
if not np.all(other_order == auth_order):
## order varies; reorder this rdms object
_, new_order = np.where(auth_order[:,None] == other_order)
# order varies; reorder this rdms object
_, new_order = np.where(auth_order[:, None] == other_order)
rdm_new.reorder(new_order)

dissimilarities = np.concatenate([
Expand Down

0 comments on commit ec73820

Please sign in to comment.