Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REST scaling for protein-ligand systems with the CLI #1197

Draft
wants to merge 8 commits into
base: 0.10.x
Choose a base branch
from
15 changes: 13 additions & 2 deletions examples/new-cli/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ forcefield_files:
# e.g. one of ['openff-2.0.0', 'gaff-2.11']
small_molecule_forcefield: openff-2.0.0

# Solvent model
solvent_model: "tip3p"

#
# Simulation conditions
#
Expand All @@ -37,6 +40,7 @@ temperature: 300 # kelvin
timestep: 4 # femtoseconds
ionic_strength: 0.15 # molar


# Atom mapping specification
atom_expression:
- IntType
Expand Down Expand Up @@ -79,5 +83,12 @@ phases:
use_given_geometries: true
given_geometries_tolerance: 0.4 # angstroms

# Solvent model
solvent_model: "tip3p"
# Realtime analysis frequency
offline-freq: 100

# Advanced (optional) -- Specify HTF (HybridTopologyFactory or RESTCapableHybridTopologyFactory)
#hybrid_topology_factory: RESTCapableHybridTopologyFactory
# REST specific parameters (optional)
#max_temperature: 600 # Kelvin
#rest_radius: 0.3 # nanometers
#w_lifting: 0.3 # nanometers
13 changes: 4 additions & 9 deletions perses/annihilation/relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,14 +2833,7 @@ class RESTCapableHybridTopologyFactory(HybridTopologyFactory):
_new_system_exceptions : dict of key: tuple of ints, value: list of floats
key: new system indices of the atoms in the exception, value: chargeProd (units of the proton charge squared), sigma (nm), and epsilon (kJ/mol) for the exception
"""

# Constants copied from: https://github.com/openmm/openmm/blob/master/platforms/reference/include/SimTKOpenMMRealType.h#L89. These will be imported directly once we have addresssed https://github.com/choderalab/openmmtools/issues/522
M_PI = 3.14159265358979323846
E_CHARGE = (1.602176634e-19)
AVOGADRO = (6.02214076e23)
EPSILON0 = (1e-6*8.8541878128e-12/(E_CHARGE*E_CHARGE*AVOGADRO))
ONE_4PI_EPS0 = (1/(4*M_PI*EPSILON0))
#from openmmtools.constants import ONE_4PI_EPS0
from openmmtools.constants import ONE_4PI_EPS0

_default_electrostatics_expression_list = [

Expand Down Expand Up @@ -3220,8 +3213,10 @@ def _generate_rest_region(self):

# Retrieve neighboring atoms within self._rest_radius nm of the query atoms
traj = md.Trajectory(np.array(self._hybrid_positions), self._hybrid_topology)
solute_atoms = list(traj.topology.select("is_protein"))
solute_atoms = list(traj.topology.select("not water"))
Copy link
Contributor

@zhang-ivy zhang-ivy Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right way to change this function to work for small molecules. Specifying "not water" will include solvent ions, which we definitely don't want to include as part of the REST region.

We first need to determine what the desired behavior is. The current behavior of this function for small molecules is to automatically include all small molecule atoms in the REST region, and also include protein atoms within X angstroms (where X is user specified) of the small molecule. I think we discussed that you would run some experiments testing this approach, though this will require finding transformations where REST is known to improve sampling / speed up convergence.

One worry I have with this approach is that including all small molecule atoms in the REST region is not the best practices approach (Schrodinger papers typically choose the alchemically-changing (aka unique old and new) atoms and some nearby protein atoms) and there may be risk of small molecule unbinding. Therefore, I think the desired approach should be to modify these lines:

        # Retrieve the residue index of the residue that is being alchemified based on the first unique old atom
        atom_index = list(self._atom_classes['unique_old_atoms'])[0]
        hybrid_topology_atoms = list(self._hybrid_topology.atoms)
        alchemical_residue_index = hybrid_topology_atoms[atom_index].residue.index
        
       # Retrieve indices of all atoms in the alchemical residue
        hybrid_topology_residues = list(self._hybrid_topology.residues)
        mutated_res = hybrid_topology_residues[alchemical_residue_index]
        query_indices = [atom.index for atom in mutated_res.atoms]
        _logger.info(f"Generating rest_region with query indices: {query_indices}")

such that for small molecule transformations only, query_indices is all unique old/new atoms/core atoms. This may requiring introducing an argument to the function like is_small_molecule that tells the function whether its dealing with a small molecule transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the things that we noticed is that the solute_atoms after the selection is empty for solvent phase. But we probably want to have the option to do REST scaling in solvent phase as well, such that the neighboring part is not needed, but the molecule atoms are still scaled according to the REST protocol.

Copy link
Contributor

@zhang-ivy zhang-ivy Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes from discussion:

  • As noted above, there should be three code paths: 1) protein mutation, 2) small molecule complex phase, 3) small molecule solvent phase
  • Double check that "is protein" mdtraj selection only contains protein atoms (and does not interpret small molecule atoms as protein atoms). If indeed "is protein" only corresponds to protein atoms, the way that the function is written will fail to include small molecule atoms in the rest_atoms_all. Will need to make sure the relevant small molecule atoms are included in this.

rest_atoms = list(md.compute_neighbors(traj, self._rest_radius.value_in_unit_system(unit.md_unit_system), query_indices, haystack_indices=solute_atoms)[0])
# rest_atoms = list(
# md.compute_neighbors(traj, self._rest_radius.value_in_unit_system(unit.md_unit_system), query_indices)[0])

# Retrieve full residues for all atoms in rest region
residues = [atom.residue.index for atom in traj.topology.atoms if atom.index in rest_atoms]
Expand Down
46 changes: 25 additions & 21 deletions perses/app/setup_relative_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
import os
import sys
import simtk.unit as unit
from openmm import unit
import logging
import warnings
from cloudpathlib import AnyPath
Expand Down Expand Up @@ -469,26 +469,28 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True):
else:
measure_shadow_work = False
_logger.info(f"\tno measure_shadow_work specified: defaulting to False.")
if isinstance(setup_options['pressure'], (float, int)):
pressure = setup_options['pressure'] * unit.atmosphere
else:
pressure = setup_options['pressure']
if isinstance(setup_options['temperature'], (float, int)):
temperature = setup_options['temperature'] * unit.kelvin
else:
temperature = setup_options['temperature']
if isinstance(setup_options['solvent_padding'], (float, int)):
solvent_padding_angstroms = setup_options['solvent_padding'] * unit.angstrom
else:
solvent_padding_angstroms = setup_options['solvent_padding']
if isinstance(setup_options['ionic_strength'], (float, int)):
ionic_strength = setup_options['ionic_strength'] * unit.molar
else:
ionic_strength = setup_options['ionic_strength']
# Read simulation options/parameters and assign units if needed
pressure = setup_options['pressure']
temperature = setup_options['temperature']
solvent_padding_angstroms = setup_options['solvent_padding']
ionic_strength = setup_options['ionic_strength']
max_temperature = setup_options.get('max_temperature')
if isinstance(pressure, (float, int)):
pressure *= unit.atmosphere
if isinstance(temperature, (float, int)):
temperature *= unit.kelvin
if isinstance(solvent_padding_angstroms, (float, int)):
solvent_padding_angstroms *= unit.angstrom
if isinstance(ionic_strength, (float, int)):
ionic_strength *= unit.molar
if isinstance(max_temperature, (float, int)):
max_temperature *= unit.kelvin

_logger.info(f"\tsetting pressure: {pressure}.")
_logger.info(f"\tsetting temperature: {temperature}.")
_logger.info(f"\tsetting temperature: {temperature}K.")
_logger.info(f"\tsetting solvent padding: {solvent_padding_angstroms}A.")
_logger.info(f"\tsetting ionic strength: {ionic_strength}M.")
_logger.info(f"\tsetting max temperature: {max_temperature}K.")

setup_pickle_file = setup_options['save_setup_pickle_as'] if 'save_setup_pickle_as' in list(setup_options) else None
_logger.info(f"\tsetup pickle file: {setup_pickle_file}")
Expand Down Expand Up @@ -723,7 +725,7 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True):
hybrid_factory=htf[phase], online_analysis_interval=setup_options['offline-freq'],
)
hss[phase].setup(n_states=n_states, temperature=temperature, storage_file=reporter,
endstates=endstates)
endstates=endstates, t_max=max_temperature)
# We need to specify contexts AFTER setup
hss[phase].energy_context_cache = energy_context_cache
hss[phase].sampler_context_cache = sampler_context_cache
Expand Down Expand Up @@ -1106,11 +1108,13 @@ def _generate_htf(phase: str, topology_proposal_dictionary: dict, setup_options:
# Add/use specified REST HTF parameters if present
rest_specific_options = dict()
try:
rest_specific_options.update({'rest_radius': setup_options['rest_radius']})
rest_radius = setup_options['rest_radius'] * unit.nanometer
rest_specific_options.update({'rest_radius': rest_radius})
except KeyError:
_logger.info("'rest_radius' not specified. Using default value.")
try:
rest_specific_options.update({'w_lifting': setup_options['w_lifting']})
w_lifting = setup_options['w_lifting'] * unit.nanometer
rest_specific_options.update({'w_lifting': w_lifting})
except KeyError:
_logger.info("'w_lifting' not specified. Using default value.")

Expand Down
42 changes: 39 additions & 3 deletions perses/samplers/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,40 @@ def __init__(self, *args, hybrid_factory=None, **kwargs):
# TODO: Should this overload the create() method from parent instead of being setup()?
def setup(self, n_states, temperature, storage_file, minimisation_steps=100,
n_replicas=None, lambda_schedule=None,
lambda_protocol=None, endstates=True, t_max=300 * unit.kelvin):

lambda_protocol=None, endstates=True, t_max=None):
"""
Set up the simulation with the specified parameters.

Parameters:
-----------
n_states : int
The number of alchemical states to simulate.
temperature : openmm.unit.Quantity
The temperature of the simulation in Kelvin.
storage_file : str
The path to the storage file to store the simulation results.
minimisation_steps : int, optional
The number of minimisation steps to perform before simulation. Default is 100.
n_replicas : int, optional
The number of replicas for replica exchange. If not specified, it will be set to `n_states`.
lambda_schedule : array-like, optional
The schedule of lambda values for the alchemical states. Default is a linear schedule from 0 to 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to something like "Default is None, in which case a linear schedule from 0 to 1 will be used"

lambda_protocol : object, optional
The lambda protocol object that defines the alchemical transformation protocol. Default is None.
endstates : bool, optional
Whether to generate unsampled endstates. Default is True.
t_max : openmm.unit.Quantity, optional
The maximum temperature for REST scaling. Default is None.

Raises:
-------
ValueError
If the hybrid factory name is not supported.

Returns:
--------
None
"""
from perses.dispersed import feptasks

# Retrieve class name, hybrid system, and hybrid positions
Expand All @@ -50,11 +82,15 @@ def setup(self, n_states, temperature, storage_file, minimisation_steps=100,
lambda_zero_alchemical_state = RESTCapableRelativeAlchemicalState.from_system(hybrid_system)
lambda_protocol = RESTCapableLambdaProtocol() if lambda_protocol is None else lambda_protocol

# Default to current temperature if t_max is not specified (no REST scaling)
if t_max is None:
t_max = temperature

# Set beta_0 and beta_m
beta_0 = 1 / (kB * temperature)
beta_m = 1 / (kB * t_max)
else:
raise Exception(f"{factory_name} not supported")
raise ValueError(f"{factory_name} not supported")

# Create reference compound thermodynamic state
thermostate = ThermodynamicState(hybrid_system, temperature=temperature)
Expand Down