diff --git a/compound_generation.py b/compound_generation.py index f0b5fd7..64c88da 100644 --- a/compound_generation.py +++ b/compound_generation.py @@ -43,10 +43,12 @@ def __init__(self, use_cuda=True): self.use_cuda = False self.encoder = EncoderCNN(8) - self.decoder = DecoderRNN(512, 1024, 29, 1) + self.decoder = DecoderRNN(512, 1024, 38, 1) +# self.vae_model = LigandVAE(use_cuda=use_cuda) self.D = discriminator(nc=8,use_cuda=True) self.G = generator(nc=8,use_cuda=True) +# self.vae_model.eval() self.D.eval() self.G.eval() self.encoder.eval() @@ -56,6 +58,7 @@ def __init__(self, use_cuda=True): assert torch.cuda.is_available() self.encoder.cuda() self.decoder.cuda() + # self.vae_model.cuda() self.D.cuda() self.G.cuda() self.use_cuda = True @@ -102,18 +105,7 @@ def generate_molecules(self, smile_str, n_attemps=300, lam_fact=1., probab=False :return: list of RDKit molecules. """ - shape_input, cond_input = get_mol_voxels(smile_str) - if self.use_cuda: - shape_input = shape_input.cuda() - cond_input = cond_input.cuda() - - shape_input = shape_input.unsqueeze(0).repeat(n_attemps, 1, 1, 1, 1) - cond_input = cond_input.unsqueeze(0).repeat(n_attemps, 1, 1, 1, 1) - - shape_input = Variable(shape_input, volatile=True) - cond_input = Variable(cond_input, volatile=True) - #recoded_shapes, _, _ = self.G(shape_input, cond_input, lam_fact) z = Variable(torch.randn(n_attemps, 128, 12, 12, 12)).cuda() recoded_shapes = self.G(z) smiles = self.caption_shape(recoded_shapes, probab=probab) diff --git a/decoding.py b/decoding.py index d6bdc7b..19e951d 100644 --- a/decoding.py +++ b/decoding.py @@ -2,9 +2,9 @@ # Copying and distribution is allowed under AGPLv3 license vocab_list = ["pad", "start", "end", - "C", "c", "N", "n", "S", "s", "P", "O", "o", + "C", "c", "N", "n", "S", "s", "P", "p", "O", "o", "B", "F", "I", - "Cl", "[nH]", "Br", # "X", "Y", "Z", + "X", "Y", "Z", #"Cl", "[nH]", "Br" "1", "2", "3", "4", "5", "6", "#", "=", "-", "(", ")","/","\\","@","[","]","H","+","7" # Misc ] @@ -25,5 +25,6 @@ def decode_smiles(in_tensor): if xchar == 2: break csmile += vocab_i2c_v1[xchar] + csmile = csmile.replace("X","Cl").replace("Y","[nH]").replace("Z","Br") gen_smiles.append(csmile) return gen_smiles \ No newline at end of file diff --git a/generators.py b/generators.py index 25cbd72..ee36b0c 100644 --- a/generators.py +++ b/generators.py @@ -5,9 +5,14 @@ import rdkit from rdkit import Chem from rdkit.Chem import AllChem -from htmd.molecule.util import uniformRandomRotation -from htmd.smallmol.smallmol import SmallMol -from htmd.molecule.voxeldescriptors import _getOccupancyC, _getGridCenters + +#from htmd.molecule.util import uniformRandomRotation +#from htmd.smallmol.smallmol import SmallMol +#from htmd.molecule.voxeldescriptors import _getOccupancyC, _getGridCenters + +from moleculekit.smallmol.smallmol import SmallMol +from moleculekit.tools.voxeldescriptors import getVoxelDescriptors + import numpy as np import multiprocessing @@ -15,9 +20,9 @@ import random vocab_list = ["pad", "start", "end", - "C", "c", "N", "n", "S", "s", "P", "O", "o", + "C", "c", "N", "n", "S", "s", "P", "p", "O", "o", "B", "F", "I", - "Cl", "[nH]", "Br", # "X", "Y", "Z", + "X", "Y", "Z", #"Cl", "[nH]", "Br" "1", "2", "3", "4", "5", "6", "#", "=", "-", "(", ")","/","\\","@","[","]","H","+","7" # Misc ] @@ -30,7 +35,7 @@ size = 24 N = [size, size, size] bbm = (np.zeros(3) - float(size * 1. / 2)) -global_centers = _getGridCenters(bbm, N, resolution) +#global_centers = _getGridCenters(bbm, N, resolution) def string_gen_V1(in_string): @@ -70,7 +75,7 @@ def generate_representation(in_smile): AllChem.EmbedMolecule(mh) Chem.AllChem.MMFFOptimizeMolecule(mh) m = Chem.RemoveHs(mh) - mol = SmallMol(m) + mol = SmallMol(m, force_reading=True,fixHs=False) return mol except: # Rarely the conformer generation fails return None @@ -159,6 +164,8 @@ def generate_representation_v1(smile): end_token = smile_str.index(2) smile_str = "".join([vocab_i2c_v1[i] for i in smile_str[1:end_token]]) + smile_str = smile_str.replace("X","Cl").replace("Y","[nH]").replace("Z","Br") + mol = generate_representation(smile_str) if mol is None: return None @@ -170,6 +177,44 @@ def generate_representation_v1(smile): return torch.Tensor(vox), torch.Tensor(smile), end_token + 1 +def generate_representation_v2(smile): + """ + Generate voxelized and string representation of a molecule + """ + # Convert smile to 3D structure + + smile_str = list(smile) + end_token = smile_str.index(2) + smile_str = "".join([vocab_i2c_v1[i] for i in smile_str[1:end_token]]) + + smile_str = smile_str.replace("X","Cl").replace("Y","[nH]").replace("Z","Br") + + mol = generate_representation(smile_str) + if mol is None: + return None + + try: + center = mol.getCenter() + box = [size, size, size] #size = 24 + #getVoxelDescriptors calculates feature of the mol object and return features as array, centers of voxel + #The features define the 8 feature of the voxel, (‘hydrophobic’, ‘aromatic’, ‘hbond_acceptor’, ‘hbond_donor’, ‘positive_ionizable’, ‘negative_ionizable’, ‘metal’, ‘occupancies’). + mol_vox, mol_centers, mol_N = getVoxelDescriptors(mol, boxsize=box, voxelsize=1, buffer=0, center=center,validitychecks =False) + #print(mol_vox, mol_centers, mol_N ) #mol_N = [35 35 35] + + #print(mol_vox.shape,mol_centers.shape,mol_N.shape) #(42875, 8) (42875, 3) (3,) + + mol_vox_t = mol_vox.transpose().reshape([1, mol_vox.shape[1], mol_N[0], mol_N[1], mol_N[2]]) + except: + print("Can not Voxelization") + #sys.exit() + return None + + finish_combine = np.squeeze(mol_vox_t) + + #print(finish_combine.shape) + + return torch.Tensor(finish_combine), torch.Tensor(smile), end_token + 1 + def gather_fn(in_data): """ @@ -200,7 +245,7 @@ def __init__(self, n_proc=6, mp_pool=None): raise NotImplementedError("Use multiprocessing for now!") def transform_data(self, smiles): - inputs = self.mp.map(generate_representation_v1, smiles) + inputs = self.mp.map(generate_representation_v2, smiles) # Sometimes representation generation fails inputs = list(filter(lambda x: x is not None, inputs)) diff --git a/networks.py b/networks.py index 5e77933..ff0f35e 100644 --- a/networks.py +++ b/networks.py @@ -129,15 +129,15 @@ def init_weights(self): def forward(self, features, captions, lengths): """Decode shapes feature vectors and generates SMILES.""" embeddings = self.embed(captions) - print(embeddings.shape,captions.shape) + #print(embeddings.shape,captions.shape) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) - print(embeddings.shape,features.shape) + #print(embeddings.shape,features.shape) packed = pack_padded_sequence(embeddings, lengths, batch_first=True) - print(packed,packed[0].shape,sum(packed[1]),sum(lengths)) + #print(packed,packed[0].shape,sum(packed[1]),sum(lengths)) hiddens, _ = self.lstm(packed) outputs = self.linear(hiddens[0]) @@ -147,7 +147,7 @@ def sample(self, features, states=None): """Samples SMILES tockens for given shape features (Greedy search).""" sampled_ids = [] inputs = features.unsqueeze(1) - for i in range(80): + for i in range(75): hiddens, states = self.lstm(inputs, states) outputs = self.linear(hiddens.squeeze(1)) predicted = outputs.max(1)[1] @@ -160,7 +160,7 @@ def sample_prob(self, features, states=None): """Samples SMILES tockens for given shape features (probalistic picking).""" sampled_ids = [] inputs = features.unsqueeze(1) - for i in range(62): # maximum sampling length + for i in range(130): # maximum sampling length hiddens, states = self.lstm(inputs, states) outputs = self.linear(hiddens.squeeze(1)) if i == 0: @@ -184,7 +184,7 @@ def sample_prob(self, features, states=None): valid_token = rand_num < iter_sum update_indecies = np.logical_and(valid_token, np.logical_not(tokens.astype(np.bool))) - tokens[update_indecies] = i+1 + tokens[update_indecies] = i # put back on the GPU. if probs.is_cuda: diff --git a/train.py b/train.py index 048f568..8a1926a 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ # Copying and distribution is allowed under AGPLv3 license import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" import sys import torch import torch.autograd @@ -26,22 +27,22 @@ cap_loss = 0. caption_start = 4000 -batch_size = 128 +batch_size = 256 savedir = args["output_dir"] os.makedirs(savedir, exist_ok=True) smiles = np.load(args["input"]) import multiprocessing -multiproc = multiprocessing.Pool(6) +multiproc = multiprocessing.Pool(14) my_gen = queue_datagen(smiles, batch_size=batch_size, mp_pool=multiproc) -mg = GeneratorEnqueuer(my_gen, seed=0) +mg = GeneratorEnqueuer(my_gen) mg.start() mt_gen = mg.get() # Define the networks encoder = EncoderCNN(8) -decoder = DecoderRNN(512, 1024, 29, 1) +decoder = DecoderRNN(512, 1024, 38, 1) D = discriminator(nc=8,use_cuda=True) G = generator(nc=8,use_cuda=True) @@ -51,7 +52,7 @@ G.cuda() # Caption optimizer -criterion = nn.CrossEntropyLoss() +criterion = nn.CrossEntropyLoss(ignore_index = 0) caption_params = list(decoder.parameters()) + list(encoder.parameters()) caption_optimizer = torch.optim.Adam(caption_params, lr=0.001) @@ -59,7 +60,7 @@ decoder.train() # GAN optimizer -dg_criterion = nn.BCELoss() +dg_criterion = nn.BCELoss() # 是单目标二分类交叉熵函数 d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001) z_dimension = 32 @@ -67,8 +68,8 @@ tq_gen = tqdm(enumerate(mt_gen)) log_file = open(os.path.join(savedir, "log.txt"), "w") cap_loss = 0. -#caption_start = 4000 -caption_start = 0 +caption_start = 4000 + for i, (mol_batch, caption, lengths) in tq_gen: num_img = mol_batch.size(0) @@ -76,10 +77,10 @@ real_label = Variable(torch.ones(num_img)).cuda() fake_label = Variable(torch.zeros(num_img)).cuda() #print('fake label', fake_label.shape) - ########Train the discriminator####### - real_out = D(real_data.float()) + ########判别器训练train####### + real_out = D(real_data.float()) # 将真实图片放入判别器中 - d_loss_real = dg_criterion(real_out.view(-1), real_label) + d_loss_real = dg_criterion(real_out.view(-1), real_label) # 得到真实图片的loss real_scores = real_out z = Variable(torch.randn(num_img, 128, 12, 12, 12)).cuda() @@ -90,35 +91,38 @@ d_loss_fake = dg_criterion(fake_out.view(-1), fake_label) fake_scores = fake_out - d_loss = d_loss_real + d_loss_fake - d_optimizer.zero_grad() - d_loss.backward() + d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失 + d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0 + d_loss.backward() # 将误差反向传播 d_optimizer.step() - #==================Training generator======== + #==================训练生成器======== z = Variable(torch.randn(num_img, 128, 12, 12, 12)).cuda() fake_data = G(z) output = D(fake_data) #print(output.view(-1).shape, real_label.shape) g_loss = dg_criterion(output.view(-1), real_label) g_optimizer.zero_grad() - g_loss.backward() + g_loss.backward() #retain_graph=True retain_graph=True if i >= caption_start else False g_optimizer.step() recon_batch = G(z.detach()) if i >= caption_start: # Start by autoencoder optimization - captions = Variable(caption.cuda()) - targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] - + caption = Variable(caption.cuda()) + + targets = pack_padded_sequence(caption, lengths, batch_first=True)[0] + decoder.zero_grad() encoder.zero_grad() features = encoder(recon_batch) - outputs = decoder(features, captions, lengths) + outputs = decoder(features, caption, lengths) cap_loss = criterion(outputs, targets) - cap_loss.backward() + + cap_loss.backward() # + #torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1) caption_optimizer.step() - break + if (i + 1) % 5000 == 0: torch.save(decoder.state_dict(), @@ -151,7 +155,7 @@ lr = param_group["lr"] / 2. param_group["lr"] = lr - if i == 120000: + if i == 210000: # We are Done! log_file.close() break @@ -159,4 +163,4 @@ # Cleanup del tq_gen mt_gen.close() -multiproc.close() +multiproc.close() \ No newline at end of file