Skip to content

Commit

Permalink
Allow double backward propagation in RascalineAutograd
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Jul 28, 2023
1 parent c62ef49 commit f2d9a66
Show file tree
Hide file tree
Showing 2 changed files with 513 additions and 172 deletions.
193 changes: 122 additions & 71 deletions python/rascaline-torch/tests/autograd.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import warnings

import ase
import pytest
import torch

from rascaline.torch import SoapPowerSpectrum, SphericalExpansion, System


HYPERS = {
"cutoff": 3,
"max_radial": 2,
"max_angular": 0,
"max_radial": 10,
"max_angular": 5,
"atomic_gaussian_width": 0.3,
"center_atom_weight": 1.0,
"cutoff_function": {"ShiftedCosine": {"width": 0.5}},
"radial_basis": {"Gto": {}},
}


def create_random_system(n_atoms, cell_size):
def _create_random_system(n_atoms, cell_size):
torch.manual_seed(0)
species = torch.randint(3, (n_atoms,), dtype=torch.int)

Expand All @@ -27,61 +30,65 @@ def create_random_system(n_atoms, cell_size):
return species, positions, cell


def test_spherical_expansion_positions_grad():
species, positions, cell = create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True
def _compute_spherical_expansion(species, positions, cell):
system = System(
positions=positions,
species=species,
cell=cell,
)

calculator = SphericalExpansion(**HYPERS)
descriptor = calculator(system)
descriptor = descriptor.components_to_properties("spherical_harmonics_m")
descriptor = descriptor.keys_to_properties("spherical_harmonics_l")

def compute(species, positions, cell):
system = System(
positions=positions,
species=species,
cell=cell,
)
descriptor = calculator(system)
descriptor = descriptor.components_to_properties("spherical_harmonics_m")
descriptor = descriptor.keys_to_properties("spherical_harmonics_l")
descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties("species_neighbor")

return descriptor.block(0).values


def _compute_power_spectrum(species, positions, cell):
system = System(
positions=positions,
species=species,
cell=cell,
)

calculator = SoapPowerSpectrum(**HYPERS)
descriptor = calculator(system)
descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties(
["species_neighbor_1", "species_neighbor_2"]
)

return descriptor.block(0).values

descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties("species_neighbor")

return descriptor.block(0).values
def test_spherical_expansion_positions_grad():
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

assert torch.autograd.gradcheck(
compute,
_compute_spherical_expansion,
(species, positions, cell),
fast_mode=True,
)


def test_spherical_expansion_cell_grad():
species, positions, cell = create_random_system(n_atoms=75, cell_size=5.0)
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)

original_cell = cell.clone()
cell.requires_grad = True

calculator = SphericalExpansion(**HYPERS)

def compute(species, positions, cell):
# modifying the cell for numerical gradients should also displace
# the atoms
fractional = positions @ torch.linalg.inv(original_cell)
positions = fractional @ cell.detach()

system = System(
positions=positions,
species=species,
cell=cell,
)
descriptor = calculator(system)
descriptor = descriptor.components_to_properties("spherical_harmonics_m")
descriptor = descriptor.keys_to_properties("spherical_harmonics_l")

descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties("species_neighbor")

return descriptor.block(0).values
return _compute_spherical_expansion(species, positions, cell)

assert torch.autograd.gradcheck(
compute,
Expand All @@ -91,63 +98,107 @@ def compute(species, positions, cell):


def test_power_spectrum_positions_grad():
species, positions, cell = create_random_system(n_atoms=75, cell_size=5.0)
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

calculator = SoapPowerSpectrum(**HYPERS)

def compute(species, positions, cell):
system = System(
positions=positions,
species=species,
cell=cell,
)
descriptor = calculator(system)

descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties(
["species_neighbor_1", "species_neighbor_2"]
)

return descriptor.block(0).values

assert torch.autograd.gradcheck(
compute,
_compute_power_spectrum,
(species, positions, cell),
fast_mode=True,
)


def test_power_spectrum_cell_grad():
species, positions, cell = create_random_system(n_atoms=75, cell_size=5.0)
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)

original_cell = cell.clone()
cell.requires_grad = True

calculator = SoapPowerSpectrum(**HYPERS)

def compute(species, positions, cell):
# modifying the cell for numerical gradients should also displace
# the atoms
fractional = positions @ torch.linalg.inv(original_cell)
positions = fractional @ cell.detach()

system = System(
positions=positions,
species=species,
cell=cell,
)
descriptor = calculator(system)

descriptor = descriptor.keys_to_samples("species_center")
descriptor = descriptor.keys_to_properties(
["species_neighbor_1", "species_neighbor_2"]
)

return descriptor.block(0).values
return _compute_power_spectrum(species, positions, cell)

assert torch.autograd.gradcheck(
compute,
(species, positions, cell),
fast_mode=True,
)


def test_power_spectrum_positions_grad_grad():
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

X = _compute_power_spectrum(species, positions, cell)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(species, positions, cell)
A = X @ weights

return torch.autograd.grad(
outputs=A,
inputs=positions,
grad_outputs=torch.ones_like(A),
retain_graph=True,
create_graph=True,
)[0]

message = (
"second derivatives with respect to positions are not implemented and "
"will not be accumulated during backward\\(\\) calls"
)
computed = torch.sum(compute(weights))
with pytest.warns(UserWarning, match=message):
computed.backward(retain_graph=True)

# check that double backward still allows for gradients of weights w.r.t. forces
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=message)

assert torch.autograd.gradcheck(
compute,
(weights),
fast_mode=True,
)


def test_power_spectrum_cell_grad_grad():
species, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

X = _compute_power_spectrum(species, positions, cell)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(species, positions, cell)
A = X @ weights

return torch.autograd.grad(
outputs=A,
inputs=cell,
grad_outputs=torch.ones_like(A),
retain_graph=True,
create_graph=True,
)[0]

message = (
"second derivatives with respect to cell matrix are not implemented and "
"will not be accumulated during backward\\(\\) calls"
)
computed = torch.sum(compute(weights))
with pytest.warns(UserWarning, match=message):
computed.backward(retain_graph=True)

# check that double backward still allows for gradients of weights w.r.t. virial
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=message)

assert torch.autograd.gradcheck(
compute,
(weights),
fast_mode=True,
)
Loading

0 comments on commit f2d9a66

Please sign in to comment.