Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
AspirinCode authored May 18, 2023
1 parent 267c2d3 commit dafd12f
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 52 deletions.
16 changes: 4 additions & 12 deletions compound_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def __init__(self, use_cuda=True):

self.use_cuda = False
self.encoder = EncoderCNN(8)
self.decoder = DecoderRNN(512, 1024, 29, 1)
self.decoder = DecoderRNN(512, 1024, 38, 1)
# self.vae_model = LigandVAE(use_cuda=use_cuda)
self.D = discriminator(nc=8,use_cuda=True)
self.G = generator(nc=8,use_cuda=True)

# self.vae_model.eval()
self.D.eval()
self.G.eval()
self.encoder.eval()
Expand All @@ -56,6 +58,7 @@ def __init__(self, use_cuda=True):
assert torch.cuda.is_available()
self.encoder.cuda()
self.decoder.cuda()
# self.vae_model.cuda()
self.D.cuda()
self.G.cuda()
self.use_cuda = True
Expand Down Expand Up @@ -102,18 +105,7 @@ def generate_molecules(self, smile_str, n_attemps=300, lam_fact=1., probab=False
:return: list of RDKit molecules.
"""

shape_input, cond_input = get_mol_voxels(smile_str)
if self.use_cuda:
shape_input = shape_input.cuda()
cond_input = cond_input.cuda()

shape_input = shape_input.unsqueeze(0).repeat(n_attemps, 1, 1, 1, 1)
cond_input = cond_input.unsqueeze(0).repeat(n_attemps, 1, 1, 1, 1)

shape_input = Variable(shape_input, volatile=True)
cond_input = Variable(cond_input, volatile=True)

#recoded_shapes, _, _ = self.G(shape_input, cond_input, lam_fact)
z = Variable(torch.randn(n_attemps, 128, 12, 12, 12)).cuda()
recoded_shapes = self.G(z)
smiles = self.caption_shape(recoded_shapes, probab=probab)
Expand Down
5 changes: 3 additions & 2 deletions decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Copying and distribution is allowed under AGPLv3 license

vocab_list = ["pad", "start", "end",
"C", "c", "N", "n", "S", "s", "P", "O", "o",
"C", "c", "N", "n", "S", "s", "P", "p", "O", "o",
"B", "F", "I",
"Cl", "[nH]", "Br", # "X", "Y", "Z",
"X", "Y", "Z", #"Cl", "[nH]", "Br"
"1", "2", "3", "4", "5", "6",
"#", "=", "-", "(", ")","/","\\","@","[","]","H","+","7" # Misc
]
Expand All @@ -25,5 +25,6 @@ def decode_smiles(in_tensor):
if xchar == 2:
break
csmile += vocab_i2c_v1[xchar]
csmile = csmile.replace("X","Cl").replace("Y","[nH]").replace("Z","Br")
gen_smiles.append(csmile)
return gen_smiles
61 changes: 53 additions & 8 deletions generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from htmd.molecule.util import uniformRandomRotation
from htmd.smallmol.smallmol import SmallMol
from htmd.molecule.voxeldescriptors import _getOccupancyC, _getGridCenters

#from htmd.molecule.util import uniformRandomRotation
#from htmd.smallmol.smallmol import SmallMol
#from htmd.molecule.voxeldescriptors import _getOccupancyC, _getGridCenters

from moleculekit.smallmol.smallmol import SmallMol
from moleculekit.tools.voxeldescriptors import getVoxelDescriptors


import numpy as np
import multiprocessing
import math
import random

vocab_list = ["pad", "start", "end",
"C", "c", "N", "n", "S", "s", "P", "O", "o",
"C", "c", "N", "n", "S", "s", "P", "p", "O", "o",
"B", "F", "I",
"Cl", "[nH]", "Br", # "X", "Y", "Z",
"X", "Y", "Z", #"Cl", "[nH]", "Br"
"1", "2", "3", "4", "5", "6",
"#", "=", "-", "(", ")","/","\\","@","[","]","H","+","7" # Misc
]
Expand All @@ -30,7 +35,7 @@
size = 24
N = [size, size, size]
bbm = (np.zeros(3) - float(size * 1. / 2))
global_centers = _getGridCenters(bbm, N, resolution)
#global_centers = _getGridCenters(bbm, N, resolution)


def string_gen_V1(in_string):
Expand Down Expand Up @@ -70,7 +75,7 @@ def generate_representation(in_smile):
AllChem.EmbedMolecule(mh)
Chem.AllChem.MMFFOptimizeMolecule(mh)
m = Chem.RemoveHs(mh)
mol = SmallMol(m)
mol = SmallMol(m, force_reading=True,fixHs=False)
return mol
except: # Rarely the conformer generation fails
return None
Expand Down Expand Up @@ -159,6 +164,8 @@ def generate_representation_v1(smile):
end_token = smile_str.index(2)
smile_str = "".join([vocab_i2c_v1[i] for i in smile_str[1:end_token]])

smile_str = smile_str.replace("X","Cl").replace("Y","[nH]").replace("Z","Br")

mol = generate_representation(smile_str)
if mol is None:
return None
Expand All @@ -170,6 +177,44 @@ def generate_representation_v1(smile):

return torch.Tensor(vox), torch.Tensor(smile), end_token + 1

def generate_representation_v2(smile):
"""
Generate voxelized and string representation of a molecule
"""
# Convert smile to 3D structure

smile_str = list(smile)
end_token = smile_str.index(2)
smile_str = "".join([vocab_i2c_v1[i] for i in smile_str[1:end_token]])

smile_str = smile_str.replace("X","Cl").replace("Y","[nH]").replace("Z","Br")

mol = generate_representation(smile_str)
if mol is None:
return None

try:
center = mol.getCenter()
box = [size, size, size] #size = 24
#getVoxelDescriptors calculates feature of the mol object and return features as array, centers of voxel
#The features define the 8 feature of the voxel, (‘hydrophobic’, ‘aromatic’, ‘hbond_acceptor’, ‘hbond_donor’, ‘positive_ionizable’, ‘negative_ionizable’, ‘metal’, ‘occupancies’).
mol_vox, mol_centers, mol_N = getVoxelDescriptors(mol, boxsize=box, voxelsize=1, buffer=0, center=center,validitychecks =False)
#print(mol_vox, mol_centers, mol_N ) #mol_N = [35 35 35]

#print(mol_vox.shape,mol_centers.shape,mol_N.shape) #(42875, 8) (42875, 3) (3,)

mol_vox_t = mol_vox.transpose().reshape([1, mol_vox.shape[1], mol_N[0], mol_N[1], mol_N[2]])
except:
print("Can not Voxelization")
#sys.exit()
return None

finish_combine = np.squeeze(mol_vox_t)

#print(finish_combine.shape)

return torch.Tensor(finish_combine), torch.Tensor(smile), end_token + 1


def gather_fn(in_data):
"""
Expand Down Expand Up @@ -200,7 +245,7 @@ def __init__(self, n_proc=6, mp_pool=None):
raise NotImplementedError("Use multiprocessing for now!")

def transform_data(self, smiles):
inputs = self.mp.map(generate_representation_v1, smiles)
inputs = self.mp.map(generate_representation_v2, smiles)

# Sometimes representation generation fails
inputs = list(filter(lambda x: x is not None, inputs))
Expand Down
12 changes: 6 additions & 6 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def init_weights(self):
def forward(self, features, captions, lengths):
"""Decode shapes feature vectors and generates SMILES."""
embeddings = self.embed(captions)
print(embeddings.shape,captions.shape)
#print(embeddings.shape,captions.shape)

embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)

print(embeddings.shape,features.shape)
#print(embeddings.shape,features.shape)

packed = pack_padded_sequence(embeddings, lengths, batch_first=True)

print(packed,packed[0].shape,sum(packed[1]),sum(lengths))
#print(packed,packed[0].shape,sum(packed[1]),sum(lengths))

hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
Expand All @@ -147,7 +147,7 @@ def sample(self, features, states=None):
"""Samples SMILES tockens for given shape features (Greedy search)."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(80):
for i in range(75):
hiddens, states = self.lstm(inputs, states)
outputs = self.linear(hiddens.squeeze(1))
predicted = outputs.max(1)[1]
Expand All @@ -160,7 +160,7 @@ def sample_prob(self, features, states=None):
"""Samples SMILES tockens for given shape features (probalistic picking)."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(62): # maximum sampling length
for i in range(130): # maximum sampling length
hiddens, states = self.lstm(inputs, states)
outputs = self.linear(hiddens.squeeze(1))
if i == 0:
Expand All @@ -184,7 +184,7 @@ def sample_prob(self, features, states=None):
valid_token = rand_num < iter_sum
update_indecies = np.logical_and(valid_token,
np.logical_not(tokens.astype(np.bool)))
tokens[update_indecies] = i+1
tokens[update_indecies] = i

# put back on the GPU.
if probs.is_cuda:
Expand Down
52 changes: 28 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copying and distribution is allowed under AGPLv3 license

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import sys
import torch
import torch.autograd
Expand All @@ -26,22 +27,22 @@

cap_loss = 0.
caption_start = 4000
batch_size = 128
batch_size = 256

savedir = args["output_dir"]
os.makedirs(savedir, exist_ok=True)
smiles = np.load(args["input"])

import multiprocessing
multiproc = multiprocessing.Pool(6)
multiproc = multiprocessing.Pool(14)
my_gen = queue_datagen(smiles, batch_size=batch_size, mp_pool=multiproc)
mg = GeneratorEnqueuer(my_gen, seed=0)
mg = GeneratorEnqueuer(my_gen)
mg.start()
mt_gen = mg.get()

# Define the networks
encoder = EncoderCNN(8)
decoder = DecoderRNN(512, 1024, 29, 1)
decoder = DecoderRNN(512, 1024, 38, 1)
D = discriminator(nc=8,use_cuda=True)
G = generator(nc=8,use_cuda=True)

Expand All @@ -51,35 +52,35 @@
G.cuda()

# Caption optimizer
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(ignore_index = 0)
caption_params = list(decoder.parameters()) + list(encoder.parameters())
caption_optimizer = torch.optim.Adam(caption_params, lr=0.001)

encoder.train()
decoder.train()

# GAN optimizer
dg_criterion = nn.BCELoss()
dg_criterion = nn.BCELoss() # 是单目标二分类交叉熵函数
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
z_dimension = 32

tq_gen = tqdm(enumerate(mt_gen))
log_file = open(os.path.join(savedir, "log.txt"), "w")
cap_loss = 0.
#caption_start = 4000
caption_start = 0
caption_start = 4000


for i, (mol_batch, caption, lengths) in tq_gen:
num_img = mol_batch.size(0)
real_data = Variable(mol_batch[:, :]).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()
#print('fake label', fake_label.shape)
########Train the discriminator#######
real_out = D(real_data.float())
########判别器训练train#######
real_out = D(real_data.float()) # 将真实图片放入判别器中

d_loss_real = dg_criterion(real_out.view(-1), real_label)
d_loss_real = dg_criterion(real_out.view(-1), real_label) # 得到真实图片的loss
real_scores = real_out

z = Variable(torch.randn(num_img, 128, 12, 12, 12)).cuda()
Expand All @@ -90,35 +91,38 @@
d_loss_fake = dg_criterion(fake_out.view(-1), fake_label)
fake_scores = fake_out

d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失
d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
d_loss.backward() # 将误差反向传播
d_optimizer.step()

#==================Training generator========
#==================训练生成器========
z = Variable(torch.randn(num_img, 128, 12, 12, 12)).cuda()
fake_data = G(z)
output = D(fake_data)
#print(output.view(-1).shape, real_label.shape)
g_loss = dg_criterion(output.view(-1), real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_loss.backward() #retain_graph=True retain_graph=True if i >= caption_start else False
g_optimizer.step()

recon_batch = G(z.detach())
if i >= caption_start: # Start by autoencoder optimization
captions = Variable(caption.cuda())
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

caption = Variable(caption.cuda())

targets = pack_padded_sequence(caption, lengths, batch_first=True)[0]

decoder.zero_grad()
encoder.zero_grad()
features = encoder(recon_batch)
outputs = decoder(features, captions, lengths)
outputs = decoder(features, caption, lengths)
cap_loss = criterion(outputs, targets)
cap_loss.backward()

cap_loss.backward() #
#torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1)
caption_optimizer.step()

break


if (i + 1) % 5000 == 0:
torch.save(decoder.state_dict(),
Expand Down Expand Up @@ -151,12 +155,12 @@
lr = param_group["lr"] / 2.
param_group["lr"] = lr

if i == 120000:
if i == 210000:
# We are Done!
log_file.close()
break

# Cleanup
del tq_gen
mt_gen.close()
multiproc.close()
multiproc.close()

0 comments on commit dafd12f

Please sign in to comment.