-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpymol_attribution.py
37 lines (31 loc) · 999 Bytes
/
pymol_attribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pymol import cmd
import numpy as np
import matplotlib.cm as cm
from matplotlib.colors import Normalize
cmap = cm.seismic
norm = Normalize(vmin=-1, vmax=1)
pdb_id = '4m21_a'
attributions_path = 'out/krashras_graph_new/interpret/attributions.npz'
data_path = 'data/KrasHras/pdb/'
# Load Attribution
data = np.load(attributions_path)
attributions = data['data']
labels = data['labels']
ind = np.where(labels == pdb_id)
attribution = attributions[ind][0][:,-4]
# Load PDB
cmd.reinitialize()
cmd.bg_color('white')
cmd.load(data_path+pdb_id[:-2]+'.pdb')
cmd.split_chains()
for name in cmd.get_names('objects', 0, '(all)'):
if not name.endswith(pdb_id[-1].upper()):
cmd.delete(name)
cmd.reset()
for i, _ in enumerate(attribution):
cmd.select('toBecolored', 'res ' + str(i))
cmd.set_color('saliency'+str(i), list(cmap(norm(_)))[:3])
cmd.color('saliency'+str(i), 'toBecolored')
cmd.select('selected','chain '+pdb_id[-1].upper())
cmd.show('mesh', 'selected')
cmd.deselect()