Skip to content

Commit

Permalink
adding single-node postgresql code
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Jan 26, 2022
1 parent a7c3f7b commit 9d4fcff
Show file tree
Hide file tree
Showing 11 changed files with 841 additions and 76 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ rlops_pass
.vscode
*.sw[op]
*.out
examples/code_ocean/redox-models
examples/code_ocean/*.db
examples/code_ocean/policy_checkpoints
examples/code_ocean/builder_cache
57 changes: 57 additions & 0 deletions examples/code_ocean/bde_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pandas as pd
import rdkit
from rdkit import Chem


def prepare_for_bde(mol: rdkit.Chem.Mol) -> pd.Series:
radical_index = None
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetNumRadicalElectrons() != 0:
assert radical_index is None
radical_index = i

atom.SetNumExplicitHs(atom.GetNumExplicitHs() + 1)
atom.SetNumRadicalElectrons(0)
break
else:
raise RuntimeError(f"No radical found: {Chem.MolToSmiles(mol)}")

radical_rank = Chem.CanonicalRankAtoms(mol, includeChirality=True)[radical_index]

mol_smiles = Chem.MolToSmiles(mol)
# TODO this line seems redundant
mol = Chem.MolFromSmiles(mol_smiles)

radical_index_reordered = list(
Chem.CanonicalRankAtoms(mol, includeChirality=True)
).index(radical_rank)

molH = Chem.AddHs(mol)
for bond in molH.GetAtomWithIdx(radical_index_reordered).GetBonds():
if "H" in {bond.GetBeginAtom().GetSymbol(), bond.GetEndAtom().GetSymbol()}:
bond_index = bond.GetIdx()
break
else:
raise RuntimeError("Bond not found")

h_bond_indices = [
bond.GetIdx()
for bond in filter(
lambda bond: (
(bond.GetEndAtom().GetSymbol() == "H")
| (bond.GetBeginAtom().GetSymbol() == "H")
),
molH.GetBonds(),
)
]

other_h_bonds = list(set(h_bond_indices) - {bond_index})

return pd.Series(
{
"mol_smiles": mol_smiles,
"radical_index_mol": radical_index_reordered,
"bond_index": bond_index,
"other_h_bonds": other_h_bonds,
}
)
97 changes: 97 additions & 0 deletions examples/code_ocean/config_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Config file for stable_radical_opt.py

run_id: 'local_run'

# Parameters for setting up the problem
problem_config:
initial_state: 'C'
# maximum number of heavy atoms
max_atoms: 10
# minimum number of heavy atoms
min_atoms: 4
# atoms to use when building the molecule
atom_additions: [ 'C', 'N', 'O', 'S' ]
# if set, don't construct molecules greater than a given Synthetic Accessibility (SA) score
# see: https://github.com/rdkit/rdkit/blob/master/Contrib/SA_Score/sascorer.py
sa_score_threshold: 4.0
# whether to consider stereoisomers different molecules
stereoisomers: True
# try to get a 3D embedding of the molecule, and if this fails, remove it.
try_embedding: False
canonicalize_tautomers: False # this is CPU intensive
cache_dir: 'builder_cache/'
num_shards: 1
parallel: False
redox_model: 'redox-models/models/redox_model'
stability_model: 'redox-models/models/stability_model'

# Parameters for training the policy model
train_config:
# Reward options:
# if the reward for a given game is > the previous
# *ranked_reward_alpha* fraction of games (e.g., 75% of games),
# then it's a win. Otherwise, it's a loss.
ranked_reward_alpha: 0.75
# max/min number of games to consider
reward_buffer_max_size: 100
reward_buffer_min_size: 25

# Learning options:
# some useful tips for selecting these parameter values:
# https://stackoverflow.com/a/49924566/7483950
# learning rate
lr: 1E-3
# number times that the learning algorithm will work through the entire training dataset (PSJ -- we actually never want training to stop)
epochs: 1E6
# number of batch iterations before a training epoch is considered finished
steps_per_epoch: 100
# number of seconds to wait to check if enough games have been played
game_count_delay: 20
verbose: 2

# AlphaZero problem options:
# max/min number of games to consider (ordered by time) when training the policy
max_buffer_size: 256
min_buffer_size: 32
# number of training examples to use when updating model parameters
batch_size: 32
# folder in which to store the trained models
policy_checkpoint_dir: 'policy_checkpoints'

# MoleculeTFAlphaZeroProblem options:
# size of network hidden layers
features: 16
# number of global state attention heads. Must be a factor of `features`
num_heads: 1
# number of message passing layers
num_messages: 1

# Parameters for running the Monte Carlo Tree Search games
mcts_config:
# Minimum reward to return for invalid actions
min_reward: 0
pbc_c_base: 1.0
pbc_c_init: 1.25
# dirichlet 'shape' parameter. Larger values spread out probability over more moves.
dirichlet_alpha: 1.0
# percentage to favor dirichlet noise vs. prior estimation. Smaller means less noise
dirichlet_x: 0.5
# number of samples to perform at each level of the MCTS search
num_mcts_samples: 100
# Maximum number of seconds per MCTS round
timeout: 30
# the maximum search depth
max_depth: 1000000
#ucb_constant: math.sqrt(2)

# Database settings for the Object Relational Model (ORM)
# Used to store games and communicate between the policy model (run on GPUs) and rollout (run on CPUs)
sql_database:
# settings to connect to NREL's yuma database
drivername: "postgresql+psycopg2"
dbname: "rl"
port: "5432"
# This will be overwritten by env variable DB_HOST in rlmolecule/sql/run_config.sh
host: "localhost"
user: "example_user"
passwd: "tmppassword"
Binary file added examples/code_ocean/redox_fragment_data.pz
Binary file not shown.
130 changes: 130 additions & 0 deletions examples/code_ocean/stable_radical_molecule_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import gzip
import logging
import os
import pickle
from typing import Optional, Sequence

import rdkit
from rdkit import Chem
from rdkit.Chem import FragmentCatalog, Mol
from rlmolecule.molecule.builder.builder import (
AddNewAtomsAndBonds,
MoleculeBuilder,
MoleculeFilter,
)
from rlmolecule.molecule.molecule_state import MoleculeState
from rlmolecule.tree_search.metrics import collect_metrics

fcgen = FragmentCatalog.FragCatGenerator()
fpgen = FragmentCatalog.FragFPGenerator()
dir_path = os.path.dirname(os.path.realpath(__file__))
logger = logging.getLogger(__name__)


class AddNewAtomsAndBondsProtectRadical(AddNewAtomsAndBonds):
@staticmethod
def _get_free_valence(atom) -> int:
fv = AddNewAtomsAndBonds._get_free_valence(atom)
return fv - atom.GetNumRadicalElectrons()


class MoleculeBuilderWithFingerprint(MoleculeBuilder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformation_stack += [FingerprintFilter()]


class MoleculeBuilderProtectRadical(MoleculeBuilderWithFingerprint):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformation_stack[0] = AddNewAtomsAndBondsProtectRadical(
kwargs["atom_additions"]
)


class FingerprintFilter(MoleculeFilter):
def __init__(self):
super(FingerprintFilter, self).__init__()
with gzip.open(os.path.join(dir_path, "redox_fragment_data.pz")) as f:
data = pickle.load(f)
self.fcat = data["fcat"]
self.valid_fps = set(data["valid_fps"])

def get_fingerprint(self, mol):
fcgen.AddFragsFromMol(mol, self.fcat)
fp = fpgen.GetFPForMol(mol, self.fcat)
for i in fp.GetOnBits():
yield self.fcat.GetEntryDescription(i)

def filter(self, molecule: rdkit.Chem.Mol) -> bool:
fps = set(self.get_fingerprint(molecule))
if fps.difference(self.valid_fps) == set():
return True
else:
return False


class StableRadMoleculeState(MoleculeState):
"""
A State implementation which uses simple transformations (such as adding a bond) to
define a graph of molecules that can be navigated.
Molecules are stored as rdkit Mol instances, and the rdkit-generated SMILES string
is also stored for efficient hashing.
"""

def __init__(
self,
molecule: Mol,
builder: any,
force_terminal: bool = False,
smiles: Optional[str] = None,
) -> None:
super(StableRadMoleculeState, self).__init__(
molecule, builder, force_terminal, smiles
)

@collect_metrics
def get_next_actions(self) -> Sequence["StableRadMoleculeState"]:

logger.debug(f"Getting next actions for {self}")

result = []
if not self._forced_terminal:
if self.num_atoms < self.builder.max_atoms:
result.extend(
(
StableRadMoleculeState(molecule, self.builder)
for molecule in self.builder(self.molecule)
)
)

if self.num_atoms >= self.builder.min_atoms:
result.extend(
(
StableRadMoleculeState(
radical, self.builder, force_terminal=True
)
for radical in build_radicals(self.molecule)
)
)

return result


def build_radicals(starting_mol):
"""Build organic radicals. """

generated_smiles = set()

for i, atom in enumerate(starting_mol.GetAtoms()):
if AddNewAtomsAndBonds._get_free_valence(atom) > 0:
rw_mol = rdkit.Chem.RWMol(starting_mol)
rw_mol.GetAtomWithIdx(i).SetNumRadicalElectrons(1)

Chem.SanitizeMol(rw_mol)
smiles = Chem.MolToSmiles(rw_mol)
if smiles not in generated_smiles:
# This makes sure the atom ordering is standardized
yield Chem.MolFromSmiles(smiles)
generated_smiles.add(smiles)
Loading

0 comments on commit 9d4fcff

Please sign in to comment.