Skip to content

Commit

Permalink
check on forces applied to ani1x (ORNL#258)
Browse files Browse the repository at this point in the history
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
  • Loading branch information
allaffa and Massimiliano Lupo Pasini authored Jun 16, 2024
1 parent e031b9f commit 80f0691
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion examples/ani1_x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, dirpath, var_config, energy_per_atom=True, dist=False):
self.world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()

# Threshold for atomic forces in eV/angstrom
self.forces_norm_threshold = 100.0

self.convert_trajectories_to_graphs()

def convert_trajectories_to_graphs(self):
Expand Down Expand Up @@ -121,7 +124,13 @@ def convert_trajectories_to_graphs(self):
data = self.radius_graph(data)
data = transform_coordinates(data)

self.dataset.append(data)
if self.check_forces_values(data.force):
self.dataset.append(data)
else:
print(
f"L2-norm of force tensor exceeds threshold {self.forces_norm_threshold} - atomistic structure: {data}",
flush=True,
)

def iter_data_buckets(self, h5filename, keys=["wb97x_dz.energy"]):
"""Iterate over buckets of data in ANI HDF5 file.
Expand All @@ -146,6 +155,14 @@ def iter_data_buckets(self, h5filename, keys=["wb97x_dz.energy"]):
d["coordinates"] = grp["coordinates"][()][mask]
yield d

def check_forces_values(self, forces):

# Calculate the L2 norm for each row
norms = torch.norm(forces, p=2, dim=1)
# Check if all norms are less than the threshold

return torch.all(norms < self.forces_norm_threshold).item()

def len(self):
return len(self.dataset)

Expand Down

0 comments on commit 80f0691

Please sign in to comment.