-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
841 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pandas as pd | ||
import rdkit | ||
from rdkit import Chem | ||
|
||
|
||
def prepare_for_bde(mol: rdkit.Chem.Mol) -> pd.Series: | ||
radical_index = None | ||
for i, atom in enumerate(mol.GetAtoms()): | ||
if atom.GetNumRadicalElectrons() != 0: | ||
assert radical_index is None | ||
radical_index = i | ||
|
||
atom.SetNumExplicitHs(atom.GetNumExplicitHs() + 1) | ||
atom.SetNumRadicalElectrons(0) | ||
break | ||
else: | ||
raise RuntimeError(f"No radical found: {Chem.MolToSmiles(mol)}") | ||
|
||
radical_rank = Chem.CanonicalRankAtoms(mol, includeChirality=True)[radical_index] | ||
|
||
mol_smiles = Chem.MolToSmiles(mol) | ||
# TODO this line seems redundant | ||
mol = Chem.MolFromSmiles(mol_smiles) | ||
|
||
radical_index_reordered = list( | ||
Chem.CanonicalRankAtoms(mol, includeChirality=True) | ||
).index(radical_rank) | ||
|
||
molH = Chem.AddHs(mol) | ||
for bond in molH.GetAtomWithIdx(radical_index_reordered).GetBonds(): | ||
if "H" in {bond.GetBeginAtom().GetSymbol(), bond.GetEndAtom().GetSymbol()}: | ||
bond_index = bond.GetIdx() | ||
break | ||
else: | ||
raise RuntimeError("Bond not found") | ||
|
||
h_bond_indices = [ | ||
bond.GetIdx() | ||
for bond in filter( | ||
lambda bond: ( | ||
(bond.GetEndAtom().GetSymbol() == "H") | ||
| (bond.GetBeginAtom().GetSymbol() == "H") | ||
), | ||
molH.GetBonds(), | ||
) | ||
] | ||
|
||
other_h_bonds = list(set(h_bond_indices) - {bond_index}) | ||
|
||
return pd.Series( | ||
{ | ||
"mol_smiles": mol_smiles, | ||
"radical_index_mol": radical_index_reordered, | ||
"bond_index": bond_index, | ||
"other_h_bonds": other_h_bonds, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Config file for stable_radical_opt.py | ||
|
||
run_id: 'local_run' | ||
|
||
# Parameters for setting up the problem | ||
problem_config: | ||
initial_state: 'C' | ||
# maximum number of heavy atoms | ||
max_atoms: 10 | ||
# minimum number of heavy atoms | ||
min_atoms: 4 | ||
# atoms to use when building the molecule | ||
atom_additions: [ 'C', 'N', 'O', 'S' ] | ||
# if set, don't construct molecules greater than a given Synthetic Accessibility (SA) score | ||
# see: https://github.com/rdkit/rdkit/blob/master/Contrib/SA_Score/sascorer.py | ||
sa_score_threshold: 4.0 | ||
# whether to consider stereoisomers different molecules | ||
stereoisomers: True | ||
# try to get a 3D embedding of the molecule, and if this fails, remove it. | ||
try_embedding: False | ||
canonicalize_tautomers: False # this is CPU intensive | ||
cache_dir: 'builder_cache/' | ||
num_shards: 1 | ||
parallel: False | ||
redox_model: 'redox-models/models/redox_model' | ||
stability_model: 'redox-models/models/stability_model' | ||
|
||
# Parameters for training the policy model | ||
train_config: | ||
# Reward options: | ||
# if the reward for a given game is > the previous | ||
# *ranked_reward_alpha* fraction of games (e.g., 75% of games), | ||
# then it's a win. Otherwise, it's a loss. | ||
ranked_reward_alpha: 0.75 | ||
# max/min number of games to consider | ||
reward_buffer_max_size: 100 | ||
reward_buffer_min_size: 25 | ||
|
||
# Learning options: | ||
# some useful tips for selecting these parameter values: | ||
# https://stackoverflow.com/a/49924566/7483950 | ||
# learning rate | ||
lr: 1E-3 | ||
# number times that the learning algorithm will work through the entire training dataset (PSJ -- we actually never want training to stop) | ||
epochs: 1E6 | ||
# number of batch iterations before a training epoch is considered finished | ||
steps_per_epoch: 100 | ||
# number of seconds to wait to check if enough games have been played | ||
game_count_delay: 20 | ||
verbose: 2 | ||
|
||
# AlphaZero problem options: | ||
# max/min number of games to consider (ordered by time) when training the policy | ||
max_buffer_size: 256 | ||
min_buffer_size: 32 | ||
# number of training examples to use when updating model parameters | ||
batch_size: 32 | ||
# folder in which to store the trained models | ||
policy_checkpoint_dir: 'policy_checkpoints' | ||
|
||
# MoleculeTFAlphaZeroProblem options: | ||
# size of network hidden layers | ||
features: 16 | ||
# number of global state attention heads. Must be a factor of `features` | ||
num_heads: 1 | ||
# number of message passing layers | ||
num_messages: 1 | ||
|
||
# Parameters for running the Monte Carlo Tree Search games | ||
mcts_config: | ||
# Minimum reward to return for invalid actions | ||
min_reward: 0 | ||
pbc_c_base: 1.0 | ||
pbc_c_init: 1.25 | ||
# dirichlet 'shape' parameter. Larger values spread out probability over more moves. | ||
dirichlet_alpha: 1.0 | ||
# percentage to favor dirichlet noise vs. prior estimation. Smaller means less noise | ||
dirichlet_x: 0.5 | ||
# number of samples to perform at each level of the MCTS search | ||
num_mcts_samples: 100 | ||
# Maximum number of seconds per MCTS round | ||
timeout: 30 | ||
# the maximum search depth | ||
max_depth: 1000000 | ||
#ucb_constant: math.sqrt(2) | ||
|
||
# Database settings for the Object Relational Model (ORM) | ||
# Used to store games and communicate between the policy model (run on GPUs) and rollout (run on CPUs) | ||
sql_database: | ||
# settings to connect to NREL's yuma database | ||
drivername: "postgresql+psycopg2" | ||
dbname: "rl" | ||
port: "5432" | ||
# This will be overwritten by env variable DB_HOST in rlmolecule/sql/run_config.sh | ||
host: "localhost" | ||
user: "example_user" | ||
passwd: "tmppassword" |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import gzip | ||
import logging | ||
import os | ||
import pickle | ||
from typing import Optional, Sequence | ||
|
||
import rdkit | ||
from rdkit import Chem | ||
from rdkit.Chem import FragmentCatalog, Mol | ||
from rlmolecule.molecule.builder.builder import ( | ||
AddNewAtomsAndBonds, | ||
MoleculeBuilder, | ||
MoleculeFilter, | ||
) | ||
from rlmolecule.molecule.molecule_state import MoleculeState | ||
from rlmolecule.tree_search.metrics import collect_metrics | ||
|
||
fcgen = FragmentCatalog.FragCatGenerator() | ||
fpgen = FragmentCatalog.FragFPGenerator() | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AddNewAtomsAndBondsProtectRadical(AddNewAtomsAndBonds): | ||
@staticmethod | ||
def _get_free_valence(atom) -> int: | ||
fv = AddNewAtomsAndBonds._get_free_valence(atom) | ||
return fv - atom.GetNumRadicalElectrons() | ||
|
||
|
||
class MoleculeBuilderWithFingerprint(MoleculeBuilder): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.transformation_stack += [FingerprintFilter()] | ||
|
||
|
||
class MoleculeBuilderProtectRadical(MoleculeBuilderWithFingerprint): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.transformation_stack[0] = AddNewAtomsAndBondsProtectRadical( | ||
kwargs["atom_additions"] | ||
) | ||
|
||
|
||
class FingerprintFilter(MoleculeFilter): | ||
def __init__(self): | ||
super(FingerprintFilter, self).__init__() | ||
with gzip.open(os.path.join(dir_path, "redox_fragment_data.pz")) as f: | ||
data = pickle.load(f) | ||
self.fcat = data["fcat"] | ||
self.valid_fps = set(data["valid_fps"]) | ||
|
||
def get_fingerprint(self, mol): | ||
fcgen.AddFragsFromMol(mol, self.fcat) | ||
fp = fpgen.GetFPForMol(mol, self.fcat) | ||
for i in fp.GetOnBits(): | ||
yield self.fcat.GetEntryDescription(i) | ||
|
||
def filter(self, molecule: rdkit.Chem.Mol) -> bool: | ||
fps = set(self.get_fingerprint(molecule)) | ||
if fps.difference(self.valid_fps) == set(): | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
class StableRadMoleculeState(MoleculeState): | ||
""" | ||
A State implementation which uses simple transformations (such as adding a bond) to | ||
define a graph of molecules that can be navigated. | ||
Molecules are stored as rdkit Mol instances, and the rdkit-generated SMILES string | ||
is also stored for efficient hashing. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
molecule: Mol, | ||
builder: any, | ||
force_terminal: bool = False, | ||
smiles: Optional[str] = None, | ||
) -> None: | ||
super(StableRadMoleculeState, self).__init__( | ||
molecule, builder, force_terminal, smiles | ||
) | ||
|
||
@collect_metrics | ||
def get_next_actions(self) -> Sequence["StableRadMoleculeState"]: | ||
|
||
logger.debug(f"Getting next actions for {self}") | ||
|
||
result = [] | ||
if not self._forced_terminal: | ||
if self.num_atoms < self.builder.max_atoms: | ||
result.extend( | ||
( | ||
StableRadMoleculeState(molecule, self.builder) | ||
for molecule in self.builder(self.molecule) | ||
) | ||
) | ||
|
||
if self.num_atoms >= self.builder.min_atoms: | ||
result.extend( | ||
( | ||
StableRadMoleculeState( | ||
radical, self.builder, force_terminal=True | ||
) | ||
for radical in build_radicals(self.molecule) | ||
) | ||
) | ||
|
||
return result | ||
|
||
|
||
def build_radicals(starting_mol): | ||
"""Build organic radicals. """ | ||
|
||
generated_smiles = set() | ||
|
||
for i, atom in enumerate(starting_mol.GetAtoms()): | ||
if AddNewAtomsAndBonds._get_free_valence(atom) > 0: | ||
rw_mol = rdkit.Chem.RWMol(starting_mol) | ||
rw_mol.GetAtomWithIdx(i).SetNumRadicalElectrons(1) | ||
|
||
Chem.SanitizeMol(rw_mol) | ||
smiles = Chem.MolToSmiles(rw_mol) | ||
if smiles not in generated_smiles: | ||
# This makes sure the atom ordering is standardized | ||
yield Chem.MolFromSmiles(smiles) | ||
generated_smiles.add(smiles) |
Oops, something went wrong.