-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_integrated_gradients.py
executable file
·65 lines (51 loc) · 1.77 KB
/
run_integrated_gradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Script to run integrated gradients on SQuAD or DuoRC
This script uses datasets, captum, omegaconf and transformers libraries.
Please install them in order to run this script.
Usage:
$python run_integrated_gradients.py --config ./configs/integrated_gradients/squad.yaml
"""
import os
import argparse
import pickle as pkl
import pandas as pd
from omegaconf import OmegaConf
# from transformers import BertTokenizer, BertForQuestionAnswering
from src.utils.integrated_gradients import BertIntegratedGradients
# from src.utils.misc import seed
dirname = os.path.dirname(__file__)
## Config
parser = argparse.ArgumentParser(
prog="run_integrated_gradients.py",
description="Run integrated gradients on a model.",
)
parser.add_argument(
"--config",
type=str,
action="store",
help="The configuration for integrated gradients",
default=os.path.join(dirname, "./configs/integrated_gradients/squad.yaml"),
)
args = parser.parse_args()
ig_config = OmegaConf.load(args.config)
# Load dataset
print("### Loading Dataset ###")
predictions = pd.read_json(ig_config.predictions_path)
# Initialize BertIntegratedGradients
big = BertIntegratedGradients(ig_config, predictions)
print("### Running IG ###")
(
samples,
word_importances,
token_importances,
) = big.get_random_samples_and_importances_across_all_layers(
n_samples=ig_config.n_samples
)
print("### Saving the Scores ###")
with open(os.path.join(ig_config.store_dir, "samples"), "wb") as out_file:
pkl.dump(samples, out_file)
with open(os.path.join(ig_config.store_dir, "token_importances"), "wb") as out_file:
pkl.dump(token_importances, out_file)
with open(os.path.join(ig_config.store_dir, "word_importances"), "wb") as out_file:
pkl.dump(word_importances, out_file)
print("### Finished ###")