From 12d59079b72eeaaea55c9dc605e510d07e59bf34 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 15 Feb 2024 12:50:53 +0100 Subject: [PATCH] Fix non-alchemical learnable basis bug --- torch_spex/radial_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_spex/radial_basis.py b/torch_spex/radial_basis.py index e3cebf8..af552fe 100644 --- a/torch_spex/radial_basis.py +++ b/torch_spex/radial_basis.py @@ -134,7 +134,7 @@ def forward(self, r, samples_metadata: Labels): split_l_aj = l_aj.split("_") l = int(split_l_aj[0]) aj = int(split_l_aj[1]) - where_aj = torch.nonzero(neighbor_species == aj)[0] + where_aj = torch.nonzero(neighbor_species == aj)[:, 0] radial_basis_after_mlp[l][where_aj, :] = radial_mlp_l_aj(torch.index_select(radial_basis[l], 0, where_aj)) return radial_basis_after_mlp else: