diff --git a/examples/qm7x/train.py b/examples/qm7x/train.py index 00aea2c33..936c29464 100644 --- a/examples/qm7x/train.py +++ b/examples/qm7x/train.py @@ -194,17 +194,17 @@ def hdf5_to_graph(self, fMOL, molid): forces ), f"qm7x dataset - molid:{molid} - confid:{confid} - L2-norm of atomic forces exceeds {self.forces_norm_threshold}" + if self.energy_per_atom: + energy = EPBE0 / natoms + else: + energy = EPBE0 + # data = Data( # pos=xyz, x=Z, molid=molid, confid=confid # ) - data = Data(pos=xyz, x=Z) + data = Data(pos=xyz, x=Z, force=forces, energy=energy, y=energy) data.x = torch.cat((data.x, xyz, forces, hCHG, hVDIP, hRAT), dim=1) - if self.energy_per_atom: - data.y = EPBE0 / natoms - else: - data.y = EPBE0 - data = create_graph_fromXYZ(data) # Add edge length as edge feature