forked from XuhanLiu/NGFP
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreproduce_main_results.py
122 lines (105 loc) · 4.51 KB
/
reproduce_main_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from pathlib import Path
from torch.utils.data import DataLoader, Subset
from NeuralGraph.dataset import MolData, SmileData
from NeuralGraph.model import QSAR, MLP
import torch.nn as nn
import pandas as pd
import numpy as np
import argparse
FP_METHODS = ["morgan", "nfp"]
EXP_NAMES = ["solubility", "drug_efficacy", "photovoltaic"]
FP_LEN = 1<<9 # fingerprint length for circular FP
def split_train_valid_test(n, p=0.8, v=0.1, seed=None):
if seed:
np.random.seed(seed)
idx = np.arange(n)
np.random.shuffle(idx)
s = int(n*p)
t = int(n*v)
# train, valid, test
return idx[:s], idx[s:(s+t)], idx[(s+t):]
def normalize_array(A):
mean, std = np.mean(A), np.std(A)
def norm_func(X): return (X-mean) / std
def restore_func(X): return X * std + mean
return norm_func, restore_func
def load_csv(data_file, target_name):
df = pd.read_csv(data_file)
return df['smiles'], df[target_name].values
def mse(x, y):
return ((x-y)**2).mean()
def main(args):
BSZ, RUNS, LR, N_EPOCH = args.batch_size, args.runs, args.lr, args.epochs
OUTPUT, SMILES, TARGET = [None]*3
if args.experiment == EXP_NAMES[0]:
OUTPUT = './output/best_delaney.pkl'
DATAFILE = Path('./dataset/solubility/delaney-processed.csv')
TGT_COL_NAME = 'measured log solubility in mols per litre'
SMILES, TARGET = load_csv(DATAFILE, TGT_COL_NAME)
elif args.experiment == EXP_NAMES[1]:
OUTPUT = './output/best_efficacy.pkl'
DATAFILE = Path('./dataset/drug_efficacy/malaria-processed.csv')
TGT_COL_NAME = 'activity'
SMILES, TARGET = load_csv(DATAFILE, TGT_COL_NAME)
elif args.experiment == EXP_NAMES[2]:
OUTPUT = './output/best_photovoltaic.pkl'
DATAFILE = Path('./dataset/photovoltaic_efficiency/cep-processed.csv')
TGT_COL_NAME = 'PCE'
SMILES, TARGET = load_csv(DATAFILE, TGT_COL_NAME)
else:
raise NotImplementedError
def build_data_net(args, target):
if args.fp_method == FP_METHODS[0]:
#""" CFP """
data = SmileData(SMILES, target, fp_len=FP_LEN, radius=4)
net = lambda : MLP(hid_dim=FP_LEN, n_class=1)
return data, net
elif args.fp_method == FP_METHODS[1]:
#""" NFP """
net = lambda : QSAR(hid_dim=128, n_class=1)
data = MolData(SMILES, target)
return data, net
else:
raise NotImplementedError
res = []
for _ in range(RUNS):
train_idx, valid_idx, test_idx = split_train_valid_test(len(TARGET),
seed=None)
norm_func, restore_func = normalize_array(
np.concatenate([TARGET[train_idx], TARGET[valid_idx]], axis=0))
target = norm_func(TARGET)
data, net = build_data_net(args, target)
train_loader = DataLoader(Subset(data, train_idx), batch_size=BSZ,
shuffle=True, drop_last=True)
valid_loader = DataLoader(Subset(data, valid_idx), batch_size=BSZ,
shuffle=False)
test_loader = DataLoader(Subset(data, test_idx), batch_size=BSZ,
shuffle=False)
net = net()
net = net.fit(train_loader, valid_loader, epochs=N_EPOCH, path=OUTPUT,
criterion=nn.MSELoss(), lr=LR)
score = net.predict(test_loader)
gt = restore_func(target[test_idx])
prd = restore_func(score)
res.append(mse(gt, prd))
print(mse(gt,prd))
avg_mse, std_mse = np.asarray(res).mean(), np.asarray(res).std()
return avg_mse, std_mse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("experiment", default="solubility", type=str,
help="Specify the experiment name",
choices=EXP_NAMES)
parser.add_argument("fp_method", default="nfp", type=str,
help="Specify the fingerprint method",
choices=FP_METHODS)
parser.add_argument("-b", "--batch-size", help="batch size",
default=64, type=int)
parser.add_argument("-e", "--epochs", help="number of epochs",
default=500, type=int)
parser.add_argument("-r", "--runs", help="number of runs",
default=5, type=int)
parser.add_argument("-l", "--lr", help="learning rate",
default=1e-3, type=float)
parsed_args = parser.parse_args()
print(main(parsed_args))