Skip to content

Commit

Permalink
stress change.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Nov 18, 2024
1 parent 3c180de commit 8c983f5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
19 changes: 8 additions & 11 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def calculate(self, atoms, properties=None, system_changes=None):
(
g.to(self.device),
lg.to(self.device),
torch.tensor(atoms.cell)
torch.tensor(np.array(atoms.cell))
.type(torch.get_default_dtype())
.to(self.device),
)
Expand All @@ -321,19 +321,16 @@ def calculate(self, atoms, properties=None, system_changes=None):
energy = result["out"].detach().cpu().numpy() * num_atoms
else:
energy = result["out"].detach().cpu().numpy()

stress = self.stress_wt * np.array(
full_3x3_to_voigt_6_stress(
result["stresses"][:3].reshape(3, 3).detach().cpu().numpy()
)
)
# print('stress',stress)
self.results = {
"energy": energy, # * num_atoms,
"forces": result["grad"].detach().cpu().numpy(),
"stress": full_3x3_to_voigt_6_stress(
# np.eye(3)
result["stresses"][:3]
.reshape(3, 3)
.detach()
.cpu()
.numpy()
)
/ 160.21766208,
"stress": stress,
"dipole": np.zeros(3),
"charges": np.zeros(len(atoms)),
"magmom": 0.0,
Expand Down
64 changes: 32 additions & 32 deletions alignn/models/alignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,38 +581,38 @@ def forward(
# print("stress1", stress, stress.shape)
# print("g.batch_size", g.batch_size)
else:
stresses = []
count_edge = 0
count_node = 0
for graph_id in range(g.batch_size):
num_edges = g.batch_num_edges()[graph_id]
num_nodes = 0
st = -1 * (
160.21766208
* torch.matmul(
r[count_edge : count_edge + num_edges].T,
pair_forces[
count_edge : count_edge + num_edges
],
)
/ g.ndata["V"][count_node + num_nodes]
)

count_edge = count_edge + num_edges
num_nodes = g.batch_num_nodes()[graph_id]
count_node = count_node + num_nodes
stresses.append(st)
stress = self.config.stress_multiplier * torch.stack(
stresses
)
# print("stress2", stress, stress.shape)
# virial = (
# 160.21766208
# * 10
# * torch.einsum("ij, ik->jk",
# result["r"], result["dy_dr"])
# / 2
# ) # / ( g.ndata["V"][0])
# stresses = []
# count_edge = 0
# count_node = 0
# for graph_id in range(g.batch_size):
# num_edges = g.batch_num_edges()[graph_id]
# num_nodes = 0
# st = -1 * (
# 160.21766208
# * torch.matmul(
# r[count_edge : count_edge + num_edges].T,
# pair_forces[
# count_edge : count_edge + num_edges
# ],
# )
# / g.ndata["V"][count_node + num_nodes]
# )

# count_edge = count_edge + num_edges
# num_nodes = g.batch_num_nodes()[graph_id]
# count_node = count_node + num_nodes
# stresses.append(st)
# stress = self.config.stress_multiplier * torch.stack(
# stresses
# )
stress = (
# 160.21766208
# * 10
-1
* torch.einsum("ij, ik->jk", r, pair_forces)
/ 2
# / (2 * g.ndata["V"][0])
) ## / ( g.ndata["V"][0])
if self.link:
out = self.link(out)

Expand Down

0 comments on commit 8c983f5

Please sign in to comment.