Skip to content

Commit

Permalink
Merge pull request #98 from BIMSBbioinfo/gradientshap
Browse files Browse the repository at this point in the history
Add support for GradientShap Method as an alternative feature attribution method
  • Loading branch information
borauyar authored Jan 25, 2025
2 parents 3222ace + 04e7c53 commit 883430d
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 142 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ jobs:
conda activate my_env
flexynesis --data_path lgggbm_tcga_pub_processed --model_class DirectPred --target_variables STUDY --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types mut,cna --outdir . --prefix lgg_surv --early_stop_patience 3 --use_loss_weighting False --surv_event_var OS_STATUS --surv_time_var OS_MONTHS
- name: Run DirectPred_TestCovariates
shell: bash -l {0}
run: |
conda activate my_env
flexynesis --data_path lgggbm_tcga_pub_processed --model_class DirectPred --target_variables STUDY --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types mut --outdir . --prefix lgg_surv --early_stop_patience 3 --use_loss_weighting False --covariates BCR_STATUS
- name: Run DirectPred_Test_Explainers
shell: bash -l {0}
run: |
conda activate my_env
flexynesis --data_path lgggbm_tcga_pub_processed --model_class DirectPred --target_variables STUDY --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types mut --outdir . --prefix lgg_surv --early_stop_patience 3 --use_loss_weighting False --feature_importance_method Both
- name: Run supervised_vae
shell: bash -l {0}
run: |
Expand Down
23 changes: 17 additions & 6 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def main():
--threads (int): How many threads to use when using CPU. Default is 4.
--num_workers (int): How many workers to use for model training. Default is 0
--use_gpu (bool): If set, the system will attempt to use CUDA/GPU if available.
--feature_importance_method (str): which method(s) to use to compute feature importance scores. Options are: IntegratedGradients, GradientShap, or Both. Default: Both
--disable_marker_finding (bool): If set, marker discovery after model training is disabled.
--string_organism (int): STRING DB organism id. Default is 9606.
--string_node_name (str): Type of node name. Choices are ["gene_name", "gene_id"]. Default is "gene_name".
Expand Down Expand Up @@ -103,6 +104,8 @@ def main():
parser.add_argument("--num_workers", help="(Optional) How many workers to use for model training (default is 0)", type=int, default = 0)
parser.add_argument("--use_gpu", action="store_true",
help="(Optional) If set, the system will attempt to use CUDA/GPU if available.")
parser.add_argument("--feature_importance_method", help="Choose feature importance score method", type=str,
choices=["IntegratedGradients", "GradientShap", "Both"], default="IntegratedGradients")
parser.add_argument("--disable_marker_finding", action="store_true",
help="(Optional) If set, marker discovery after model training is disabled.")
# GNN args.
Expand Down Expand Up @@ -339,12 +342,20 @@ def main():
if any([args.target_variables, args.surv_event_var]):
if not args.disable_marker_finding: # unless marker discovery is disabled
# compute feature importance values
print("[INFO] Computing variable importance scores")
for var in model.target_variables:
model.compute_feature_importance(train_dataset, var, steps = 25)
df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables],
ignore_index = True)
df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance.csv'])), header=True, index=False)

if args.feature_importance_method == 'Both':
explainers = ['IntegratedGradients', 'GradientShap']
else:
explainers = [args.feature_importance_method]

for explainer in explainers:
print("[INFO] Computing variable importance scores using explainer:",explainer)
for var in model.target_variables:
model.compute_feature_importance(train_dataset, var, steps_or_samples = 25, method=explainer)
df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables],
ignore_index = True)
df_imp['explainer'] = explainer
df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance', explainer, 'csv'])), header=True, index=False)

# print known/predicted labels
predicted_labels = pd.concat([flexynesis.get_predicted_labels(model.predict(train_dataset), train_dataset, 'train'),
Expand Down
74 changes: 49 additions & 25 deletions flexynesis/models/crossmodal_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import lightning as pl
from scipy import stats

from captum.attr import IntegratedGradients
from captum.attr import IntegratedGradients, GradientShap

from ..modules import *

Expand Down Expand Up @@ -482,32 +482,34 @@ def forward_target(self, *args):
outputs_list.append(outputs[target_var])
return torch.cat(outputs_list, dim = 0)

def compute_feature_importance(self, dataset, target_var, steps = 5, batch_size = 64):
def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64):
"""
Computes the feature importance for each variable in the dataset using the Integrated Gradients method.
This method measures the importance of each feature by attributing the prediction output to each input feature.
Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP.
Args:
dataset: The dataset object containing the features and data.
target_var (str): The target variable for which feature importance is calculated.
steps (int, optional): The number of steps to use for integrated gradients approximation. Defaults to 5.
method (str, optional): The attribution method to use ("IntegratedGradients" or "GradientShap").
Defaults to "IntegratedGradients".
steps_or_samples (int, optional): Number of steps for Integrated Gradients or samples for Gradient SHAP.
Defaults to 5.
batch_size (int, optional): The size of the batch to process the dataset. Defaults to 64.
Returns:
pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities.
Columns include 'target_variable', 'target_class', 'target_class_label', 'layer', 'name',
and 'importance'.
This function adjusts the device setting based on the availability of GPUs and performs the computation using
Integrated Gradients. It processes batches of data, aggregates results across batches, and formats the output
into a readable DataFrame which is then stored in the model's attribute for later use or analysis.
"""
device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu')
self.to(device)

print("[INFO] Computing feature importance for variable:",target_var,"on device:",device)
# Initialize the Integrated Gradients method
ig = IntegratedGradients(self.forward_target)

# Choose the attribution method dynamically
if method == "IntegratedGradients":
explainer = IntegratedGradients(self.forward_target)
elif method == "GradientShap":
explainer = GradientShap(self.forward_target)
else:
raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.")

# Get the number of classes for the target variable
if self.variable_types[target_var] == 'numerical':
Expand All @@ -523,19 +525,38 @@ def compute_feature_importance(self, dataset, target_var, steps = 5, batch_size
dat, _, _ = batch
x_list = [dat[x].to(device) for x in self.input_layers]
input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list])
baseline = tuple(torch.zeros_like(x) for x in input_data)

if method == 'IntegratedGradients':
baseline = tuple(torch.zeros_like(x) for x in input_data)
elif method == 'GradientShap': # provide multiple baselines for Gr.Shap
baseline = tuple(
torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0)
for x in input_data
)

if num_class == 1:
# returns a tuple of tensors (one per data modality)
attributions = ig.attribute(input_data, baseline,
additional_forward_args=(target_var, steps),
n_steps=steps)
if method == 'IntegratedGradients':
attributions = explainer.attribute(input_data, baseline,
additional_forward_args=(target_var, steps_or_samples),
n_steps=steps_or_samples)
elif method == 'GradientShap':
attributions = explainer.attribute(input_data, baseline,
additional_forward_args=(target_var, steps_or_samples),
n_samples=steps_or_samples)
aggregated_attributions[0].append(attributions)
else:
for target_class in range(num_class):
# returns a tuple of tensors (one per data modality)
attributions = ig.attribute(input_data, baseline,
additional_forward_args=(target_var, steps),
target=target_class, n_steps=steps)
if method == 'IntegratedGradients':
attributions = explainer.attribute(input_data, baseline,
additional_forward_args=(target_var, steps_or_samples),
target=target_class,
n_steps=steps_or_samples)
elif method == 'GradientShap':
attributions = explainer.attribute(input_data, baseline,
additional_forward_args=(target_var, steps_or_samples),
target=target_class,
n_samples=steps_or_samples)
aggregated_attributions[target_class].append(attributions)

# For each target class and for each data modality/layer, concatenate attributions accross batches
Expand Down Expand Up @@ -571,9 +592,12 @@ def compute_feature_importance(self, dataset, target_var, steps = 5, batch_size
for j in range(len(layers)):
features = dataset.features[layers[j]]
importances = imp[i][j][0].detach().numpy()
target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else ''
df_list.append(pd.DataFrame({'target_variable': target_var,
'target_class': i, 'layer': layers[j],
'name': features, 'importance': importances}))
'target_class': i,
'target_class_label': target_class_label,
'layer': layers[j],
'name': features, 'importance': importances}))
df_imp = pd.concat(df_list, ignore_index = True)

# save scores in model
Expand Down
Loading

0 comments on commit 883430d

Please sign in to comment.