diff --git a/bids/modeling/statsmodels.py b/bids/modeling/statsmodels.py index 279d5b48a..b08e8224e 100644 --- a/bids/modeling/statsmodels.py +++ b/bids/modeling/statsmodels.py @@ -4,6 +4,7 @@ from collections import namedtuple, OrderedDict, Counter, defaultdict import itertools from functools import reduce +from multiprocessing.sharedctypes import Value import re import fnmatch @@ -263,11 +264,11 @@ class BIDSStatsModelsNode: overridden if one is passed when run() is called on a node. """ - def __init__(self, level, name, transformations=None, model=None, - contrasts=None, dummy_contrasts=False, group_by=None): + def __init__(self, level, name, model, group_by, transformations=None, + contrasts=None, dummy_contrasts=False): self.level = level.lower() self.name = name - self.model = model or {} + self.model = model if transformations is None: transformations = {"transformer": "pybids-transforms-v1", "instructions": []} @@ -279,13 +280,7 @@ def __init__(self, level, name, transformations=None, model=None, self.children = [] self.parents = [] if group_by is None: - group_by = [] - # Loop over contrasts after first level - if self.level != "run": - group_by.append("contrast") - # Loop over node level of this node - if self.level != "dataset": - group_by.append(self.level) + raise ValueError(f"group_by is not defined for Node: {name}") self.group_by = group_by # Check for intercept only run level model and throw an error @@ -592,25 +587,18 @@ def __init__(self, node, entities={}, collections=None, inputs=None, var_names = list(self.node.model['x']) - # Handle the special 1 construct. If it's present, we add a - # column of 1's to the design matrix. But behavior varies: - # * If there's only a single contrast across all of the inputs, - # the intercept column is given the same name as the input contrast. - # It may already exist, in which case we do nothing. - # * Otherwise, we name the column 'intercept'. - int_name = None + # Handle the special 1 construct. + # Add column of 1's to the design matrix called "intercept" if 1 in var_names: - if ('contrast' not in df.columns or df['contrast'].nunique() > 1): - int_name = 'intercept' - else: - int_name = df['contrast'].unique()[0] - - var_names.remove(1) + var_names = ['intercept' if i == 1 else i for i in var_names] + if 'intercept' not in df.columns: + df.insert(0, 'intercept', 1) - if int_name not in df.columns: - df.insert(0, int_name, 1) - else: - var_names.append(int_name) + # If a single incoming contrast + if ('contrast' in df.columns and df['contrast'].nunique() == 1): + unique_in_contrast = df['contrast'].unique()[0] + else: + unique_in_contrast = None var_names = expand_wildcards(var_names, df.columns) @@ -626,7 +614,7 @@ def __init__(self, node, entities={}, collections=None, inputs=None, # Create ModelSpec and build contrasts self.model_spec = create_model_spec(self.data, node.model, self.metadata) - self.contrasts = self._build_contrasts(int_name) + self.contrasts = self._build_contrasts(unique_in_contrast) def _collections_to_dfs(self, collections): """Merges collections and converts them to a pandas DataFrame.""" @@ -690,17 +678,55 @@ def _inputs_to_df(self, inputs): input_df.loc[input_df.index[i], con.name] = 1 return input_df - def _build_contrasts(self, int_name): - """Contrast list of ContrastInfo objects based on current state.""" - contrasts = {} + def _build_contrasts(self, unique_in_contrast=None): + """Contrast list of ContrastInfo objects based on current state. + + Parameters + ---------- + unique_in_contrast : string + Name of unique incoming contrast inputs (i.e. if there is only 1) + """ + in_contrasts = self.node.contrasts.copy() col_names = set(self.X.columns) - for con in self.node.contrasts: - name = con["name"] + + # Create dummy contrasts as regular contrasts + dummies = self.node.dummy_contrasts + if dummies: + if 'conditionlist' in dummies: + conditions = set(dummies['condition_list']) + else: + conditions = col_names + + for col_name in conditions: + in_contrasts.insert(0, + { + 'name': col_name, + 'condition_list': [col_name], + 'weights': [1], + 'test': dummies.get('test') + } + ) + + # Process all contrasts, starting with dummy contrasts + # Dummy contrasts are replaced if a contrast is defined with same name + contrasts = {} + for con in in_contrasts: condition_list = list(con["condition_list"]) - if 1 in condition_list and int_name is not None: - condition_list[condition_list.index(1)] = int_name - if name == 1 and int_name is not None: - name = int_name + + # Rename special 1 construct + condition_list = ['intercept' if i == 1 else i for i in condition_list] + + name = con["name"] + + # Rename contrast name + if name == 1: + name = unique_in_contrast or 'intercept' + else: + # If Node has single contrast input, as is grouped by contrast + # Rename contrast to append incoming contrast name + if unique_in_contrast: + name = f"{unique_in_contrast}_{name}" + missing_vars = set(condition_list) - col_names if missing_vars: if self.invalid_contrasts == 'error': @@ -711,31 +737,13 @@ def _build_contrasts(self, int_name): elif self.invalid_contrasts == 'drop': continue weights = np.atleast_2d(con['weights']) + # Add contrast name to entities; can be used in grouping downstream entities = {**self.entities, 'contrast': name} ci = ContrastInfo(name, condition_list, con['weights'], con.get("test"), entities) contrasts[name] = ci - dummies = self.node.dummy_contrasts - if dummies: - conditions = col_names - if 'conditions' in dummies: - conds = set(dummies['conditions']) - if 1 in conds and int_name is not None: - conds.discard(1) - conds.add(int_name) - conditions &= conds - conditions -= set(c.name for c in contrasts.values()) - - for col_name in conditions: - if col_name in contrasts: - continue - entities = {**self.entities, 'contrast': col_name} - ci = ContrastInfo(col_name, [col_name], [1], dummies.get("test"), - entities) - contrasts[col_name] = ci - return list(contrasts.values()) @property diff --git a/bids/modeling/tests/test_statsmodels.py b/bids/modeling/tests/test_statsmodels.py index 97ec62631..722ae6722 100644 --- a/bids/modeling/tests/test_statsmodels.py +++ b/bids/modeling/tests/test_statsmodels.py @@ -113,10 +113,22 @@ def test_entire_graph_smoketest(graph): cis = list(chain(*[op.contrasts for op in outputs])) assert len(cis) == 18 outputs = graph["participant"].run(cis, group_by=['subject', 'contrast']) - # 2 subjects x 3 contrasts + # 2 subjects x 3 contrasts) assert len(outputs) == 6 + # * 2 participant level contrasts = 12 cis = list(chain(*[op.contrasts for op in outputs])) - assert len(cis) == 6 + assert len(cis) == 12 + + # Test output names for single subject + out_contrasts = [ + c.entities['contrast'] for c in cis if c.entities['subject'] == '01' + ] + + expected_outs = [ + 'gain', 'gain_neg', 'RT', 'RT_neg', 'RT:gain', 'RT:gain_neg' + ] + + assert set(out_contrasts) == set(expected_outs) # Construct new ContrastInfo objects with name updated to reflect last # contrast. This would normally be done by the handling tool (e.g., fitlins) @@ -140,7 +152,7 @@ def test_entire_graph_smoketest(graph): assert model_spec.X.shape == (2, 2) assert model_spec.Z is None assert len(model_spec.terms) == 2 - assert not set(model_spec.terms.keys()) - {"RT", "gain", "RT:gain", "sex"} + assert not set(model_spec.terms.keys()) - {"intercept", "sex"} # BY-GROUP NODE outputs = graph["by-group"].run(inputs, group_by=['contrast']) @@ -153,7 +165,7 @@ def test_entire_graph_smoketest(graph): assert model_spec.__class__.__name__ == "GLMMSpec" assert model_spec.X.shape == (2, 1) assert model_spec.Z is None - assert not set(model_spec.terms.keys()) - {"RT", "gain", "RT:gain"} + assert not set(model_spec.terms.keys()) - {"intercept"} def test_expand_wildcards(): diff --git a/bids/tests/data/ds005/models/ds-005_type-test_model.json b/bids/tests/data/ds005/models/ds-005_type-test_model.json index e234f2c6b..1c369fe82 100644 --- a/bids/tests/data/ds005/models/ds-005_type-test_model.json +++ b/bids/tests/data/ds005/models/ds-005_type-test_model.json @@ -25,10 +25,6 @@ "Name": "Rename", "Input": "trial_type.parametric gain", "Output": "gain" - }, - { - "Name": "Scale", - "Input": "RT" } ], "DummyContrasts": { @@ -53,6 +49,12 @@ "ConditionList": [1], "Weights": [1], "Test": "FEMA" + }, + { + "Name": "neg", + "ConditionList": [1], + "Weights": [-1], + "Test": "FEMA" } ] }, @@ -60,7 +62,8 @@ "Name": "by-group", "Level": "Dataset", "GroupBy": [ - "sex" + "sex", + "contrast" ], "Model": { "X": [ @@ -75,7 +78,7 @@ { "Name": "group-diff", "Level": "Dataset", - "GroupBy": [], + "GroupBy": ["contrast"], "Model": { "X": [ 1,