-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathget_vocab.py
60 lines (48 loc) · 1.58 KB
/
get_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys
import argparse
import torch
import torch_geometric
from chemutils import *
from rdkit import Chem
from multiprocessing import Pool
import pandas as pd
from chemutils import brics_decomp, get_clique_mol
def process(data):
vocab = set()
for line in data:
s = line.strip("\r\n ")
hmol = MolGraph(s)
for node,attr in hmol.mol_tree.nodes(data=True):
smiles = attr['smiles']
vocab.add( attr['label'] )
for i,s in attr['inter_label']:
vocab.add( (smiles, s) )
return vocab
def get_motifs(data):
Chem.SanitizeMol(data)
motifs, edges = brics_decomp(data)
return motifs
def get_motifs_edges(data):
Chem.SanitizeMol(data)
motifs, edges = brics_decomp(data)
return motifs, edges
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--ncpu', type=int, default=1)
parser.add_argument('--dataset', type=str, default="zinc")
args = parser.parse_args()
if (args.dataset == 'zinc'):
smiles_list = pd.read_csv('dataset/' + args.dataset + '/raw/all.txt', header=None)[0].tolist()
else:
smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
data = smiles_list
#data = ['CCOC(=O)c1cncn1C(C)c2ccccc2']
vocab = []
for i in range(len(data)):
mol = Chem.MolFromSmiles(data[i])
Chem.SanitizeMol(mol)
motifs = get_motifs(mol)
for motif in motifs:
vocab.append(motif)
for i, x in enumerate(sorted(vocab)):
print(x)