diff --git a/.gitignore b/.gitignore index 6eebc143..a71c9644 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,7 @@ rlops_pass .vscode *.sw[op] *.out +examples/code_ocean/redox-models +examples/code_ocean/*.db +examples/code_ocean/policy_checkpoints +examples/code_ocean/builder_cache diff --git a/examples/code_ocean/bde_utils.py b/examples/code_ocean/bde_utils.py new file mode 100644 index 00000000..1451e8a8 --- /dev/null +++ b/examples/code_ocean/bde_utils.py @@ -0,0 +1,57 @@ +import pandas as pd +import rdkit +from rdkit import Chem + + +def prepare_for_bde(mol: rdkit.Chem.Mol) -> pd.Series: + radical_index = None + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetNumRadicalElectrons() != 0: + assert radical_index is None + radical_index = i + + atom.SetNumExplicitHs(atom.GetNumExplicitHs() + 1) + atom.SetNumRadicalElectrons(0) + break + else: + raise RuntimeError(f"No radical found: {Chem.MolToSmiles(mol)}") + + radical_rank = Chem.CanonicalRankAtoms(mol, includeChirality=True)[radical_index] + + mol_smiles = Chem.MolToSmiles(mol) + # TODO this line seems redundant + mol = Chem.MolFromSmiles(mol_smiles) + + radical_index_reordered = list( + Chem.CanonicalRankAtoms(mol, includeChirality=True) + ).index(radical_rank) + + molH = Chem.AddHs(mol) + for bond in molH.GetAtomWithIdx(radical_index_reordered).GetBonds(): + if "H" in {bond.GetBeginAtom().GetSymbol(), bond.GetEndAtom().GetSymbol()}: + bond_index = bond.GetIdx() + break + else: + raise RuntimeError("Bond not found") + + h_bond_indices = [ + bond.GetIdx() + for bond in filter( + lambda bond: ( + (bond.GetEndAtom().GetSymbol() == "H") + | (bond.GetBeginAtom().GetSymbol() == "H") + ), + molH.GetBonds(), + ) + ] + + other_h_bonds = list(set(h_bond_indices) - {bond_index}) + + return pd.Series( + { + "mol_smiles": mol_smiles, + "radical_index_mol": radical_index_reordered, + "bond_index": bond_index, + "other_h_bonds": other_h_bonds, + } + ) diff --git a/examples/code_ocean/config_local.yaml b/examples/code_ocean/config_local.yaml new file mode 100644 index 00000000..4dd70289 --- /dev/null +++ b/examples/code_ocean/config_local.yaml @@ -0,0 +1,97 @@ +# Config file for stable_radical_opt.py + +run_id: 'local_run' + +# Parameters for setting up the problem +problem_config: + initial_state: 'C' + # maximum number of heavy atoms + max_atoms: 10 + # minimum number of heavy atoms + min_atoms: 4 + # atoms to use when building the molecule + atom_additions: [ 'C', 'N', 'O', 'S' ] + # if set, don't construct molecules greater than a given Synthetic Accessibility (SA) score + # see: https://github.com/rdkit/rdkit/blob/master/Contrib/SA_Score/sascorer.py + sa_score_threshold: 4.0 + # whether to consider stereoisomers different molecules + stereoisomers: True + # try to get a 3D embedding of the molecule, and if this fails, remove it. + try_embedding: False + canonicalize_tautomers: False # this is CPU intensive + cache_dir: 'builder_cache/' + num_shards: 1 + parallel: False + redox_model: 'redox-models/models/redox_model' + stability_model: 'redox-models/models/stability_model' + +# Parameters for training the policy model +train_config: + # Reward options: + # if the reward for a given game is > the previous + # *ranked_reward_alpha* fraction of games (e.g., 75% of games), + # then it's a win. Otherwise, it's a loss. + ranked_reward_alpha: 0.75 + # max/min number of games to consider + reward_buffer_max_size: 100 + reward_buffer_min_size: 25 + + # Learning options: + # some useful tips for selecting these parameter values: + # https://stackoverflow.com/a/49924566/7483950 + # learning rate + lr: 1E-3 + # number times that the learning algorithm will work through the entire training dataset (PSJ -- we actually never want training to stop) + epochs: 1E6 + # number of batch iterations before a training epoch is considered finished + steps_per_epoch: 100 + # number of seconds to wait to check if enough games have been played + game_count_delay: 20 + verbose: 2 + + # AlphaZero problem options: + # max/min number of games to consider (ordered by time) when training the policy + max_buffer_size: 256 + min_buffer_size: 32 + # number of training examples to use when updating model parameters + batch_size: 32 + # folder in which to store the trained models + policy_checkpoint_dir: 'policy_checkpoints' + + # MoleculeTFAlphaZeroProblem options: + # size of network hidden layers + features: 16 + # number of global state attention heads. Must be a factor of `features` + num_heads: 1 + # number of message passing layers + num_messages: 1 + +# Parameters for running the Monte Carlo Tree Search games +mcts_config: + # Minimum reward to return for invalid actions + min_reward: 0 + pbc_c_base: 1.0 + pbc_c_init: 1.25 + # dirichlet 'shape' parameter. Larger values spread out probability over more moves. + dirichlet_alpha: 1.0 + # percentage to favor dirichlet noise vs. prior estimation. Smaller means less noise + dirichlet_x: 0.5 + # number of samples to perform at each level of the MCTS search + num_mcts_samples: 100 + # Maximum number of seconds per MCTS round + timeout: 30 + # the maximum search depth + max_depth: 1000000 + #ucb_constant: math.sqrt(2) + +# Database settings for the Object Relational Model (ORM) +# Used to store games and communicate between the policy model (run on GPUs) and rollout (run on CPUs) +sql_database: + # settings to connect to NREL's yuma database + drivername: "postgresql+psycopg2" + dbname: "rl" + port: "5432" + # This will be overwritten by env variable DB_HOST in rlmolecule/sql/run_config.sh + host: "localhost" + user: "example_user" + passwd: "tmppassword" \ No newline at end of file diff --git a/examples/code_ocean/redox_fragment_data.pz b/examples/code_ocean/redox_fragment_data.pz new file mode 100644 index 00000000..e94469e4 Binary files /dev/null and b/examples/code_ocean/redox_fragment_data.pz differ diff --git a/examples/code_ocean/stable_radical_molecule_state.py b/examples/code_ocean/stable_radical_molecule_state.py new file mode 100644 index 00000000..9ec676d9 --- /dev/null +++ b/examples/code_ocean/stable_radical_molecule_state.py @@ -0,0 +1,130 @@ +import gzip +import logging +import os +import pickle +from typing import Optional, Sequence + +import rdkit +from rdkit import Chem +from rdkit.Chem import FragmentCatalog, Mol +from rlmolecule.molecule.builder.builder import ( + AddNewAtomsAndBonds, + MoleculeBuilder, + MoleculeFilter, +) +from rlmolecule.molecule.molecule_state import MoleculeState +from rlmolecule.tree_search.metrics import collect_metrics + +fcgen = FragmentCatalog.FragCatGenerator() +fpgen = FragmentCatalog.FragFPGenerator() +dir_path = os.path.dirname(os.path.realpath(__file__)) +logger = logging.getLogger(__name__) + + +class AddNewAtomsAndBondsProtectRadical(AddNewAtomsAndBonds): + @staticmethod + def _get_free_valence(atom) -> int: + fv = AddNewAtomsAndBonds._get_free_valence(atom) + return fv - atom.GetNumRadicalElectrons() + + +class MoleculeBuilderWithFingerprint(MoleculeBuilder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.transformation_stack += [FingerprintFilter()] + + +class MoleculeBuilderProtectRadical(MoleculeBuilderWithFingerprint): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.transformation_stack[0] = AddNewAtomsAndBondsProtectRadical( + kwargs["atom_additions"] + ) + + +class FingerprintFilter(MoleculeFilter): + def __init__(self): + super(FingerprintFilter, self).__init__() + with gzip.open(os.path.join(dir_path, "redox_fragment_data.pz")) as f: + data = pickle.load(f) + self.fcat = data["fcat"] + self.valid_fps = set(data["valid_fps"]) + + def get_fingerprint(self, mol): + fcgen.AddFragsFromMol(mol, self.fcat) + fp = fpgen.GetFPForMol(mol, self.fcat) + for i in fp.GetOnBits(): + yield self.fcat.GetEntryDescription(i) + + def filter(self, molecule: rdkit.Chem.Mol) -> bool: + fps = set(self.get_fingerprint(molecule)) + if fps.difference(self.valid_fps) == set(): + return True + else: + return False + + +class StableRadMoleculeState(MoleculeState): + """ + A State implementation which uses simple transformations (such as adding a bond) to + define a graph of molecules that can be navigated. + + Molecules are stored as rdkit Mol instances, and the rdkit-generated SMILES string + is also stored for efficient hashing. + """ + + def __init__( + self, + molecule: Mol, + builder: any, + force_terminal: bool = False, + smiles: Optional[str] = None, + ) -> None: + super(StableRadMoleculeState, self).__init__( + molecule, builder, force_terminal, smiles + ) + + @collect_metrics + def get_next_actions(self) -> Sequence["StableRadMoleculeState"]: + + logger.debug(f"Getting next actions for {self}") + + result = [] + if not self._forced_terminal: + if self.num_atoms < self.builder.max_atoms: + result.extend( + ( + StableRadMoleculeState(molecule, self.builder) + for molecule in self.builder(self.molecule) + ) + ) + + if self.num_atoms >= self.builder.min_atoms: + result.extend( + ( + StableRadMoleculeState( + radical, self.builder, force_terminal=True + ) + for radical in build_radicals(self.molecule) + ) + ) + + return result + + +def build_radicals(starting_mol): + """Build organic radicals. """ + + generated_smiles = set() + + for i, atom in enumerate(starting_mol.GetAtoms()): + if AddNewAtomsAndBonds._get_free_valence(atom) > 0: + rw_mol = rdkit.Chem.RWMol(starting_mol) + rw_mol.GetAtomWithIdx(i).SetNumRadicalElectrons(1) + + Chem.SanitizeMol(rw_mol) + smiles = Chem.MolToSmiles(rw_mol) + if smiles not in generated_smiles: + # This makes sure the atom ordering is standardized + yield Chem.MolFromSmiles(smiles) + generated_smiles.add(smiles) diff --git a/examples/code_ocean/stable_radical_opt.py b/examples/code_ocean/stable_radical_opt.py new file mode 100644 index 00000000..061283a8 --- /dev/null +++ b/examples/code_ocean/stable_radical_opt.py @@ -0,0 +1,158 @@ +import argparse +import logging +import math +import multiprocessing +import os +import time + +from rlmolecule.sql.run_config import RunConfig + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def run_games(run_config: RunConfig, **kwargs) -> None: + from rlmolecule.alphazero.alphazero import AlphaZero + + from stable_radical_problem import construct_problem + + logger.info("starting run_games script") + + config = run_config.mcts_config + game = AlphaZero( + construct_problem(run_config, **kwargs), + min_reward=config.get("min_reward", 0.0), + pb_c_base=config.get("pb_c_base", 1.0), + pb_c_init=config.get("pb_c_init", 1.25), + dirichlet_noise=config.get("dirichlet_noise", True), + dirichlet_alpha=config.get("dirichlet_alpha", 1.0), + dirichlet_x=config.get("dirichlet_x", 0.25), + # MCTS parameters + ucb_constant=config.get("ucb_constant", math.sqrt(2)), + ) + while True: + path, reward = game.run( + num_mcts_samples=config.get("num_mcts_samples", 50), + timeout=config.get("timeout", None), + max_depth=config.get("max_depth", 1000000), + ) + logger.info( + f"Game Finished -- Reward {reward.raw_reward:.3f} -- Final state {path[-1][0]}" + ) + + +def train_model(run_config: RunConfig, **kwargs) -> None: + from stable_radical_problem import construct_problem + + logger.info("starting train_model script") + + config = run_config.train_config + construct_problem(run_config, **kwargs).train_policy_model( + steps_per_epoch=config.get("steps_per_epoch", 100), + lr=float(config.get("lr", 1e-3)), + epochs=int(float(config.get("epochs", 1e4))), + game_count_delay=config.get("game_count_delay", 20), + verbose=config.get("verbose", 2), + ) + + +def monitor(run_config: RunConfig, **kwargs): + from rlmolecule.sql.tables import GameStore, RewardStore + + from stable_radical_problem import construct_problem + + logger.info("starting monitor script") + problem = construct_problem(run_config, **kwargs) + + while True: + best_reward = ( + problem.session.query(RewardStore) + .filter_by(run_id=problem.run_id) + .order_by(RewardStore.reward.desc()) + .first() + ) + + num_games = ( + problem.session.query(GameStore).filter_by(run_id=problem.run_id).count() + ) + + if best_reward: + logger.info( + f"Best Reward: {best_reward.reward:.3f} for molecule " + f"{best_reward.data['smiles']} with {num_games} games played" + ) + + else: + logger.debug("Monitor script looping, no reward found") + + time.sleep(5) + + +def setup_argparser(): + parser = argparse.ArgumentParser( + description="Optimize stable radicals to work as both the anode" + " and cathode of a redox-flow battery." + ) + + parser.add_argument("--config", type=str, help="Configuration file") + parser.add_argument( + "--train-policy", + action="store_true", + default=False, + help="Train the policy model only (on GPUs)", + ) + parser.add_argument( + "--rollout", + action="store_true", + default=False, + help="Run the game simulations only (on CPUs)", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_argparser() + args = parser.parse_args() + kwargs = vars(args) + + run_config = RunConfig(args.config) + + if args.train_policy: + train_model(run_config, **kwargs) + elif args.rollout: + # make sure the rollouts do not use the GPU + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + run_games(run_config, **kwargs) + # else: + # logger.warning("Must specify either --train-policy or --rollout") + + else: + + # run_games(run_config) + # train_model(run_config) + + jobs = [ + multiprocessing.Process(target=monitor, args=(run_config,), kwargs=kwargs) + ] + jobs[0].start() + time.sleep(1) + + for i in range(34): + jobs += [ + multiprocessing.Process( + target=run_games, args=(run_config,), kwargs=kwargs + ) + ] + + jobs += [ + multiprocessing.Process( + target=train_model, args=(run_config,), kwargs=kwargs + ) + ] + + for job in jobs[1:]: + job.start() + + for job in jobs: + job.join(300) diff --git a/examples/code_ocean/stable_radical_problem.py b/examples/code_ocean/stable_radical_problem.py new file mode 100644 index 00000000..8da613c0 --- /dev/null +++ b/examples/code_ocean/stable_radical_problem.py @@ -0,0 +1,247 @@ +import os +import pathlib +from typing import Dict, Tuple + +import numpy as np +import rdkit +import tensorflow as tf +from alfabet import model as alfabet_model +from rdkit.Chem.rdDistGeom import EmbedMolecule +from rlmolecule.molecule.builder.builder import MoleculeBuilder +from rlmolecule.molecule.molecule_problem import MoleculeTFAlphaZeroProblem +from rlmolecule.molecule.molecule_state import MoleculeState +from rlmolecule.sql.run_config import RunConfig +from rlmolecule.tree_search.metrics import collect_metrics +from rlmolecule.tree_search.reward import RankedRewardFactory + +from bde_utils import prepare_for_bde +from stable_radical_molecule_state import ( + MoleculeBuilderProtectRadical, + MoleculeBuilderWithFingerprint, + StableRadMoleculeState, +) + + +@tf.function(experimental_relax_shapes=True) +def predict(model: "tf.keras.Model", inputs): + return model.predict_step(inputs) + + +def windowed_loss(target: float, desired_range: Tuple[float, float]) -> float: + """ Returns 0 if the molecule is in the middle of the desired range, + scaled loss otherwise. """ + + span = desired_range[1] - desired_range[0] + + lower_lim = desired_range[0] + span / 6 + upper_lim = desired_range[1] - span / 6 + + if target < lower_lim: + return max(1 - 3 * (abs(target - lower_lim) / span), 0) + elif target > upper_lim: + return max(1 - 3 * (abs(target - upper_lim) / span), 0) + else: + return 1 + + +class StableRadOptProblem(MoleculeTFAlphaZeroProblem): + def __init__( + self, + engine: "sqlalchemy.engine.Engine", + builder: "MoleculeBuilder", + stability_model: "tf.keras.Model", + redox_model: "tf.keras.Model", + initial_state: str, + **kwargs + ) -> None: + """A class to estimate the suitability of radical species in redox flow batteries. + + :param engine: A sqlalchemy engine pointing to a suitable database backend + :param builder: A MoleculeBuilder class to handle molecule construction + :param stability_model: A tensorflow model to estimate spin and buried volumes + :param redox_model: A tensorflow model to estimate electron affinity and ionization energies + :param bde_model: A tensorflow model to estimate bond dissociation energies + :param initial_state: The initial starting state for the molecule search. + """ + self.initial_state = initial_state + self.engine = engine + self._builder = builder + self.stability_model = stability_model + self.redox_model = redox_model + super(StableRadOptProblem, self).__init__(engine, builder, **kwargs) + + def get_initial_state(self) -> MoleculeState: + if self.initial_state == "C": + return StableRadMoleculeState(rdkit.Chem.MolFromSmiles("C"), self._builder) + else: + return MoleculeState( + rdkit.Chem.MolFromSmiles(self.initial_state), self._builder + ) + + def get_reward(self, state: MoleculeState) -> Tuple[float, dict]: + + # Make sure the molecule has a 3D representation + try: + molH = rdkit.Chem.AddHs(state.molecule) + assert EmbedMolecule(molH, maxAttempts=30, randomSeed=42) >= 0 + + except (AssertionError, RuntimeError): + return 0.0, {"forced_terminal": False, "smiles": state.smiles} + + policy_inputs = self.get_policy_inputs(state) + + # Node is outside the domain of validity + if (policy_inputs["atom"] == 1).any() | (policy_inputs["bond"] == 1).any(): + return 0.0, {"forced_terminal": False, "smiles": state.smiles} + + if state.forced_terminal: + reward, stats = self.calc_reward(state) + stats.update({"forced_terminal": True, "smiles": state.smiles}) + return reward, stats + + # Reward called on a non-terminal state, likely built into a corner + return 0.0, {"forced_terminal": False, "smiles": state.smiles} + + @collect_metrics + def calc_reward(self, state: MoleculeState) -> Tuple[float, Dict]: + model_inputs = { + key: tf.constant(np.expand_dims(val, 0)) + for key, val in self.get_policy_inputs(state).items() + } + spins, buried_vol = predict(self.stability_model, model_inputs) + + spins = spins.numpy().flatten() + buried_vol = buried_vol.numpy().flatten() + + atom_index = int(spins.argmax()) + max_spin = spins[atom_index] + spin_buried_vol = buried_vol[atom_index] + + atom_type = state.molecule.GetAtomWithIdx(atom_index).GetSymbol() + + ionization_energy, electron_affinity = ( + predict(self.redox_model, model_inputs).numpy().tolist()[0] + ) + + v_diff = ionization_energy - electron_affinity + bde, bde_diff = self.calc_bde(state) + + ea_range = (-0.5, 0.2) + ie_range = (0.5, 1.2) + v_range = (1, 1.7) + bde_range = (60, 80) + + reward = ( + (1 - max_spin) * 50 + + spin_buried_vol + + 100 + * ( + windowed_loss(electron_affinity, ea_range) + + windowed_loss(ionization_energy, ie_range) + + windowed_loss(v_diff, v_range) + + windowed_loss(bde, bde_range) + ) + / 4 + ) + + stats = { + "max_spin": max_spin, + "spin_buried_vol": spin_buried_vol, + "ionization_energy": ionization_energy, + "electron_affinity": electron_affinity, + "bde": bde, + "bde_diff": bde_diff, + } + stats = {key: str(val) for key, val in stats.items()} + + return reward, stats + + def calc_bde(self, state: MoleculeState): + """calculate the X-H bde, and the difference to the next-weakest X-H bde in kcal/mol""" + + bde_inputs = prepare_for_bde(state.molecule) + pred_bdes = alfabet_model.predict( + [bde_inputs.mol_smiles], drop_duplicates=False + ) + pred_bdes = pred_bdes.set_index("bond_index").bde_pred + + bde_radical = pred_bdes.loc[bde_inputs.bond_index] + + if len(bde_inputs.other_h_bonds) == 0: + bde_diff = 30.0 # Just an arbitrary large number + + else: + other_h_bdes = pred_bdes.loc[bde_inputs.other_h_bonds] + bde_diff = (other_h_bdes - bde_radical).min() + + return bde_radical, bde_diff + + +def construct_problem(run_config: RunConfig, **kwargs): + prob_config = run_config.problem_config + + stability_model = tf.keras.models.load_model( + prob_config.get("stability_model"), compile=False + ) + redox_model = tf.keras.models.load_model( + prob_config.get("redox_model"), compile=False + ) + + initial_state = prob_config.get("initial_state", "C") + if initial_state == "C": + builder_class = MoleculeBuilderWithFingerprint + else: + builder_class = MoleculeBuilderProtectRadical + + if "cache_dir" in prob_config: + cache_dir = os.path.join(prob_config["cache_dir"], run_config.run_id) + else: + cache_dir = None + + builder = builder_class( + max_atoms=prob_config.get("max_atoms", 15), + min_atoms=prob_config.get("min_atoms", 4), + try_embedding=prob_config.get("try_embedding", True), + sa_score_threshold=prob_config.get("sa_score_threshold", 3.5), + stereoisomers=prob_config.get("stereoisomers", True), + canonicalize_tautomers=prob_config.get("canonicalize_tautomers", True), + atom_additions=prob_config.get("atom_additions", ("C", "N", "O", "S")), + cache_dir=cache_dir, + num_shards=prob_config.get("num_shards", 1), + parallel=prob_config.get("parallel", True), + ) + + engine = run_config.start_engine() + + run_id = run_config.run_id + train_config = run_config.train_config + reward_factory = RankedRewardFactory( + engine=engine, + run_id=run_id, + reward_buffer_min_size=train_config.get("reward_buffer_min_size", 50), + reward_buffer_max_size=train_config.get("reward_buffer_max_size", 250), + ranked_reward_alpha=train_config.get("ranked_reward_alpha", 0.75), + ) + + problem = StableRadOptProblem( + engine, + builder, + stability_model, + redox_model, + run_id=run_id, + initial_state=initial_state, + reward_class=reward_factory, + features=train_config.get("features", 64), + # Number of attention heads + num_heads=train_config.get("num_heads", 4), + num_messages=train_config.get("num_messages", 3), + max_buffer_size=train_config.get("max_buffer_size", 200), + # Don't start training the model until this many games have occurred + min_buffer_size=train_config.get("min_buffer_size", 15), + batch_size=train_config.get("batch_size", 32), + policy_checkpoint_dir=os.path.join( + train_config.get("policy_checkpoint_dir", "policy_checkpoints"), run_id + ), + ) + + return problem diff --git a/examples/code_ocean/start_postgres.sh b/examples/code_ocean/start_postgres.sh new file mode 100755 index 00000000..511a40a0 --- /dev/null +++ b/examples/code_ocean/start_postgres.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +initdb -D psql_data.db +pg_ctl -D psql_data.db -l postgres.log -o "-i" start +psql -c "CREATE USER example_user WITH PASSWORD 'tmppassword'" postgres +createdb --owner=example_user rl +exit 0 # In case there are errors with the database calls + + +# stop with pg_ctl -D qed_data_psql.db stop \ No newline at end of file diff --git a/examples/code_ocean/submit_local.sh b/examples/code_ocean/submit_local.sh new file mode 100644 index 00000000..80026954 --- /dev/null +++ b/examples/code_ocean/submit_local.sh @@ -0,0 +1,11 @@ +#!/bin/bash +#SBATCH --partition=debug +#SBATCH --account=rlmolecule +#SBATCH --time=1:00:00 +#SBATCH --job-name=test_stable_rad_opt_local +#SBATCH --nodes=1 +#SBATCH --ntasks=1 + +conda activate rlmol + +python stable_radical_opt.py --config=config_local.yaml \ No newline at end of file diff --git a/rlmolecule/molecule/builder/builder.py b/rlmolecule/molecule/builder/builder.py index b87e937a..898781f3 100644 --- a/rlmolecule/molecule/builder/builder.py +++ b/rlmolecule/molecule/builder/builder.py @@ -4,19 +4,21 @@ from abc import ABC, abstractmethod from functools import partial from multiprocessing import Pool -from typing import Iterable, List, Optional, Dict +from typing import Dict, Iterable, List, Optional import numpy as np import rdkit -from diskcache import FanoutCache, Cache +from diskcache import Cache, FanoutCache from rdkit import Chem, RDConfig -from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions +from rdkit.Chem.EnumerateStereoisomers import ( + EnumerateStereoisomers, + StereoEnumerationOptions, +) from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit.Chem.rdDistGeom import EmbedMolecule - from rlmolecule.molecule.builder.gdb_filters import check_all_filters -sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) +sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score")) # noinspection PyUnresolvedReferences import sascorer @@ -30,21 +32,23 @@ from rdkit import RDLogger -RDLogger.DisableLog('rdApp.warning') +RDLogger.DisableLog("rdApp.warning") class MoleculeBuilder: - def __init__(self, - max_atoms: int = 10, - min_atoms: int = 4, - atom_additions: Optional[List] = None, - stereoisomers: bool = False, - canonicalize_tautomers: bool = False, - sa_score_threshold: Optional[float] = None, - try_embedding: bool = False, - cache_dir: Optional[str] = None, - num_shards: int = 1, - parallel: bool = False) -> None: + def __init__( + self, + max_atoms: int = 10, + min_atoms: int = 4, + atom_additions: Optional[List] = None, + stereoisomers: bool = False, + canonicalize_tautomers: bool = False, + sa_score_threshold: Optional[float] = None, + try_embedding: bool = False, + cache_dir: Optional[str] = None, + num_shards: int = 1, + parallel: bool = False, + ) -> None: """A class to build molecules according to a number of different options :param max_atoms: Maximum number of heavy atoms @@ -81,7 +85,9 @@ def __init__(self, AddNewAtomsAndBonds(atom_additions), ] - parallel_stack = [GdbFilter(), ] + parallel_stack = [ + GdbFilter(), + ] if sa_score_threshold is not None: parallel_stack += [SAScoreFilter(sa_score_threshold, min_atoms)] @@ -102,7 +108,8 @@ def __init__(self, if parallel: self.transformation_stack += [ ParallelTransformer(parallel_stack), - UniqueMoleculeFilter()] + UniqueMoleculeFilter(), + ] else: self.transformation_stack += parallel_stack @@ -122,7 +129,7 @@ def __call__(self, parent_molecule: rdkit.Chem.Mol) -> Iterable[rdkit.Chem.Mol]: def __getstate__(self): attributes = self.__dict__ - attributes['cached_call'] = None + attributes["cached_call"] = None return attributes @@ -142,7 +149,9 @@ def __call__(self, inputs: Iterable[rdkit.Chem.Mol]) -> Iterable[rdkit.Chem.Mol] yield from self.call(molecule) -def process_call(molecule: rdkit.Chem.Mol, transformation_stack: List[MoleculeTransformer]) -> List[rdkit.Chem.Mol]: +def process_call( + molecule: rdkit.Chem.Mol, transformation_stack: List[MoleculeTransformer] +) -> List[rdkit.Chem.Mol]: inputs = (molecule,) for transformer in transformation_stack: inputs = transformer(inputs) @@ -150,9 +159,9 @@ def process_call(molecule: rdkit.Chem.Mol, transformation_stack: List[MoleculeTr class ParallelTransformer(BaseTransformer): - def __init__(self, - transformation_stack: List[MoleculeTransformer], - chunk_size: int = 10): + def __init__( + self, transformation_stack: List[MoleculeTransformer], chunk_size: int = 10 + ): self.chunk_size = chunk_size self.transformation_stack = transformation_stack self.pool = Pool() @@ -189,7 +198,7 @@ def __init__(self, atom_additions: Optional[List] = None, **kwargs): if atom_additions is not None: self.atom_additions = atom_additions else: - self.atom_additions = ('C', 'N', 'O') + self.atom_additions = ("C", "N", "O") @staticmethod def sanitize(molecule: rdkit.Chem.Mol) -> Optional[rdkit.Chem.Mol]: @@ -218,24 +227,41 @@ def _get_free_valence(atom) -> int: """ For a given atom, calculate the free valence remaining """ return pt.GetDefaultValence(atom.GetSymbol()) - atom.GetExplicitValence() - def _get_valid_partners(self, starting_mol: rdkit.Chem.Mol, atom: rdkit.Chem.Atom) -> List[int]: + def _get_valid_partners( + self, starting_mol: rdkit.Chem.Mol, atom: rdkit.Chem.Atom + ) -> List[int]: """ For a given atom, return other atoms it can be connected to """ return list( - set(range(starting_mol.GetNumAtoms())) - set((neighbor.GetIdx() for neighbor in atom.GetNeighbors())) - - set(range(atom.GetIdx())) - # Prevent duplicates by only bonding forward - {atom.GetIdx()} | set(np.arange(len(self.atom_additions)) + starting_mol.GetNumAtoms())) - - def _get_valid_bonds(self, starting_mol: rdkit.Chem.Mol, atom1_idx: int, atom2_idx: int) -> range: + set(range(starting_mol.GetNumAtoms())) + - set((neighbor.GetIdx() for neighbor in atom.GetNeighbors())) + - set(range(atom.GetIdx())) + - {atom.GetIdx()} # Prevent duplicates by only bonding forward + | set(np.arange(len(self.atom_additions)) + starting_mol.GetNumAtoms()) + ) + + def _get_valid_bonds( + self, starting_mol: rdkit.Chem.Mol, atom1_idx: int, atom2_idx: int + ) -> range: """ Compare free valences of two atoms to calculate valid bonds """ free_valence_1 = self._get_free_valence(starting_mol.GetAtomWithIdx(atom1_idx)) if atom2_idx < starting_mol.GetNumAtoms(): - free_valence_2 = self._get_free_valence(starting_mol.GetAtomWithIdx(int(atom2_idx))) + free_valence_2 = self._get_free_valence( + starting_mol.GetAtomWithIdx(int(atom2_idx)) + ) else: - free_valence_2 = pt.GetDefaultValence(self.atom_additions[atom2_idx - starting_mol.GetNumAtoms()]) + free_valence_2 = pt.GetDefaultValence( + self.atom_additions[atom2_idx - starting_mol.GetNumAtoms()] + ) return range(min(min(free_valence_1, free_valence_2), 3)) - def _add_bond(self, starting_mol: rdkit.Chem.Mol, atom1_idx: int, atom2_idx: int, bond_type: int) -> Chem.RWMol: + def _add_bond( + self, + starting_mol: rdkit.Chem.Mol, + atom1_idx: int, + atom2_idx: int, + bond_type: int, + ) -> Chem.RWMol: """ Given two atoms and a bond type, execute the addition using rdkit """ num_atom = starting_mol.GetNumAtoms() rw_mol = Chem.RWMol(starting_mol) @@ -279,9 +305,8 @@ def call(self, molecule: rdkit.Chem.Mol) -> Iterable[rdkit.Chem.Mol]: smiles_out = rdkit.Chem.MolToSmiles(out) stereo_count = count_stereocenters(smiles_out) - if stereo_count['atom_unassigned'] != 0: - print(f'{smiles_in}: {smiles_out}') - # if stereo_count['bond_unassigned'] == 0, f'{smiles_in}: {smiles_out}' + if stereo_count["atom_unassigned"] != 0: + logger.debug(f"unassigned stereo in output {smiles_in}: {smiles_out}") yield out @@ -315,7 +340,9 @@ def filter(self, molecule: rdkit.Chem.Mol) -> bool: try: return check_all_filters(molecule) except Exception as ex: - logger.warning(f"Issue with GDBFilter and molecule {Chem.MolToSmiles(molecule)}: {ex}") + logger.warning( + f"Issue with GDBFilter and molecule {Chem.MolToSmiles(molecule)}: {ex}" + ) return False @@ -327,18 +354,34 @@ def count_stereocenters(smiles: str) -> Dict: rdkit.Chem.FindPotentialStereoBonds(mol) stereocenters = rdkit.Chem.FindMolChiralCenters(mol, includeUnassigned=True) - stereobonds = [bond for bond in mol.GetBonds() if bond.GetStereo() is not - rdkit.Chem.rdchem.BondStereo.STEREONONE] - - atom_assigned = len([center for center in stereocenters if center[1] != '?']) - atom_unassigned = len([center for center in stereocenters if center[1] == '?']) + stereobonds = [ + bond + for bond in mol.GetBonds() + if bond.GetStereo() is not rdkit.Chem.rdchem.BondStereo.STEREONONE + ] + + atom_assigned = len([center for center in stereocenters if center[1] != "?"]) + atom_unassigned = len([center for center in stereocenters if center[1] == "?"]) + + bond_assigned = len( + [ + bond + for bond in stereobonds + if bond.GetStereo() is not rdkit.Chem.rdchem.BondStereo.STEREOANY + ] + ) + bond_unassigned = len( + [ + bond + for bond in stereobonds + if bond.GetStereo() is rdkit.Chem.rdchem.BondStereo.STEREOANY + ] + ) - bond_assigned = len([bond for bond in stereobonds if bond.GetStereo() is not - rdkit.Chem.rdchem.BondStereo.STEREOANY]) - bond_unassigned = len([bond for bond in stereobonds if bond.GetStereo() is - rdkit.Chem.rdchem.BondStereo.STEREOANY]) + return { + "atom_assigned": atom_assigned, + "atom_unassigned": atom_unassigned, + "bond_assigned": bond_assigned, + "bond_unassigned": bond_unassigned, + } - return {'atom_assigned': atom_assigned, - 'atom_unassigned': atom_unassigned, - 'bond_assigned': bond_assigned, - 'bond_unassigned': bond_unassigned} diff --git a/rlmolecule/sql/run_config.py b/rlmolecule/sql/run_config.py index d07cd709..ed2824a1 100644 --- a/rlmolecule/sql/run_config.py +++ b/rlmolecule/sql/run_config.py @@ -5,7 +5,7 @@ import yaml from sqlalchemy import create_engine - +from sqlalchemy.pool import NullPool # TODO add the command line args that correspond to the config file options here @@ -15,7 +15,7 @@ def __init__(self, config_file, **kwargs): self.config_map = {} if config_file is not None: - with open(config_file, 'r') as f: + with open(config_file, "r") as f: # self.config_map = yaml.safe_load(f) # expandvars is a neat trick to expand bash variables within the yaml file # from here: https://stackoverflow.com/a/60283894/7483950 @@ -24,13 +24,13 @@ def __init__(self, config_file, **kwargs): # TODO overwrite settings in the config file if they were passed in via kwargs # Settings for setting up scripts to run everything # self.run_config = self.config_map.get('run_config',{}) - self.run_id = self.config_map.get('run_id', 'test') + self.run_id = self.config_map.get("run_id", "test") # Settings specific to the problem at hand - self.problem_config = self.config_map.get('problem_config', {}) + self.problem_config = self.config_map.get("problem_config", {}) # Settings for training the policy model - self.train_config = self.config_map.get('train_config', {}) - self.mcts_config = self.config_map.get('mcts_config', {}) + self.train_config = self.config_map.get("train_config", {}) + self.mcts_config = self.config_map.get("mcts_config", {}) # def load_config_file(config_file): # with open(config_file, 'r') as conf: @@ -39,7 +39,9 @@ def __init__(self, config_file, **kwargs): # return config_map def start_engine(self): - self.engine = RunConfig.start_db_engine(**self.config_map.get('sql_database', {})) + self.engine = RunConfig.start_db_engine( + **self.config_map.get("sql_database", {}) + ) return self.engine @staticmethod @@ -47,43 +49,49 @@ def start_db_engine(**kwargs): """ Connect to the sql database that will store the game and reward data used by the policy model and game runner """ - drivername = kwargs.get('drivername', 'sqlite') - db_file = kwargs.get('db_file', 'game_data.db') - if drivername == 'sqlite': + drivername = kwargs.get("drivername", "sqlite") + db_file = kwargs.get("db_file", "game_data.db") + if drivername == "sqlite": engine = create_engine( - f'sqlite:///{db_file}', + f"sqlite:///{db_file}", # The 'check_same_thread' option only works for sqlite - connect_args={'check_same_thread': False}, - execution_options={"isolation_level": "AUTOCOMMIT"}) + connect_args={"check_same_thread": False}, + execution_options={"isolation_level": "AUTOCOMMIT"}, + poolclass=NullPool, + ) else: engine = RunConfig.start_server_db_engine(**kwargs) return engine @staticmethod - def start_server_db_engine(drivername="postgresql+psycopg2", - dbname='bde', - port=None, - host=None, - user=None, - passwd_file=None, - passwd=None, - **kwargs): + def start_server_db_engine( + drivername="postgresql+psycopg2", + dbname="bde", + port=None, + host=None, + user=None, + passwd_file=None, + passwd=None, + **kwargs, + ): # By default, use the host defined in the environment - host = os.getenv('DB_HOST', host) + host = os.getenv("DB_HOST", host) # add the ':' to separate host from port port = ":" + str(port) if port is not None else "" if passwd_file is not None: # read the password from a file - with open(passwd_file, 'r') as f: + with open(passwd_file, "r") as f: passwd = f.read().strip() # add the ':' to separate user from passwd passwd = ":" + str(passwd) if passwd is not None else "" - engine_str = f'{drivername}://{user}{passwd}@{host}{port}/{dbname}' + engine_str = f"{drivername}://{user}{passwd}@{host}{port}/{dbname}" # don't print since it has the user's password # print(f'connecting to database using: {engine_str}') - engine = create_engine(engine_str, execution_options={"isolation_level": "AUTOCOMMIT"}) + engine = create_engine( + engine_str, execution_options={"isolation_level": "AUTOCOMMIT"} + ) return engine