Skip to content

Commit

Permalink
Merge pull request #24 from Caisusandy/main
Browse files Browse the repository at this point in the history
Support mdtraj 1.10
  • Loading branch information
bojunliu0818 authored Jul 18, 2024
2 parents f3457fe + 6042e35 commit c2efce0
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jobs:
run: |
# conda install -yq -c ${CONDA_PREFIX}/conda-bld/ msmbuilder2022
pip install pytest==8.0.2
conda install scipy=1.13.1
conda install -yq numdifftools hmmlearn
mkdir ../../pkgs
cp -r msmbuilder/tests ../../pkgs
Expand Down
3 changes: 1 addition & 2 deletions msmbuilder/decomposition/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import hashlib
import itertools
import numpy as np
from six.moves import xrange


def iterate_tracker(maxiter, max_nc, verbose=False):
Expand All @@ -12,7 +11,7 @@ def iterate_tracker(maxiter, max_nc, verbose=False):
last_hash_count = 0
arr = yield

for i in xrange(maxiter):
for i in range(maxiter):
arr = yield i
if arr is not None:
hsh = hashlib.sha1(arr.view(np.uint8)).hexdigest()
Expand Down
6 changes: 3 additions & 3 deletions msmbuilder/featurizer/multichain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import mdtraj as md
from mdtraj.utils import ensure_type
from mdtraj.utils.six import string_types
import numpy as np
import itertools
import warnings
Expand Down Expand Up @@ -170,14 +169,15 @@ def __init__(self, protein_chain='auto', ligand_chain='auto',

def _get_contact_pairs(self, contacts):
if self.scheme=='ca':
if not any(a for a in self.reference_frame.top.chain(ligand_chain).atoms
# possible error here with "ligand_chain" from no where, change to self.ligand_chain
if not any(a for a in self.reference_frame.top.chain(self.ligand_chain).atoms
if a.name.lower() == 'ca'):
raise ValueError("Bad scheme: the ligand has no alpha carbons")

# this is really similar to mdtraj/contact.py, but ensures that
# md.compute_contacts is always seeing an array of exactly the
# contacts we want to specify
if isinstance(contacts, string_types):
if isinstance(contacts, str):
if contacts.lower() != 'all':
raise ValueError('({}) is not a valid contacts specifier'.format(contacts.lower()))

Expand Down
4 changes: 1 addition & 3 deletions msmbuilder/tpt/committor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from __future__ import print_function, division, absolute_import
import numpy as np

from mdtraj.utils.six.moves import xrange

__all__ = ['committors', 'conditional_committors',
'_committors', '_conditional_committors']

Expand Down Expand Up @@ -190,7 +188,7 @@ def _conditional_committors(source, sink, waypoint, tprob):
# permute the transition matrix into cannonical form - send waypoint the the
# last row, and source + sink to the end after that
Bsink_indices = [source, sink, waypoint]
perm = np.array([i for i in xrange(n_states) if i not in Bsink_indices],
perm = np.array([i for i in range(n_states) if i not in Bsink_indices],
dtype=int)
perm = np.concatenate([perm, Bsink_indices])
permuted_tprob = tprob[perm, :][:, perm]
Expand Down
5 changes: 2 additions & 3 deletions msmbuilder/tpt/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from . import committors, conditional_committors

from mdtraj.utils.six.moves import xrange
import itertools

__all__ = ['fraction_visited', 'hub_scores']
Expand Down Expand Up @@ -115,15 +114,15 @@ def hub_scores(msm, waypoints=None):
if isinstance(waypoints, int):
waypoints = [waypoints]
elif waypoints is None:
waypoints = xrange(n_states)
waypoints = range(n_states)
elif not (isinstance(waypoints, list) or
isinstance(waypoints, np.ndarray)):
raise ValueError("waypoints (%s) must be an int, a list, or None" %
str(waypoints))

hub_scores = []
for waypoint in waypoints:
other_states = (i for i in xrange(n_states) if i != waypoint)
other_states = (i for i in range(n_states) if i != waypoint)

# calculate the hub score for this waypoint
hub_score = 0.0
Expand Down
3 changes: 1 addition & 2 deletions msmbuilder/tpt/mfpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import print_function, division, absolute_import
import numpy as np
import scipy
from mdtraj.utils.six.moves import xrange
import copy
from msmbuilder.msm.core import _solve_msm_eigensystem

Expand Down Expand Up @@ -214,7 +213,7 @@ def _mfpts(tprob, populations, sinks, lag_time):

# mfpt[i,j] = (fund_matrix[j,j] - fund_matrix[i,j]) / populations[j]
mfpts = fund_matrix * -1
for j in xrange(n_states):
for j in range(n_states):
mfpts[:, j] += fund_matrix[j, j]
mfpts[:, j] /= populations[j]

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
},
install_requires=[
'numpy',
'mdtraj==1.9.9',
'mdtraj',
'scikit-learn',
'pandas',
'fastcluster',
Expand All @@ -220,7 +220,7 @@
'tables',
'numpydoc',
'six',
'scipy',
'scipy<=1.13.1',
'pyhmc @ git+https://github.com/bojunliu0818/pyhmc.git@bojunliu0818-dev',
],
entry_points={'console_scripts':
Expand Down

0 comments on commit c2efce0

Please sign in to comment.