From c86c3738fd5277a94e951e567bf7316298189cfe Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Thu, 14 Sep 2023 11:53:37 +0200 Subject: [PATCH] merge torchscript tests with regular tests --- tests/test_spherical_expansions.py | 21 +++++++-- tests/test_torchscript.py | 73 ------------------------------ 2 files changed, 18 insertions(+), 76 deletions(-) delete mode 100644 tests/test_torchscript.py diff --git a/tests/test_spherical_expansions.py b/tests/test_spherical_expansions.py index c0dfdac..91eb974 100644 --- a/tests/test_spherical_expansions.py +++ b/tests/test_spherical_expansions.py @@ -4,7 +4,6 @@ import metatensor.torch from metatensor.torch import Labels, TensorBlock, TensorMap -import numpy as np import ase.io from torch_spex.spherical_expansions import VectorExpansion, SphericalExpansion @@ -18,7 +17,9 @@ class TestEthanol1SphericalExpansion: device = "cpu" dtype = torch.float32 frames = ase.io.read('datasets/rmd17/ethanol1.extxyz', ':1') - all_species = list(np.unique([frame.numbers for frame in frames])) + all_species = torch.unique(torch.concatenate([torch.tensor(frame.numbers) + for frame in frames])) + all_species = [int(species) for species in all_species] with open("tests/data/expansion_coeffs-ethanol1_0-hypers.json", "r") as f: hypers = json.load(f) @@ -44,6 +45,12 @@ def test_vector_expansion_coeffs(self): # now we using float64 computation the accuracy had to be decreased again assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + vector_expansion_script = torch.jit.script(vector_expansion) + with torch.no_grad(): + tm_script = metatensor.torch.sort(vector_expansion_script.forward(**self.batch)) + assert metatensor.torch.allclose(tm, tm_script, atol=1e-5, + rtol=torch.finfo(self.dtype).eps*10) + def test_spherical_expansion_coeffs(self): tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-data.npz") tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype) @@ -56,6 +63,12 @@ def test_spherical_expansion_coeffs(self): # now we using float64 computation the accuracy had to be decreased again assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5) + spherical_expansion_script = torch.jit.script(spherical_expansion_calculator) + with torch.no_grad(): + tm_script = metatensor.torch.sort(spherical_expansion_script.forward(**self.batch)) + assert metatensor.torch.allclose(tm, tm_script, atol=1e-5, + rtol=torch.finfo(self.dtype).eps*10) + def test_spherical_expansion_coeffs_alchemical(self): with open("tests/data/expansion_coeffs-ethanol1_0-alchemical-hypers.json", "r") as f: hypers = json.load(f) @@ -88,7 +101,9 @@ class TestArtificialSphericalExpansion: device = "cpu" dtype = torch.float32 frames = ase.io.read('tests/datasets/artificial.extxyz', ':') - all_species = list(np.unique(np.hstack([frame.numbers for frame in frames]))) + all_species = torch.unique(torch.concatenate([torch.tensor(frame.numbers) + for frame in frames])) + all_species = [int(species) for species in all_species] with open("tests/data/expansion_coeffs-artificial-hypers.json", "r") as f: hypers = json.load(f) diff --git a/tests/test_torchscript.py b/tests/test_torchscript.py deleted file mode 100644 index 8bbeac8..0000000 --- a/tests/test_torchscript.py +++ /dev/null @@ -1,73 +0,0 @@ -import json - -import torch - -import metatensor.torch -from metatensor.torch import Labels, TensorBlock, TensorMap -import numpy as np -import ase.io - -from torch_spex.spherical_expansions import VectorExpansion, SphericalExpansion -from torch_spex.structures import InMemoryDataset, TransformerNeighborList, collate_nl -from torch.utils.data import DataLoader - - -class TestEthanol1SphericalExpansion: - """ - Tests on the ethanol1 dataset - """ - device = "cpu" - frames = ase.io.read('datasets/rmd17/ethanol1.extxyz', ':1') - all_species = list(np.unique([frame.numbers for frame in frames])) - all_species = [int(species) for species in all_species] - with open("tests/data/expansion_coeffs-ethanol1_0-hypers.json", "r") as f: - hypers = json.load(f) - - transformers = [TransformerNeighborList(cutoff=hypers["cutoff radius"])] - dataset = InMemoryDataset(frames, transformers) - loader = DataLoader(dataset, batch_size=1, collate_fn=collate_nl) - batch = next(iter(loader)) - - def test_vector_expansion_coeffs(self): - vector_expansion = torch.jit.script(VectorExpansion(self.hypers, self.all_species, device=self.device)) - vector_expansion.forward(**self.batch) - - def test_spherical_expansion_coeffs(self): - spherical_expansion_calculator = torch.jit.script(SphericalExpansion(self.hypers, self.all_species, device=self.device)) - spherical_expansion_calculator.forward(**self.batch) - - def test_spherical_expansion_coeffs_alchemical(self): - with open("tests/data/expansion_coeffs-ethanol1_0-alchemical-hypers.json", "r") as f: - hypers = json.load(f) - spherical_expansion_calculator = torch.jit.script(SphericalExpansion(hypers, self.all_species, device=self.device)) - spherical_expansion_calculator.forward(**self.batch) - -class TestArtificialSphericalExpansion: - """ - Tests on the artificial dataset - """ - device = "cpu" - frames = ase.io.read('tests/datasets/artificial.extxyz', ':') - all_species = list(np.unique(np.hstack([frame.numbers for frame in frames]))) - all_species = [int(species) for species in all_species] - with open("tests/data/expansion_coeffs-artificial-hypers.json", "r") as f: - hypers = json.load(f) - - transformers = [TransformerNeighborList(cutoff=hypers["cutoff radius"])] - dataset = InMemoryDataset(frames, transformers) - loader = DataLoader(dataset, batch_size=len(frames), collate_fn=collate_nl) - batch = next(iter(loader)) - - def test_vector_expansion_coeffs(self): - vector_expansion = torch.jit.script(VectorExpansion(self.hypers, self.all_species, device=self.device)) - vector_expansion.forward(**self.batch) - - def test_spherical_expansion_coeffs(self): - spherical_expansion_calculator = torch.jit.script(SphericalExpansion(self.hypers, self.all_species, device=self.device)) - spherical_expansion_calculator.forward(**self.batch) - - def test_spherical_expansion_coeffs_artificial(self): - with open("tests/data/expansion_coeffs-artificial-alchemical-hypers.json", "r") as f: - hypers = json.load(f) - spherical_expansion_calculator = torch.jit.script(SphericalExpansion(hypers, self.all_species, device=self.device)) - spherical_expansion_calculator.forward(**self.batch)