Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff to pre-commit hooks #573

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,26 @@ ci:
skip: []
submodules: false
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.7
hooks:
# Run the linter.
- id: ruff
args: [--line-length=80, --fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: 'setup.cfg'
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
args: [--line-length=80]
exclude: 'setup.cfg|foyer/tests/files/.*'
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: [--profile=black, --line-length=80]
- repo: https://github.com/pycqa/pydocstyle
rev: '6.3.0'
hooks:
- id: pydocstyle
exclude: ^(foyer/tests/|docs/|devtools/|setup.py)
args: [--convention=numpy]
exclude: "foyer/tests/files/.*"
5 changes: 2 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import pathlib
import sys

import sphinx_rtd_theme

sys.path.insert(0, os.path.abspath("../.."))
sys.path.insert(0, os.path.abspath("sphinxext"))

base_path = pathlib.Path(__file__).parent
os.system("python {} --name".format((base_path / "../../setup.py").resolve()))


import foyer

# -- Project information -----------------------------------------------------

project = "foyer"
Expand Down Expand Up @@ -147,7 +147,6 @@
# a list of builtin themes.
#
# html_theme = 'alabaster'
import sphinx_rtd_theme

html_theme = "sphinx_rtd_theme"
hhtml_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
Expand Down
16 changes: 4 additions & 12 deletions foyer/atomtyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def find_atomtypes(structure, forcefield, max_iter=10):
topology_graph = TopologyGraph.from_gmso_topology(structure)

if isinstance(forcefield, Forcefield):
atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield(
forcefield
)
atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield(forcefield)
elif isinstance(forcefield, AtomTypingRulesProvider):
atomtype_rules = forcefield
else:
Expand Down Expand Up @@ -110,9 +108,7 @@ def find_atomtypes(structure, forcefield, max_iter=10):
atomic_number = atom_data.atomic_number
atomic_symbol = atom_data.element
try:
element_from_num = ele.element_from_atomic_number(
atomic_number
).symbol
element_from_num = ele.element_from_atomic_number(atomic_number).symbol
element_from_sym = ele.element_from_symbol(atomic_symbol).symbol
assert element_from_num == element_from_sym
system_elements.add(element_from_num)
Expand Down Expand Up @@ -210,13 +206,9 @@ def _iterate_rules(rules, topology_graph, typemap, max_iter):

def _resolve_atomtypes(topology_graph, typemap):
"""Determine the final atomtypes from the white- and blacklists."""
atoms = {
atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)
}
atoms = {atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)}
for atom_id, atom in typemap.items():
atomtype = [
rule_name for rule_name in atom["whitelist"] - atom["blacklist"]
]
atomtype = [rule_name for rule_name in atom["whitelist"] - atom["blacklist"]]
if len(atomtype) == 1:
atom["atomtype"] = atomtype[0]
elif len(atomtype) > 1:
Expand Down
118 changes: 29 additions & 89 deletions foyer/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None):
return _topology_from_parmed(non_omm_topology, non_element_types)
elif has_mbuild:
mb = import_("mbuild")
if (non_omm_topology, mb.Compound):
if all([non_omm_topology, mb.Compound]):
pmd_comp_struct = non_omm_topology.to_parmed(residues=residues)
return _topology_from_parmed(pmd_comp_struct, non_element_types)
else:
Expand All @@ -162,16 +162,12 @@ def generate_topology(non_omm_topology, non_element_types=None, residues=None):
def _structure_from_residue(residue, parent_structure):
"""Convert a ParmEd Residue to an equivalent Structure."""
structure = pmd.Structure()
orig_to_copy = (
dict()
) # Clone a lot of atoms to avoid any of parmed's tracking
orig_to_copy = dict() # Clone a lot of atoms to avoid any of parmed's tracking
for atom in residue.atoms:
new_atom = copy(atom)
new_atom._idx = atom.idx
orig_to_copy[atom] = new_atom
structure.add_atom(
new_atom, resname=residue.name, resnum=residue.number
)
structure.add_atom(new_atom, resname=residue.name, resnum=residue.number)

for bond in parent_structure.bonds:
if bond.atom1 in residue.atoms and bond.atom2 in residue.atoms:
Expand All @@ -198,10 +194,7 @@ def _topology_from_parmed(structure, non_element_types):
if pmd_atom.name in non_element_types:
element = non_element_types[pmd_atom.name]
else:
if (
isinstance(pmd_atom.atomic_number, int)
and pmd_atom.atomic_number != 0
):
if isinstance(pmd_atom.atomic_number, int) and pmd_atom.atomic_number != 0:
element = elem.Element.getByAtomicNumber(pmd_atom.atomic_number)
else:
element = elem.Element.getBySymbol(pmd_atom.name)
Expand All @@ -221,9 +214,7 @@ def _topology_from_parmed(structure, non_element_types):
topology.addBond(atom1, atom2)
atom1.bond_partners.append(atom2)
atom2.bond_partners.append(atom1)
if structure.box_vectors and np.any(
[x._value for x in structure.box_vectors]
):
if structure.box_vectors and np.any([x._value for x in structure.box_vectors]):
topology.setPeriodicBoxVectors(structure.box_vectors)

positions = structure.positions
Expand Down Expand Up @@ -293,9 +284,7 @@ def _unwrap_typemap(structure, residue_map):
for res_ref, val in residue_map.items():
if id(res.name) == id(res_ref):
for i, atom in enumerate(res.atoms):
master_typemap[int(atom.idx)]["atomtype"] = val[i][
"atomtype"
]
master_typemap[int(atom.idx)]["atomtype"] = val[i]["atomtype"]
return master_typemap


Expand Down Expand Up @@ -325,9 +314,7 @@ def _separate_urey_bradleys(system, topology):
) not in bonds:
ub_force.addBond(*force.getBondParameters(bond_idx))
else:
harmonic_bond_force.addBond(
*force.getBondParameters(bond_idx)
)
harmonic_bond_force.addBond(*force.getBondParameters(bond_idx))
system.removeForce(force_idx)

system.addForce(harmonic_bond_force)
Expand Down Expand Up @@ -499,9 +486,7 @@ class Forcefield(app.ForceField):

"""

def __init__(
self, forcefield_files=None, name=None, validation=True, debug=False
):
def __init__(self, forcefield_files=None, name=None, validation=True, debug=False):
self.atomTypeDefinitions = dict()
self.atomTypeOverrides = dict()
self.atomTypeDesc = dict()
Expand Down Expand Up @@ -539,13 +524,9 @@ def __init__(
if len(preprocessed_files) == 1:
self._version = self._parse_version_number(preprocessed_files[0])
self._name = self._parse_name(preprocessed_files[0])
self._combining_rule = self._parse_combining_rule(
preprocessed_files[0]
)
self._combining_rule = self._parse_combining_rule(preprocessed_files[0])
elif len(preprocessed_files) > 1:
self._version = [
self._parse_version_number(f) for f in preprocessed_files
]
self._version = [self._parse_version_number(f) for f in preprocessed_files]
self._name = [self._parse_name(f) for f in preprocessed_files]
self._combining_rule = [
self._parse_combining_rule(f) for f in preprocessed_files
Expand Down Expand Up @@ -639,9 +620,7 @@ def _parse_name(self, forcefield_file):
try:
return root.attrib["name"]
except KeyError:
warnings.warn(
"No force field name found in force field XML file."
)
warnings.warn("No force field name found in force field XML file.")
return None

def _parse_combining_rule(self, forcefield_file):
Expand All @@ -651,9 +630,7 @@ def _parse_combining_rule(self, forcefield_file):
try:
return root.attrib["combining_rule"]
except KeyError:
warnings.warn(
"No combining rule found in force field XML file."
)
warnings.warn("No combining rule found in force field XML file.")
return "lorentz"

def _create_element(self, element, mass):
Expand All @@ -679,9 +656,7 @@ def registerAtomType(self, parameters):
"""Register a new atom type."""
name = parameters["name"]
if name in self._atomTypes:
raise ValueError(
"Found multiple definitions for atom type: " + name
)
raise ValueError("Found multiple definitions for atom type: " + name)
atom_class = parameters["class"]
mass = _convertParameterToNumber(parameters["mass"])
element = None
Expand Down Expand Up @@ -846,10 +821,7 @@ def run_atomtyping(self, structure, use_residue_map=True, **kwargs):

# Need to call this only once and store results for later id() comparisons
for res_id, res in enumerate(structure.residues):
if (
structure.residues[res_id].name
not in residue_map.keys()
):
if structure.residues[res_id].name not in residue_map.keys():
tmp_res = _structure_from_residue(res, structure)
typemap = find_atomtypes(tmp_res, forcefield=self)
residue_map[res.name] = typemap
Expand Down Expand Up @@ -877,9 +849,7 @@ def parametrize_system(
**kwargs,
):
"""Create system based on resulting typemapping."""
topology, positions = _topology_from_parmed(
structure, self.non_element_types
)
topology, positions = _topology_from_parmed(structure, self.non_element_types)

system = self.createSystem(topology, *args, **kwargs)

Expand Down Expand Up @@ -918,9 +888,7 @@ def parametrize_system(
)

if self.combining_rule == "geometric":
self._patch_parmed_adjusts(
structure, combining_rule=self.combining_rule
)
self._patch_parmed_adjusts(structure, combining_rule=self.combining_rule)

total_charge = sum([atom.charge for atom in structure.atoms])
if not np.allclose(total_charge, 0):
Expand Down Expand Up @@ -1032,9 +1000,7 @@ def createSystem(
elem.hydrogen,
None,
):
transfer_mass = hydrogenMass - sys.getParticleMass(
atom2.index
)
transfer_mass = hydrogenMass - sys.getParticleMass(atom2.index)
sys.setParticleMass(atom2.index, hydrogenMass)
mass = sys.getParticleMass(atom1.index) - transfer_mass
sys.setParticleMass(atom1.index, mass)
Expand Down Expand Up @@ -1091,9 +1057,7 @@ def createSystem(
bonded_to = data.bondedToAtom[atom]
if len(bonded_to) > 2:
for subset in itertools.combinations(bonded_to, 3):
data.impropers.append(
(atom, subset[0], subset[1], subset[2])
)
data.impropers.append((atom, subset[0], subset[1], subset[2]))

# Identify bonds that should be implemented with constraints
if constraints == AllBonds or constraints == HAngles:
Expand Down Expand Up @@ -1188,15 +1152,9 @@ def createSystem(
site.originWeights[1],
site.originWeights[2],
),
mm.Vec3(
site.xWeights[0], site.xWeights[1], site.xWeights[2]
),
mm.Vec3(
site.yWeights[0], site.yWeights[1], site.yWeights[2]
),
mm.Vec3(
site.localPos[0], site.localPos[1], site.localPos[2]
),
mm.Vec3(site.xWeights[0], site.xWeights[1], site.xWeights[2]),
mm.Vec3(site.yWeights[0], site.yWeights[1], site.yWeights[2]),
mm.Vec3(site.localPos[0], site.localPos[1], site.localPos[2]),
)
sys.setVirtualSite(index, local_coord_site)

Expand Down Expand Up @@ -1263,9 +1221,7 @@ def _write_references_to_file(self, atom_types, references_file):
for atomtype, dois in atomtype_references.items():
for doi in dois:
unique_references[doi].append(atomtype)
unique_references = collections.OrderedDict(
sorted(unique_references.items())
)
unique_references = collections.OrderedDict(sorted(unique_references.items()))
with open(references_file, "w") as f:
for doi, atomtypes in unique_references.items():
url = "http://api.crossref.org/works/{}/transform/application/x-bibtex".format(
Expand Down Expand Up @@ -1338,11 +1294,7 @@ def get_parameters(self, group, key, keys_are_atom_classes=False):
if group not in param_extractors:
raise ValueError(f"Cannot extract parameters for {group}")

key = (
[key]
if isinstance(key, str) or not isinstance(key, Iterable)
else key
)
key = [key] if isinstance(key, str) or not isinstance(key, Iterable) else key

validate_type(key, str)

Expand All @@ -1367,18 +1319,14 @@ def _extract_non_bonded_params(self, atom_type):

atom_type = atom_type[0]

non_bonded_forces_gen = self.get_generator(
ff=self, gen_type=NonbondedGenerator
)
non_bonded_forces_gen = self.get_generator(ff=self, gen_type=NonbondedGenerator)

non_bonded_params = non_bonded_forces_gen.params.paramsForType

try:
return non_bonded_params[atom_type]
except KeyError:
raise MissingParametersError(
f"Missing parameters for atom {atom_type}"
)
raise MissingParametersError(f"Missing parameters for atom {atom_type}")

def _extract_harmonic_bond_params(self, atom_types):
"""Return parameters for a specific HarmonicBondForce between atom types."""
Expand Down Expand Up @@ -1548,9 +1496,7 @@ def _extract_rb_proper_params(self, atom_types):
f"be extracted for four atoms. Provided {len(atom_types)}"
)

rb_torsion_force_gen = self.get_generator(
ff=self, gen_type=RBTorsionGenerator
)
rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator)

wildcard = self._atomClasses[""]
(
Expand Down Expand Up @@ -1600,9 +1546,7 @@ def _extract_rb_improper_params(self, atom_types):
f"be extracted for four atoms. Provided {len(atom_types)}"
)

rb_torsion_force_gen = self.get_generator(
ff=self, gen_type=RBTorsionGenerator
)
rb_torsion_force_gen = self.get_generator(ff=self, gen_type=RBTorsionGenerator)

match = self._match_impropers(atom_types, rb_torsion_force_gen)

Expand All @@ -1622,9 +1566,7 @@ def map_atom_classes_to_types(self, atom_classes_keys, strict=False):
# When to do this substitution with wildcards?
substitution = self._atomClasses.get(key)
if not substitution:
raise ValueError(
f"Atom class {key} is missing from the Forcefield"
)
raise ValueError(f"Atom class {key} is missing from the Forcefield")
atom_type_keys.append(next(iter(substitution)))

return atom_type_keys
Expand Down Expand Up @@ -1715,9 +1657,7 @@ def get_generator(ff, gen_type):
@staticmethod
def substitute_wildcards(atom_types, wildcard):
"""Return possible wildcard options."""
return tuple(
atom_type or next(iter(wildcard)) for atom_type in atom_types
)
return tuple(atom_type or next(iter(wildcard)) for atom_type in atom_types)


pmd.Structure.write_foyer = write_foyer
Loading
Loading