Skip to content

Commit

Permalink
Add tests and change behavior when applying types, update connection.…
Browse files Browse the repository at this point in the history
…connection_members from types (#808)
  • Loading branch information
CalCraven authored Mar 17, 2024
1 parent 69fd5e2 commit d3fca73
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 66 deletions.
17 changes: 12 additions & 5 deletions gmso/formats/lammpsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from gmso.lib.potential_templates import PotentialTemplateLibrary
from gmso.utils.compatibility import check_compatibility
from gmso.utils.conversions import convert_kelvin_to_energy_units
from gmso.utils.sorting import reindex_molecules, sort_by_types
from gmso.utils.sorting import (
reindex_molecules,
sort_by_types,
sort_connection_members,
)
from gmso.utils.units import LAMMPS_UnitSystems, write_out_parameter_and_units


Expand Down Expand Up @@ -1223,11 +1227,14 @@ def _write_conn_data(out_file, top, connStr, sorted_typesList):
]
for index in indexList:
typeStr = f"{i+1:<6d}\t{index+1:<6d}\t"
sorted_membersList = sort_connection_members(
conn, sort_by="index", top=top
)
indexStr = "\t".join(
map(
lambda x: str(top.sites.index(x) + 1).ljust(6),
conn.connection_members,
)
[
str(top.get_index(member) + 1).ljust(6)
for member in sorted_membersList
]
)
out_file.write(typeStr + indexStr + "\n")
i += 1
Expand Down
21 changes: 16 additions & 5 deletions gmso/formats/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,14 @@ def _write_connection(top, connection, potential_name, shifted_idx_map):

def _harmonic_bond_potential_writer(top, bond, shifted_idx_map):
"""Write harmonic bond information."""
eq_connsList = bond.equivalent_members()
indexList = [
tuple(map(lambda x: top.get_index(x), conn)) for conn in eq_connsList
]
sorted_indicesList = sorted(indexList)[0]
line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format(
str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1),
str(shifted_idx_map[sorted_indicesList[0]] + 1),
str(shifted_idx_map[sorted_indicesList[1]] + 1),
"1",
bond.connection_type.parameters["r_eq"].in_units(u.nm).value,
bond.connection_type.parameters["k"]
Expand All @@ -456,10 +461,16 @@ def _harmonic_bond_potential_writer(top, bond, shifted_idx_map):

def _harmonic_angle_potential_writer(top, angle, shifted_idx_map):
"""Write harmonic angle information."""
eq_connsList = angle.equivalent_members()
indexList = [
tuple(map(lambda x: top.get_index(x), conn)) for conn in eq_connsList
]
sorted_indicesList = sorted(indexList)[0]

line = "{0:8s}{1:8s}{2:8s}{3:4s}{4:15.5f}{5:15.5f}\n".format(
str(shifted_idx_map[top.get_index(angle.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[2])] + 1),
str(shifted_idx_map[sorted_indicesList[0]] + 1),
str(shifted_idx_map[sorted_indicesList[1]] + 1),
str(shifted_idx_map[sorted_indicesList[2]] + 1),
"1",
angle.connection_type.parameters["theta_eq"].in_units(u.degree).value,
angle.connection_type.parameters["k"]
Expand Down
2 changes: 1 addition & 1 deletion gmso/parameterization/topology_parameterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _apply_connection_parameters(
matched_order = [
connection.connection_members[i] for i in match[1]
]
# connection.connection_members = matched_order
connection.connection_members = matched_order
if not match[0].member_types:
connection.connection_type.member_types = tuple(
member.atom_type.name for member in matched_order
Expand Down
12 changes: 12 additions & 0 deletions gmso/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,18 @@ def parmed_benzene(self):
)
return benzene

@pytest.fixture
def benzeneTopology(self):
untyped_benzene = mb.load(get_fn("benzene.mol2"))
top_benzene = untyped_benzene.to_gmso()
ff_improper = ForceField(get_fn("benzeneaa_improper.xml"))
return apply(
top_benzene,
ff_improper,
identify_connections=True,
ignore_params=[],
)

# TODO: now
# add in some fixtures for (connects), amber

Expand Down
6 changes: 3 additions & 3 deletions gmso/tests/files/charmm.lammps
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ Dihedrals
7 5 2 1 7 9
8 3 3 1 7 8
9 3 3 1 7 9
10 1 4 3 1 7
11 1 5 3 1 7
12 1 6 3 1 7
10 1 7 1 3 4
11 1 7 1 3 5
12 1 7 1 3 6

Impropers

Expand Down
50 changes: 25 additions & 25 deletions gmso/tests/files/restrained_benzene_ua.top
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; File Topology written by GMSO at 2023-05-04 17:01:43.396827
; File Topology written by GMSO at 2024-03-05 18:12:52.039136

[ defaults ]
; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ
Expand Down Expand Up @@ -38,12 +38,12 @@ Compound 3

[ bonds ] ;Harmonic potential restraint
; ai aj funct b0 kb
2 1 6 0.14000 1000.00000
6 1 6 0.14000 1000.00000
3 2 6 0.14000 1000.00000
4 3 6 0.14000 1000.00000
5 4 6 0.14000 1000.00000
6 5 6 0.14000 1000.00000
1 2 6 0.14000 1000.00000
1 6 6 0.14000 1000.00000
2 3 6 0.14000 1000.00000
3 4 6 0.14000 1000.00000
4 5 6 0.14000 1000.00000
5 6 6 0.14000 1000.00000

[ angles ]
; ai aj ak funct phi_0 k0
Expand All @@ -56,31 +56,31 @@ Compound 3

[ angle_restraints ]
; ai aj ai ak funct theta_eq k multiplicity
1 2 1 6 1 120.00000 1000.00000 1
2 1 2 3 1 120.00000 1000.00000 1
3 2 3 4 1 120.00000 1000.00000 1
4 3 4 5 1 120.00000 1000.00000 1
5 4 5 6 1 120.00000 1000.00000 1
6 1 6 5 1 120.00000 1000.00000 1
1 6 1 2 1 120.00000 1000.00000 1
2 3 2 1 1 120.00000 1000.00000 1
3 4 3 2 1 120.00000 1000.00000 1
4 5 4 3 1 120.00000 1000.00000 1
5 6 5 4 1 120.00000 1000.00000 1
6 5 6 1 1 120.00000 1000.00000 1

[ dihedrals ]
; ai aj ak al funct c0 c1 c2 c3 c4 c5
2 1 6 5 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
3 2 1 6 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
1 2 3 4 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
2 3 4 5 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
3 4 5 6 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
1 6 5 4 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
5 6 1 2 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
6 1 2 3 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
4 3 2 1 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
5 4 3 2 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
6 5 4 3 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
4 5 6 1 3 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000

[ dihedral_restraints ]
#ifdef DIHRES
; ai aj ak al funct theta_eq delta_theta kd
2 1 6 5 1 0.00000 0.00000 1000.00000
3 2 1 6 1 0.00000 0.00000 1000.00000
1 2 3 4 1 0.00000 0.00000 1000.00000
2 3 4 5 1 0.00000 0.00000 1000.00000
3 4 5 6 1 0.00000 0.00000 1000.00000
1 6 5 4 1 0.00000 0.00000 1000.00000
5 6 1 2 1 0.00000 0.00000 1000.00000
6 1 2 3 1 0.00000 0.00000 1000.00000
4 3 2 1 1 0.00000 0.00000 1000.00000
5 4 3 2 1 0.00000 0.00000 1000.00000
6 5 4 3 1 0.00000 0.00000 1000.00000
4 5 6 1 1 0.00000 0.00000 1000.00000
#endif DIHRES

[ system ]
Expand Down
67 changes: 42 additions & 25 deletions gmso/tests/parameterization/parameterization_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,51 @@ def trappe_ua_foyer(self):
def assert_same_connection_params(self):
def _assert_same_connection_params(top1, top2, connection_type="bonds"):
"""Match connection parameters between two gmso topologies."""
connection_types_original = {}
connection_types_mirror = {}
for connection in getattr(top2, connection_type):
connection_types_mirror[
tuple(
top2.get_index(member)
for member in connection.connection_members
)
] = connection

connection_types_top1 = {}
for connection in getattr(top1, connection_type):
connection_types_original[
tuple(
top1.get_index(member)
for member in connection.connection_members
)
] = connection

for key in connection_types_original:
conn = connection_types_original[key]
conn_mirror = connection_types_mirror[key]
eq_connsList = connection.equivalent_members()
indexList = [
tuple(map(lambda x: top1.get_index(x), conn))
for conn in eq_connsList
]
atom_indicesList = sorted(indexList)[0]
connection_types_top1[atom_indicesList] = connection
connection_types_top2 = {}
for connection in getattr(top2, connection_type):
eq_connsList = connection.equivalent_members()
indexList = [
tuple(map(lambda x: top2.get_index(x), conn))
for conn in eq_connsList
]
atom_indicesList = sorted(indexList)[0]
connection_types_top2[atom_indicesList] = connection

# for connection in getattr(top2, connection_type):
# connection_types_mirror[
# tuple(
# top2.get_index(member)
# for member in sort_connection_members(connection, "atom_type")
# )
# ] = connection

# for connection in getattr(top1, connection_type):
# connection_types_original[
# tuple(
# top1.get_index(member)
# for member in sort_connection_members(connection, "atom_type")
# )
# ] = connection

for key in connection_types_top1:
conn1 = connection_types_top1[key]
conn2 = connection_types_top2[key]
conn_type_attr = connection_type[:-1] + "_type"
conn_type_mirror = getattr(conn_mirror, conn_type_attr)
conn_type = getattr(conn, conn_type_attr)
for param in conn_type.parameters:
conn_type1 = getattr(conn1, conn_type_attr)
conn_type2 = getattr(conn2, conn_type_attr)
for param in conn_type1.parameters:
assert u.allclose_units(
conn_type_mirror.parameters[param],
conn_type.parameters[param],
conn_type2.parameters[param],
conn_type1.parameters[param],
)

return _assert_same_connection_params
Expand Down
94 changes: 94 additions & 0 deletions gmso/tests/test_improper.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,97 @@ def test_sort_improper_types(self):
)
assert sort_by_classes(imptype) == expected_sortingList
assert sort_by_types(imptype) == expected_sortingList

def test_sorting_improper_based_on_impropertype(self):
from gmso.exceptions import MissingParameterError
from gmso.utils.sorting import sort_by_classes, sort_by_types

def sort_improper_connection_members(improper):
if improper.improper_type is None:
return improper.connection_members
improper_classes = improper.improper_type.member_classes
improper_types = improper.improper_type.member_types
if improper_classes:
order_improperList = improper_classes
orderStr = "atomclass" # String to access site attribute
elif improper_types:
order_improperList = improper_types
orderStr = "name" # String to access site attribute
else:
missing_types = [site.atom_type.atomclass for site in improper]
raise MissingParameterError(
improper.improper_type, missing_types
)

# get the site atomtypes and make a dictionary map to match to the order_improperList
cmemList = improper.connection_members
assert order_improperList[0] == getattr(
cmemList[0].atom_type, orderStr
) # first atoms should be the same
first_site = cmemList[0]
middle_sitesList = []
for site in cmemList[1:]:
if getattr(site.atom_type, orderStr) == order_improperList[-1]:
last_site = site
else:
middle_sitesList.append(site)
assert (
len(middle_sitesList) == 2
), f"The improper_type {improper.improper_type} could not find 2 middle sites from {middle_sitesList}"
middle_sitesList = sorted(
middle_sitesList,
key=lambda site: getattr(site.atom_type, orderStr),
)
return [first_site] + middle_sitesList + [last_site]

atom1 = Atom(
name="atom1",
position=[0, 0, 0],
atom_type=AtomType(name="A", atomclass="A"),
)
atom2 = Atom(
name="atom2",
position=[1, 0, 0],
atom_type=AtomType(name="B", atomclass="B"),
)
atom3 = Atom(
name="atom3",
position=[1, 1, 0],
atom_type=AtomType(name="C", atomclass="C"),
)
atom4 = Atom(
name="atom4",
position=[1, 1, 4],
atom_type=AtomType(name="D", atomclass="D"),
)

connect = Improper(connection_members=[atom2, atom1, atom4, atom3])

consituentList = [
atom2.atom_type.name,
atom4.atom_type.name,
atom3.atom_type.name,
atom1.atom_type.name,
]

imptype = ImproperType(
member_types=consituentList, member_classes=consituentList
)
connect.improper_type = imptype
expected_membersList = [atom2, atom3, atom4, atom1]
assert sort_by_types(connect.improper_type) == tuple(
[site.atom_type.name for site in expected_membersList]
)
assert (
tuple([site.atom_type for site in connect.connection_members])
!= imptype.member_types
)
assert sort_improper_connection_members(connect) == expected_membersList

def test_applied_improper_updates_connection_members(self, benzeneTopology):
improper = benzeneTopology.impropers[0]
classes_connectionList = tuple(
[site.atom_type.atomclass for site in improper.connection_members]
)
classes_typeList = improper.improper_type.member_classes
assert classes_connectionList == classes_typeList
1 change: 0 additions & 1 deletion gmso/tests/test_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def test_read_n_diherals(self, typed_ethane_opls):
"typed_ethane",
"typed_methylnitroaniline",
"typed_methaneUA",
"typed_water_system",
],
)
def test_lammps_vs_parmed_by_mol(self, top, request):
Expand Down
25 changes: 25 additions & 0 deletions gmso/utils/files/benzeneaa_improper.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<ForceField name="Test-Torsions" version="0.0.1" combining_rule="geometric">
<AtomTypes>
<Type name="CT" class="CT" element="C" mass="12.01100" def="C"/>
<Type name="HC" class="HC" element="H" mass="1.0081" def="H"/>
</AtomTypes>
<HarmonicBondForce>
<Bond class1="CT" class2="HC" length="0.163" k="251040.0"/>
<Bond class1="CT" class2="CT" length="0.163" k="251040.0"/>
</HarmonicBondForce>
<HarmonicAngleForce>
<Angle class1="CT" class2="CT" class3="HC" angle="2.7053" k="397.48"/>
<Angle class1="HC" class2="CT" class3="HC" angle="2.7053" k="397.48"/>
<Angle class1="CT" class2="CT" class3="CT" angle="1.9111" k="397.48"/>
</HarmonicAngleForce>
<PeriodicTorsionForce>
<Proper class1="CT" class2="CT" class3="CT" class4="HC" periodicity1="1" phase1="3.14" k1="3.1"/>
<Proper class1="CT" class2="CT" class3="CT" class4="CT" periodicity1="1" phase1="3.14" k1="3.1"/>
<Proper class1="HC" class2="CT" class3="CT" class4="HC" periodicity1="1" phase1="3.14" k1="3.1"/>
<Improper class1="CT" class2="CT" class3="HC" class4="CT" periodicity1="1" phase1="3.14" k1="3.1"/>
</PeriodicTorsionForce>
<NonbondedForce coulomb14scale="0" lj14scale="0">
<Atom type="CT" charge="-0.100" sigma="0.302" epsilon="0.773245"/>
<Atom type="HC" charge="0.100" sigma="0.302" epsilon="0.01"/>
</NonbondedForce>
</ForceField>
Loading

0 comments on commit d3fca73

Please sign in to comment.