Skip to content

Commit

Permalink
replace temporary sort with official one
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Sep 15, 2023
1 parent 9f7ca67 commit d689562
Showing 1 changed file with 4 additions and 40 deletions.
44 changes: 4 additions & 40 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def test_vector_expansion_coeffs(self):
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
# we need to sort both computed and reference pair expansion coeffs,
# because ase.neighborlist can get different neighborlist order for some reasons
tm_ref = sort_tm(tm_ref)
tm_ref = metatensor.torch.sort(tm_ref)
vector_expansion = VectorExpansion(self.hypers, self.all_species,
device=self.device, dtype=self.dtype)
with torch.no_grad():
tm = sort_tm(vector_expansion.forward(**self.batch))
tm = metatensor.torch.sort(vector_expansion.forward(**self.batch))
# Default types are float32 so we cannot get higher accuracy than 1e-7.
# Because the reference value have been cacluated using float32 and
# now we using float64 computation the accuracy had to be decreased again
Expand Down Expand Up @@ -101,11 +101,11 @@ class TestArtificialSphericalExpansion:
def test_vector_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-artificial-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = sort_tm(tm_ref)
tm_ref = metatensor.torch.sort(tm_ref)
vector_expansion = VectorExpansion(self.hypers, self.all_species,
device=self.device, dtype=self.dtype)
with torch.no_grad():
tm = sort_tm(vector_expansion.forward(**self.batch))
tm = metatensor.torch.sort(vector_expansion.forward(**self.batch))
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)

def test_spherical_expansion_coeffs(self):
Expand Down Expand Up @@ -137,39 +137,3 @@ def test_spherical_expansion_coeffs_artificial(self):
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)

### these util functions will be removed once lab-cosmo/metatensor/pull/281 is merged
def native_list_argsort(native_list):
return sorted(range(len(native_list)), key=native_list.__getitem__)

def sort_tm(tm):
blocks = []
for _, block in tm.items():
values = block.values

samples_values = block.samples.values
sorted_idx = native_list_argsort([tuple(row.tolist()) for row in block.samples.values])
samples_values = samples_values[sorted_idx]
values = values[sorted_idx]

components_values = []
for i, component in enumerate(block.components):
component_values = component.values
sorted_idx = native_list_argsort([tuple(row.tolist()) for row in component.values])
components_values.append( component_values[sorted_idx] )
values = np.take(values, sorted_idx, axis=i+1)

properties_values = block.properties.values
sorted_idx = native_list_argsort([tuple(row.tolist()) for row in block.properties.values])
properties_values = properties_values[sorted_idx]
values = values[..., sorted_idx]

blocks.append(
TensorBlock(
values=values,
samples=Labels(values=samples_values, names=block.samples.names),
components=[Labels(values=components_values[i], names=component.names) for i, component in enumerate(block.components)],
properties=Labels(values=properties_values, names=block.properties.names)
)
)
return TensorMap(keys=tm.keys, blocks=blocks)

0 comments on commit d689562

Please sign in to comment.