-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinfer_chir.py
140 lines (115 loc) · 4.57 KB
/
infer_chir.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
'''
Date: 2022-11-23 11:29:36
LastEditors: yuhhong
LastEditTime: 2022-12-12 12:57:28
'''
import os
import argparse
import numpy as np
from tqdm import tqdm
import yaml
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR
import random
from rdkit import Chem
# suppress rdkit warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from sklearn.metrics import roc_auc_score, accuracy_score
from dataset import ChiralityDataset_infer
from model import MolNet_CSP
from utils import set_seed, average_results_on_enantiomers
TEST_BATCH_SIZE = 1 # global variable in inference
TEST_NUM_WORKERS = 0 # global variable in inference
def inference(model, device, loader, num_points):
model.eval()
y_pred = []
smiles_list = []
id_list = []
mbs = []
for _, batch in enumerate(tqdm(loader, desc="Iteration")):
mol_id, smiles_iso, mb, x = batch
x = x.to(device).to(torch.float32)
x = x.permute(0, 2, 1)
idx_base = torch.arange(0, TEST_BATCH_SIZE, device=device).view(-1, 1, 1) * num_points
with torch.no_grad():
pred = model(x, idx_base)
y_pred.append(pred.view(TEST_BATCH_SIZE, -1).detach().cpu())
smiles_list.extend(smiles_iso)
id_list.extend(mol_id)
mbs.extend(mb.tolist())
y_pred = torch.cat(y_pred, dim=0)
return id_list, smiles_list, mbs, y_pred
def batch_filter(supp):
for mol in supp: # remove empty molecule
if mol is None:
continue
if len(Chem.MolToMolBlock(mol).split("\n")) <= 6:
continue
yield mol
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description='3DMolCSP (infer)')
parser.add_argument('--config', type=str, required=True,
help='Path to configuration')
parser.add_argument('--csp_no', type=int, default=0, required=True,
help='charility phase number [0, 19]')
parser.add_argument('--resume_path', type=str, default='', required=True,
help='Pretrained model path')
parser.add_argument('--result_path', type=str, default='', required=True,
help='Results path')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--no_cuda', type=bool, default=False,
help='enables CUDA training')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
set_seed(42)
results_dir = "/".join(args.result_path.split('/')[:-1])
os.makedirs(results_dir, exist_ok = True)
print('Create the results directory, {}'.format(results_dir))
# load the configuration file
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
model = MolNet_CSP(config['model_para'], args.device).to(device)
num_params = sum(p.numel() for p in model.parameters())
# print(f'{str(model)} #Params: {num_params}')
print('#Params: {}'.format(num_params))
print("Loading the data...")
supp = Chem.SDMolSupplier(config['paths']['test_data'])
test_set = ChiralityDataset_infer([item for item in batch_filter(supp)],
num_points=config['model_para']['num_atoms'],
csp_no=args.csp_no,
flipping=False)
supp_ena = Chem.SDMolSupplier(config['paths']['test_data'])
test_set_ena = ChiralityDataset_infer([item for item in batch_filter(supp_ena)],
num_points=config['model_para']['num_atoms'],
csp_no=args.csp_no,
flipping=True)
test_set = ConcatDataset([test_set, test_set_ena]) # concat two configurations' datasets
test_loader = DataLoader(test_set,
batch_size=TEST_BATCH_SIZE,
num_workers=TEST_NUM_WORKERS,
drop_last=True,)
print('Load {} test data from {}.'.format(len(test_set), config['paths']['test_data']))
print("Load the model...")
model.load_state_dict(torch.load(args.resume_path, map_location=device)['model_state_dict'])
model.to(device)
print('Evaluating...')
id_list, smiles_list, mbs, y_pred = inference(model, device, test_loader,
config['model_para']['num_atoms'])
y_pred_out = []
for y in y_pred:
y_pred_out.append(','.join([str(i) for i in y.tolist()]))
res_df = pd.DataFrame({'ID': id_list, 'SMILES': smiles_list, 'MB': mbs, 'Pred': y_pred_out})
print('Average the results of enantiomers...')
res_df = average_results_on_enantiomers(res_df)
res_df.to_csv(args.result_path, sep='\t')
print('Save the test results to {}'.format(args.result_path))