diff --git a/gmso/abc/abstract_site.py b/gmso/abc/abstract_site.py index 86b3ae944..399a2b6dc 100644 --- a/gmso/abc/abstract_site.py +++ b/gmso/abc/abstract_site.py @@ -1,14 +1,13 @@ """Basic interaction site in GMSO that all other sites will derive from.""" import warnings -from typing import Any, ClassVar, NamedTuple, Optional, Sequence, TypeVar, Union +from typing import Any, ClassVar, Optional, Sequence, TypeVar, Union import numpy as np import unyt as u from pydantic import ( ConfigDict, Field, - StrictInt, StrictStr, field_serializer, field_validator, @@ -20,8 +19,124 @@ from gmso.exceptions import GMSOError PositionType = Union[Sequence[float], np.ndarray, u.unyt_array] -MoleculeType = NamedTuple("Molecule", name=StrictStr, number=StrictInt) -ResidueType = NamedTuple("Residue", name=StrictStr, number=StrictInt) + + +class Molecule(GMSOBase): + def __repr__(self): + return ( + f"Molecule(name={self.name}, residue={self.residue}, isrigid={self.isrigid}" + ) + + __iterable_attributes__: ClassVar[set] = { + "name", + "number", + "isrigid", + } + + __base_doc__: ClassVar[str] = "Molecule label for interaction sites." + + name_: str = Field( + "", + validate_default=True, + description="Name of the molecule", + alias="name", + ) + number_: int = Field( + 0, + description="The index/number of the molecule", + alias="number", + ) + isrigid_: bool = Field( + False, + description="Indicate whether the molecule is rigid", + ) + model_config = ConfigDict( + alias_to_fields={ + "name": "name_", + "number": "number_", + "isrigid": "isrigid_", + } + ) + + @property + def name(self) -> str: + """Return the name of the molecule.""" + return self.__dict__.get("name_") + + @property + def number(self) -> int: + """Return the index/number of the moleucle.""" + return self.__dict__.get("number_") + + @property + def isrigid(self) -> bool: + """Return the rigid label of the molecule.""" + return self.__dict__.get("isrigid_") + + def __hash__(self): + return hash(tuple([(name, val) for name, val in self.__dict__.items()])) + + def __eq__(self, other): + """Test if two objects are equivalent.""" + if isinstance(other, (list, tuple)): + return all( + [val1 == val2 for val1, val2 in zip(self.__dict__.values(), other)] + ) + else: + return self.__dict__ == other.__dict__ + + +class Residue(GMSOBase): + def __repr__(self): + return f"Residue(name={self.name}, residue={self.residue}" + + __iterable_attributes__: ClassVar[set] = { + "name", + "number", + } + + __base_doc__: ClassVar[str] = "Residue label for interaction sites." + + name_: str = Field( + "", + validate_default=True, + description="Name of the residue", + alias="name", + ) + number_: int = Field( + 0, + description="The index/number of the residue", + alias="number", + ) + model_config = ConfigDict( + alias_to_fields={ + "name": "name_", + "number": "number_", + } + ) + + @property + def name(self) -> str: + """Return the name of the residue.""" + return self.__dict__.get("name_") + + @property + def number(self) -> int: + """Return the index/number of the residue.""" + return self.__dict__.get("number_") + + def __hash__(self): + return hash(tuple([(name, val) for name, val in self.__dict__.items()])) + + def __eq__(self, other): + """Test if two objects are equivalent.""" + if isinstance(other, (list, tuple)): + return all( + [val1 == val2 for val1, val2 in zip(self.__dict__.values(), other)] + ) + else: + return self.__dict__ == other.__dict__ + SiteT = TypeVar("SiteT", bound="Site") @@ -76,13 +191,13 @@ class Site(GMSOBase): alias="group", ) - molecule_: Optional[MoleculeType] = Field( + molecule_: Optional[Union[Molecule, list, tuple]] = Field( None, description="Molecule label for the site, format of (molecule_name, molecule_number)", alias="molecule", ) - residue_: Optional[ResidueType] = Field( + residue_: Optional[Union[Residue, list, tuple]] = Field( None, description="Residue label for the site, format of (residue_name, residue_number)", alias="residue", @@ -126,7 +241,7 @@ def group(self) -> str: return self.__dict__.get("group_") @property - def molecule(self) -> tuple: + def molecule(self): """Return the molecule of the site.""" return self.__dict__.get("molecule_") @@ -185,12 +300,28 @@ def is_valid_position(cls, position): return position @field_validator("name_") - def inject_name(cls, value): + def parse_name(cls, value): if value == "" or value is None: return cls.__name__ else: return value + @field_validator("residue_") + def parse_residue(cls, value): + if isinstance(value, (tuple, list)): + assert len(value) == 2 + value = Residue(name=value[0], number=value[1]) + return value + + @field_validator("molecule_") + def parse_molecule(cls, value): + if isinstance(value, (tuple, list)): + if len(value) == 2: + value = Molecule(name=value[0], number=value[1]) + elif len(value) == 3: + value = Molecule(name=value[0], number=value[1], isrigid=value[2]) + return value + @classmethod def __new__(cls, *args: Any, **kwargs: Any) -> SiteT: if cls is Site: diff --git a/gmso/core/element.py b/gmso/core/element.py index d0e93ada5..d3ab8dd84 100644 --- a/gmso/core/element.py +++ b/gmso/core/element.py @@ -254,6 +254,8 @@ def element_by_smarts_string(smarts_string, verbose=False): GMSOError If no matching element is found for the provided smarts string """ + from lark import UnexpectedCharacters + from gmso.utils.io import import_ foyer = import_("foyer") @@ -261,7 +263,13 @@ def element_by_smarts_string(smarts_string, verbose=False): PARSER = SMARTS() - symbols = PARSER.parse(smarts_string).iter_subtrees_topdown() + try: + symbols = PARSER.parse(smarts_string).iter_subtrees_topdown() + except UnexpectedCharacters: + raise GMSOError( + f"Failed to find an element from SMARTS string {smarts_string}. " + f"The SMARTS string contained unexpected characters." + ) first_symbol = None for symbol in symbols: diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 856adaa7f..54580f7c7 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -2,6 +2,7 @@ import itertools import warnings +from copy import copy from pathlib import Path import numpy as np @@ -9,7 +10,7 @@ from boltons.setutils import IndexedSet import gmso -from gmso.abc.abstract_site import Site +from gmso.abc.abstract_site import Molecule, Residue, Site from gmso.abc.serialization_utils import unyt_to_dict from gmso.core.angle import Angle from gmso.core.angle_type import AngleType @@ -313,7 +314,7 @@ def unique_site_labels(self, label_type="molecule", name_only=False): unique_tags.add(label.name if label else None) else: for site in self.sites: - unique_tags.add(getattr(site, label_type)) + unique_tags.add(copy(getattr(site, label_type))) return unique_tags @property @@ -642,6 +643,18 @@ def get_scaling_factors(self, *, molecule_id=None): ] ) + def set_rigid(self, molecule): + """Set molecule tags to rigid if they match the name or number specified. + + Parameters + ---------- + molecule : str, Molecule, or tuple of 2 + Specified the molecule name and number to be set rigid. + If only string is provided, make all molecule of that name rigid. + """ + for site in self.iter_sites(key="molecule", value=molecule): + site.molecule.isrigid = True + def remove_site(self, site): """Remove a site from the topology. @@ -1382,9 +1395,30 @@ def iter_sites(self, key, value): for site in self._sites: if getattr(site, key) and getattr(site, key).name == value: yield site - for site in self._sites: - if getattr(site, key) == value: - yield site + elif isinstance(value, (tuple, list)): + containers_dict = {"molecule": Molecule, "residue": Residue} + if len(value) == 2: + tmp = containers_dict[key](name=value[0], number=value[1]) + elif len(value) == 3: + tmp = containers_dict[key]( + name=value[0], number=value[1], isrigid=value[2] + ) + else: + raise ValueError( + f""" + Argument value was passed as {value}, + but should be an indexible iterable of + [name, number, isrigid] where name is type string, + number is type int, and isrigid is type bool. + """ + ) + for site in self._sites: + if getattr(site, key) and getattr(site, key) == tmp: + yield site + else: + for site in self._sites: + if getattr(site, key) == value: + yield site def iter_sites_by_residue(self, residue_tag): """Iterate through this topology's sites which contain this specific residue name. diff --git a/gmso/external/convert_mbuild.py b/gmso/external/convert_mbuild.py index 4dfa6f03e..5ac724ce6 100644 --- a/gmso/external/convert_mbuild.py +++ b/gmso/external/convert_mbuild.py @@ -8,6 +8,7 @@ from boltons.setutils import IndexedSet from unyt import Unit +from gmso.abc.abstract_site import Residue from gmso.core.atom import Atom from gmso.core.bond import Bond from gmso.core.box import Box @@ -179,12 +180,14 @@ def to_mbuild(topology, infer_hierarchy=True): particle = _parse_particle(particle_map, site) # Try to add the particle to a residue level residue_tag = ( - site.residue if site.residue else ("DefaultResidue", 0) + site.residue + if site.residue + else Residue(name="DefaultResidue", number=0) ) # the 0 idx is placeholder and does nothing if residue_tag in residue_dict: residue_dict_particles[residue_tag] += [particle] else: - residue_dict[residue_tag] = mb.Compound(name=residue_tag[0]) + residue_dict[residue_tag] = mb.Compound(name=residue_tag.name) residue_dict_particles[residue_tag] = [particle] for key, item in residue_dict.items(): diff --git a/gmso/formats/lammpsdata.py b/gmso/formats/lammpsdata.py index 7eabc6bea..953300bb2 100644 --- a/gmso/formats/lammpsdata.py +++ b/gmso/formats/lammpsdata.py @@ -14,7 +14,7 @@ from unyt.array import allclose_units import gmso -from gmso.abc.abstract_site import MoleculeType +from gmso.abc.abstract_site import Molecule from gmso.core.angle import Angle from gmso.core.atom import Atom from gmso.core.atom_type import AtomType @@ -488,7 +488,9 @@ def _get_atoms(filename, topology, base_unyts, type_list): charge=charge, position=coord, atom_type=copy.deepcopy(type_list[int(atom_type) - 1]), # 0-index - molecule=MoleculeType(atom_line[1], int(atom_line[1]) - 1), # 0-index + molecule=Molecule( + name=atom_line[1], number=int(atom_line[1]) - 1 + ), # 0-index ) element = element_by_mass(site.atom_type.mass.value) site.name = element.name if element else site.atom_type.name diff --git a/gmso/formats/mol2.py b/gmso/formats/mol2.py index 61bb2876d..91e3b5ffd 100644 --- a/gmso/formats/mol2.py +++ b/gmso/formats/mol2.py @@ -7,7 +7,7 @@ import unyt as u from gmso import Atom, Bond, Box, Topology -from gmso.abc.abstract_site import MoleculeType, ResidueType +from gmso.abc.abstract_site import Molecule, Residue from gmso.core.element import element_by_name, element_by_symbol from gmso.formats.formats_registry import loads_as @@ -151,8 +151,8 @@ def parse_ele(*symbols): position=position.to("nm"), element=element, charge=charge, - residue=ResidueType(content[7], int(content[6])), - molecule=MoleculeType(molecule, 0), + residue=Residue(name=content[7], number=int(content[6])), + molecule=Molecule(name=molecule, number=0), ) top.add_site(atom) diff --git a/gmso/formats/top.py b/gmso/formats/top.py index 293244cff..bfe4644f5 100644 --- a/gmso/formats/top.py +++ b/gmso/formats/top.py @@ -3,6 +3,7 @@ import datetime import warnings +import numpy as np import unyt as u from gmso.core.dihedral import Dihedral @@ -159,12 +160,13 @@ def write_top(top, filename, top_vars=None): site.atom_type.name, str(site.molecule.number + 1 if site.molecule else 1), tag, - site.atom_type.tags.get("element", site.element.symbol), + site.atom_type.tags.get("element", site.name), "1", # TODO: care about charge groups site.atom_type.charge.in_units(u.elementary_charge).value, site.atom_type.mass.in_units(u.amu).value, ) ) + if unique_molecules[tag]["position_restraints"]: out_file.write(headers["position_restraints"]) for site in unique_molecules[tag]["position_restraints"]: @@ -174,6 +176,50 @@ def write_top(top, filename, top_vars=None): ) ) + # Special treatment for water, may ned to consider a better way to tag rigid water + # Built using this https://github.com/gromacs/gromacs/blob/main/share/top/oplsaa.ff/spce.itp as reference + if "water" in tag.lower() and all( + site.molecule.isrigid for site in unique_molecules[tag]["sites"] + ): + sites_list = unique_molecules[tag]["sites"] + + water_sites = { + "O": [site for site in sites_list if site.element.symbol == "O"], + "H": [site for site in sites_list if site.element.symbol == "H"], + } + + ow_idx = shifted_idx_map[top.get_index(water_sites["O"][0])] + 1 + doh = np.linalg.norm( + water_sites["O"][0].position.to(u.nm) + - water_sites["H"][0].position.to(u.nm) + ).to_value("nm") + dhh = np.linalg.norm( + water_sites["H"][0].position.to(u.nm) + - water_sites["H"][1].position.to(u.nm) + ).to_value("nm") + + # Write settles + out_file.write( + "\n[ settles ] ;Water specific constraint algorithm\n" + "; OW_idx\tfunct\tdoh\tdhh\n" + ) + out_file.write( + "{0:4s}{1:4s}{2:15.5f}{3:15.5f}\n".format( + str(ow_idx), "1", doh, dhh + ) + ) + + # Write exclusion + out_file.write( + "\n[ exclusions ] ;Exclude all interactions between water's atoms\n" + "1\t2\t3\n" + "2\t1\t3\n" + "3\t1\t2\n" + ) + + # Break out of the loop, skipping connection info + continue + for conn_group in [ "pairs", "bonds", @@ -240,7 +286,7 @@ def write_top(top, filename, top_vars=None): ) if conn_group == "dihedral_restraints": warnings.warn( - "The diehdral_restraints writer is designed to work with" + "The dihedral_restraints writer is designed to work with" "`define = DDIHRES` clause in the GROMACS input file (.mdp)" ) out_file.write("#endif DIHRES\n") diff --git a/gmso/tests/base_test.py b/gmso/tests/base_test.py index e8f771827..55c7b080b 100644 --- a/gmso/tests/base_test.py +++ b/gmso/tests/base_test.py @@ -182,6 +182,9 @@ def spce_water(self): def water_system(self): water = Topology(name="water") water = water.load(get_path("tip3p.mol2")) + for site in water.sites: + site.molecule.name = "WaterTIP3P" + return water @pytest.fixture @@ -277,11 +280,25 @@ def typed_water_system(self, water_system): return top @pytest.fixture - def typed_tip3p_rigid_system(self, water_system): + def typed_tip3p_system(self, water_system): top = water_system top.identify_connections() ff = ForceField(get_path("tip3p-rigid.xml")) top = apply(top, ff) + + return top + + @pytest.fixture + def typed_tip3p_rigid_system(self, water_system): + top = water_system + top.identify_connections() + ff = ForceField(get_path("tip3p.xml")) + top = apply(top, ff) + + molecules = top.unique_site_labels(name_only=False) + for molecule in molecules: + top.set_rigid(molecule) + return top @pytest.fixture diff --git a/gmso/tests/files/settles_ref.top b/gmso/tests/files/settles_ref.top new file mode 100644 index 000000000..9bed22355 --- /dev/null +++ b/gmso/tests/files/settles_ref.top @@ -0,0 +1,37 @@ +; File tip3p written by GMSO at 2024-06-16 22:44:09.646856 + +[ defaults ] +; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ +1 2 yes 0.5 0.5 + +[ atomtypes ] +; name at.num mass charge ptype sigma epsilon +opls_111 8 16.00000 -0.83400 A 0.31506 0.63639 +opls_112 1 1.01100 0.41700 A 1.00000 0.00000 + +[ moleculetype ] +; name nrexcl +WaterTIP3P 3 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass +1 opls_111 1 WaterTIP3P O 1 -0.83400 16.00000 +2 opls_112 1 WaterTIP3P H 1 0.41700 1.01100 +3 opls_112 1 WaterTIP3P H 1 0.41700 1.01100 + +[ settles ] ;Water specific constraint algorithm +; OW_idx funct doh dhh +1 1 0.08925 0.15133 + +[ exclusions ] ;Exclude all interactions between water's atoms +1 2 3 +2 1 3 +3 1 2 + +[ system ] +; name +tip3p + +[ molecules ] +; molecule nmols +WaterTIP3P 1 diff --git a/gmso/tests/files/typed_water_system_ref.top b/gmso/tests/files/typed_water_system_ref.top index f08e3d6c0..4b242e410 100644 --- a/gmso/tests/files/typed_water_system_ref.top +++ b/gmso/tests/files/typed_water_system_ref.top @@ -1,4 +1,4 @@ -; File Topology written by GMSO at 2023-04-21 15:15:49.414556 +; File tip3p written by GMSO at 2024-06-16 23:05:45.808217 [ defaults ] ; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ @@ -11,13 +11,13 @@ opls_112 1 1.01100 0.41700 A 1.00000 0.00000 [ moleculetype ] ; name nrexcl -RES 3 +WaterTIP3P 3 [ atoms ] ; nr type resnr residue atom cgnr charge mass -1 opls_111 1 RES O 1 -0.83400 16.00000 -2 opls_112 1 RES H 1 0.41700 1.01100 -3 opls_112 1 RES H 1 0.41700 1.01100 +1 opls_111 1 WaterTIP3P O 1 -0.83400 16.00000 +2 opls_112 1 WaterTIP3P H 1 0.41700 1.01100 +3 opls_112 1 WaterTIP3P H 1 0.41700 1.01100 [ bonds ] ; ai aj funct b0 kb @@ -34,4 +34,4 @@ tip3p [ molecules ] ; molecule nmols -RES 1 +WaterTIP3P 1 diff --git a/gmso/tests/test_convert_mbuild.py b/gmso/tests/test_convert_mbuild.py index ef3f0e918..7b7a9c447 100644 --- a/gmso/tests/test_convert_mbuild.py +++ b/gmso/tests/test_convert_mbuild.py @@ -53,7 +53,7 @@ def test_from_mbuild_argon(self, ar_system): assert top.sites[i].name == top.sites[i].element.symbol for site in top.sites: - assert site.molecule[0] == "Ar" + assert site.molecule.name == "Ar" def test_from_mbuild_single_particle(self): compound = mb.Compound() diff --git a/gmso/tests/test_convert_parmed.py b/gmso/tests/test_convert_parmed.py index 3a702554b..eb927e29c 100644 --- a/gmso/tests/test_convert_parmed.py +++ b/gmso/tests/test_convert_parmed.py @@ -249,8 +249,8 @@ def test_residues_info(self, parmed_hexane_box): ) == len(struc.residues) for site in top_from_struc.sites: - assert site.residue[0] == "HEX" - assert site.residue[1] in list(range(0, 6)) + assert site.residue.name == "HEX" + assert site.residue.number in list(range(0, 6)) struc_from_top = to_parmed(top_from_struc) assert len(struc_from_top.residues) == len(struc.residues) diff --git a/gmso/tests/test_mcf.py b/gmso/tests/test_mcf.py index 51426be0b..036d8f3d5 100644 --- a/gmso/tests/test_mcf.py +++ b/gmso/tests/test_mcf.py @@ -319,8 +319,8 @@ def test_typed_ethylene(self): assert np.allclose(ff_coeffs, 2.0 * mcf_coeffs) - def test_fixed_angles(self, typed_tip3p_rigid_system): - top = typed_tip3p_rigid_system + def test_fixed_angles(self, typed_tip3p_system): + top = typed_tip3p_system write_mcf(top, "tip3p-rigid.mcf") mcf_data, mcf_idx = parse_mcf("tip3p-rigid.mcf") diff --git a/gmso/tests/test_mol2.py b/gmso/tests/test_mol2.py index d7b651350..16a4f3fc5 100644 --- a/gmso/tests/test_mol2.py +++ b/gmso/tests/test_mol2.py @@ -63,21 +63,21 @@ def test_read_mol2(self): def test_residue(self): top = Topology.load(get_fn("ethanol_aa.mol2")) - assert np.all([site.residue[0] == "ETO" for site in top.sites]) - assert np.all([site.residue[1] == 1 for site in top.sites]) + assert np.all([site.residue.name == "ETO" for site in top.sites]) + assert np.all([site.residue.number == 1 for site in top.sites]) top = Topology.load(get_fn("benzene_ua.mol2"), site_type="lj") assert np.all( - [site.residue[0] == "BEN1" for site in top.iter_sites_by_residue("BEN1")] + [site.residue.name == "BEN1" for site in top.iter_sites_by_residue("BEN1")] ) assert np.all( - [site.residue[1] == 1 for site in top.iter_sites_by_residue("BEN1")] + [site.residue.number == 1 for site in top.iter_sites_by_residue("BEN1")] ) assert np.all( - [site.residue[0] == "BEN2" for site in top.iter_sites_by_residue("BEN2")] + [site.residue.name == "BEN2" for site in top.iter_sites_by_residue("BEN2")] ) assert np.all( - [site.residue[1] == 2 for site in top.iter_sites_by_residue("BEN2")] + [site.residue.number == 2 for site in top.iter_sites_by_residue("BEN2")] ) def test_lj_system(self): diff --git a/gmso/tests/test_top.py b/gmso/tests/test_top.py index aabcffcb9..2cdd312d2 100644 --- a/gmso/tests/test_top.py +++ b/gmso/tests/test_top.py @@ -149,6 +149,18 @@ def test_custom_defaults(self, typed_ethane): assert struct.defaults.fudgeLJ == 0.5 assert struct.defaults.fudgeQQ == 0.5 + def test_settles(self, typed_tip3p_rigid_system): + typed_tip3p_rigid_system.save("settles.top", overwrite=True) + + with open("settles.top", "r") as f1: + current = f1.readlines() + + with open(get_path("settles_ref.top"), "r") as f2: + ref = f2.readlines() + + for line, ref_line in zip(current[1:], ref[1:]): + assert line, ref_line + def test_benzene_restraints(self, typed_benzene_ua_system): top = typed_benzene_ua_system diff --git a/gmso/utils/sorting.py b/gmso/utils/sorting.py index 6382ee353..93b6d4f52 100644 --- a/gmso/utils/sorting.py +++ b/gmso/utils/sorting.py @@ -253,6 +253,4 @@ def reindex_molecules(top): for site in top.sites: mol_num = site.molecule.number - site.molecule = site.molecule._replace( - number=mol_num - offsetDict[site.molecule.name] - ) + site.molecule.number = mol_num - offsetDict[site.molecule.name]