Skip to content

Commit

Permalink
merge torchscript tests with regular tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Sep 15, 2023
1 parent d689562 commit 5dba808
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 76 deletions.
21 changes: 18 additions & 3 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import metatensor.torch
from metatensor.torch import Labels, TensorBlock, TensorMap
import numpy as np
import ase.io

from torch_spex.spherical_expansions import VectorExpansion, SphericalExpansion
Expand All @@ -18,7 +17,9 @@ class TestEthanol1SphericalExpansion:
device = "cpu"
dtype = torch.float32
frames = ase.io.read('datasets/rmd17/ethanol1.extxyz', ':1')
all_species = list(np.unique([frame.numbers for frame in frames]))
all_species = torch.unique(torch.concatenate([torch.tensor(frame.numbers)
for frame in frames]))
all_species = [int(species) for species in all_species]
with open("tests/data/expansion_coeffs-ethanol1_0-hypers.json", "r") as f:
hypers = json.load(f)

Expand All @@ -44,6 +45,12 @@ def test_vector_expansion_coeffs(self):
# now we using float64 computation the accuracy had to be decreased again
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)

vector_expansion_script = torch.jit.script(vector_expansion)
with torch.no_grad():
tm_script = metatensor.torch.sort(vector_expansion_script.forward(**self.batch))
assert metatensor.torch.allclose(tm, tm_script, atol=1e-5,
rtol=torch.finfo(self.dtype).eps*10)

def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
Expand All @@ -56,6 +63,12 @@ def test_spherical_expansion_coeffs(self):
# now we using float64 computation the accuracy had to be decreased again
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)

spherical_expansion_script = torch.jit.script(spherical_expansion_calculator)
with torch.no_grad():
tm_script = metatensor.torch.sort(spherical_expansion_script.forward(**self.batch))
assert metatensor.torch.allclose(tm, tm_script, atol=1e-5,
rtol=torch.finfo(self.dtype).eps*10)

def test_spherical_expansion_coeffs_alchemical(self):
with open("tests/data/expansion_coeffs-ethanol1_0-alchemical-hypers.json", "r") as f:
hypers = json.load(f)
Expand Down Expand Up @@ -88,7 +101,9 @@ class TestArtificialSphericalExpansion:
device = "cpu"
dtype = torch.float32
frames = ase.io.read('tests/datasets/artificial.extxyz', ':')
all_species = list(np.unique(np.hstack([frame.numbers for frame in frames])))
all_species = torch.unique(torch.concatenate([torch.tensor(frame.numbers)
for frame in frames]))
all_species = [int(species) for species in all_species]
with open("tests/data/expansion_coeffs-artificial-hypers.json", "r") as f:
hypers = json.load(f)

Expand Down
73 changes: 0 additions & 73 deletions tests/test_torchscript.py

This file was deleted.

0 comments on commit 5dba808

Please sign in to comment.