Skip to content

Commit

Permalink
Rename labels (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Mar 4, 2024
1 parent 44c2cba commit cddc6b6
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 35 deletions.
2 changes: 1 addition & 1 deletion examples/alchemical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(self, structure_batch: Dict[str, torch.Tensor], is_training: bool =
atomic_energies = []
structure_indices = []
for ai, layer_ai in self.nu2_model.items():
block = ps.block({"a_i": int(ai)})
block = ps.block({"center_type": int(ai)})
# print(block.values)
features = block.values.squeeze(dim=1)
structure_indices.append(block.samples.column("structure"))
Expand Down
4 changes: 2 additions & 2 deletions examples/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(self, spex: TensorMap):
ps_values_ai = []
for l in range(self.l_max+1):
cg = (2*l+1)**(-0.5)
block_ai_l = spex.block({"lam": l, "a_i": a_i})
block_ai_l = spex.block({"o3_lambda": l, "center_type": a_i})
c_ai_l = block_ai_l.values

# same as this:
Expand All @@ -33,7 +33,7 @@ def forward(self, spex: TensorMap):

block = TensorBlock(
values=ps_values_ai,
samples=spex.block({"lam": 0, "a_i": a_i}).samples,
samples=spex.block({"o3_lambda": 0, "center_type": a_i}).samples,
components=[],
properties=Labels(
"property",
Expand Down
2 changes: 1 addition & 1 deletion examples/ps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, structures: Dict[str, torch.Tensor], is_training: bool = True)
atomic_energies = []
structure_indices = []
for ai, model_ai in self.nu2_model.items():
block = ps.block({"a_i": int(ai)})
block = ps.block({"center_type": int(ai)})
features = block.values.squeeze(dim=1)
structure_indices.append(block.samples.column("structure"))
atomic_energies.append(
Expand Down
6 changes: 3 additions & 3 deletions tests/compare_vs_rascaline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def function_for_splining_derivative(n, l, r):
spherical_expansion_coefficients_rascaline = calculator.compute(structures)

all_neighbor_species = Labels(
names=["species_neighbor"],
names=["neighbor_type"],
values=np.array(all_species, dtype=np.int32).reshape(-1, 1),
)
spherical_expansion_coefficients_rascaline = spherical_expansion_coefficients_rascaline.keys_to_properties(all_neighbor_species)

for a_i in all_species:
for l in range(l_max+1):
e = spherical_expansion_coefficients_torch_spex.block(lam=l, a_i=a_i).values
n_max_l = spherical_expansion_coefficients_torch_spex.block(lam=l, a_i=a_i).values.shape[2] // len(all_species)
e = spherical_expansion_coefficients_torch_spex.block(o3_lambda=l, center_type=a_i).values
n_max_l = spherical_expansion_coefficients_torch_spex.block(o3_lambda=l, center_type=a_i).values.shape[2] // len(all_species)
rascaline_indices = []
for a_i_index in range(len(all_species)):
for n in range(n_max_l):
Expand Down
2 changes: 1 addition & 1 deletion tests/plot_radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_dummy_structures(r_array):
calculator = SphericalExpansion(hypers_spherical_expansion, [1, 6])
spherical_expansion_coefficients = calculator(**structures)

block_C_0 = spherical_expansion_coefficients.block(a_i = 6, lam = 0)
block_C_0 = spherical_expansion_coefficients.block(center_type = 6, o3_lambda = 0)
print("Block shape is", block_C_0.values.shape)

block_C_0_0 = block_C_0.values[:, :, 2].flatten().detach().numpy()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, spherical_expansion_kwargs, is_compute_forces=True):
spherical_expansion_kwargs["positions"].requires_grad = True
if is_compute_forces:
spherical_expansion = self.spherical_expansion_calculator(**spherical_expansion_kwargs)
tm = metatensor.torch.sum_over_samples(spherical_expansion, sample_names="center").components_to_properties(["m"]).keys_to_properties(["a_i", "lam", "sigma"])
tm = metatensor.torch.sum_over_samples(spherical_expansion, sample_names="atom").components_to_properties(["o3_mu"]).keys_to_properties(["center_type", "o3_lambda", "o3_sigma"])
energies = torch.sum(tm.block().values, axis=1)

gradient = torch.autograd.grad(
Expand Down
64 changes: 58 additions & 6 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,57 @@
from torch_spex.structures import InMemoryDataset, TransformerNeighborList, collate_nl
from torch.utils.data import DataLoader


def rename_old_labels(labels: metatensor.torch.Labels):
# The reference values were saved with old names
new_labels = labels
if "l" in labels.names:
new_labels = new_labels.rename("l", "o3_lambda")
if "a_i" in labels.names:
new_labels = new_labels.rename("a_i", "center_type")
if "lam" in new_labels.names:
new_labels = new_labels.rename("lam", "o3_lambda")
if "sigma" in new_labels.names:
new_labels = new_labels.rename("sigma", "o3_sigma")
if "m" in new_labels.names:
new_labels = new_labels.rename("m", "o3_mu")
if "l1" in new_labels.names:
new_labels = new_labels.remove("l1")
if "center" in new_labels.names and "neighbor" not in new_labels.names:
new_labels = new_labels.rename("center", "atom")
if "a1" in new_labels.names:
new_labels = new_labels.rename("a1", "neighbor_type")
if "alphaj" in new_labels.names:
new_labels = new_labels.rename("alphaj", "neighbor_type")
if "alpha_j" in new_labels.names:
new_labels = new_labels.rename("alpha_j", "neighbor_type")
if "n1" in new_labels.names:
new_labels = new_labels.rename("n1", "n")
if "species_center" in new_labels.names:
new_labels = new_labels.rename("species_center", "center_type")
if "species_neighbor" in new_labels.names:
new_labels = new_labels.rename("species_neighbor", "neighbor_type")
if "direction" in new_labels.names:
new_labels = new_labels.rename("direction", "xyz")
return new_labels

def rename_old_tm(tm: metatensor.torch.TensorMap):
# The reference values were saved with old names
keys = rename_old_labels(tm.keys)
blocks = []
for block in tm.blocks():
blocks.append(
metatensor.torch.TensorBlock(
values=block.values,
samples=rename_old_labels(block.samples),
components=[rename_old_labels(component) for component in block.components],
properties=rename_old_labels(block.properties)
)
)

return metatensor.torch.TensorMap(keys=keys, blocks=blocks)


class TestEthanol1SphericalExpansion:
"""
Tests on the ethanol1 dataset
Expand Down Expand Up @@ -41,7 +92,7 @@ def test_vector_expansion_coeffs(self):
# Default types are float32 so we cannot get higher accuracy than 1e-7.
# Because the reference value have been cacluated using float32 and
# now we using float64 computation the accuracy had to be decreased again
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5)

vector_expansion_script = torch.jit.script(vector_expansion)
with torch.no_grad():
Expand All @@ -58,7 +109,7 @@ def test_spherical_expansion_coeffs(self):
# Default types are float32 so we cannot get higher accuracy than 1e-7.
# Because the reference value have been cacluated using float32 and
# now we using float64 computation the accuracy had to be decreased again
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5)

spherical_expansion_script = torch.jit.script(spherical_expansion_calculator)
with torch.no_grad():
Expand Down Expand Up @@ -88,7 +139,7 @@ def test_spherical_expansion_coeffs_alchemical(self):
# Default types are float32 so we cannot get higher accuracy than 1e-7.
# Because the reference value have been cacluated using float32 and
# now we using float64 computation the accuracy had to be decreased again
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5)

class TestArtificialSphericalExpansion:
"""
Expand Down Expand Up @@ -116,7 +167,7 @@ def test_vector_expansion_coeffs(self):
vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = metatensor.torch.sort(vector_expansion.forward(**self.batch))
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5)

def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-data.npz")
Expand All @@ -126,7 +177,7 @@ def test_spherical_expansion_coeffs(self):
tm = spherical_expansion_calculator.forward(**self.batch)
# The absolute accuracy is a bit smaller than in the ethanol case
# I presume it is because we use 5 frames instead of just one
assert metatensor.torch.allclose(tm_ref, tm, atol=3e-5, rtol=1e-5)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=3e-5, rtol=1e-5)

def test_spherical_expansion_coeffs_artificial(self):
with open("tests/data/expansion_coeffs-artificial-alchemical-hypers.json", "r") as f:
Expand All @@ -144,4 +195,5 @@ def test_spherical_expansion_coeffs_artificial(self):
)
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
assert metatensor.torch.allclose(tm_ref, tm, atol=1e-5, rtol=1e-5)
print(rename_old_tm(tm_ref).block(0).properties)
assert metatensor.torch.allclose(rename_old_tm(tm_ref), tm, atol=1e-5, rtol=1e-5)
6 changes: 3 additions & 3 deletions torch_spex/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, hypers, all_species) -> None:
torch.nn.Linear(len(all_species), self.n_pseudo_species, bias=False)
)
self.species_neighbor_labels = Labels(
names = ["species_neighbor"],
names = ["neighbor_type"],
values = torch.tensor(self.all_species, dtype=torch.int).unsqueeze(1)
)
else:
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(self, hypers, all_species) -> None:

def radial_transform(self, r, samples_metadata: Labels):
if self.is_physical:
a_i = samples_metadata.column("species_center")
a_j = samples_metadata.column("species_neighbor")
a_i = samples_metadata.column("center_type")
a_j = samples_metadata.column("neighbor_type")
x = r/(0.1+torch.exp(self.lengthscales[a_i])+torch.exp(self.lengthscales[a_j]))
return x
else:
Expand Down
33 changes: 16 additions & 17 deletions torch_spex/spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class SphericalExpansion(torch.nn.Module):
>>> expansion = spherical_expansion(**batch)
>>> print(expansion.keys)
Labels(
a_i lam sigma
1 0 1
8 0 1
center_type o3_lambda o3_sigma
1 0 1
8 0 1
)
"""
Expand Down Expand Up @@ -141,7 +141,7 @@ def forward(self,
expanded_vectors = self.vector_expansion_calculator(
positions, cells, species, cell_shifts, centers, pairs, structure_centers, structure_pairs, structure_offsets)

samples_metadata = expanded_vectors.block({"l": 0}).samples
samples_metadata = expanded_vectors.block({"o3_lambda": 0}).samples

n_species = len(self.all_species)
species_to_index = {atomic_number : i_species for i_species, atomic_number in enumerate(self.all_species)}
Expand All @@ -156,7 +156,7 @@ def forward(self,
if self.is_alchemical:
density_indices = s_i_metadata_to_unique
for l in range(l_max+1):
expanded_vectors_l = expanded_vectors.block({"l": l}).values
expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values
densities_l = torch.zeros(
(n_centers, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]),
dtype = expanded_vectors_l.dtype,
Expand All @@ -167,12 +167,12 @@ def forward(self,
densities.append(densities_l)
unique_species = -torch.arange(self.n_pseudo_species, dtype=torch.int64, device=density_indices.device)
else:
aj_metadata = samples_metadata.column("species_neighbor")
aj_metadata = samples_metadata.column("neighbor_type")
aj_shifts = torch.tensor([species_to_index[int(aj_index)] for aj_index in aj_metadata], dtype=torch.int64, device=aj_metadata.device)
density_indices = s_i_metadata_to_unique*n_species+aj_shifts

for l in range(l_max+1):
expanded_vectors_l = expanded_vectors.block({"l": l}).values
expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values
densities_l = torch.zeros(
(n_centers*n_species, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]),
dtype = expanded_vectors_l.dtype,
Expand All @@ -188,7 +188,7 @@ def forward(self,
blocks : List[TensorBlock] = []
for l in range(l_max+1):
densities_l = densities[l]
vectors_l_block = expanded_vectors.block({"l": l})
vectors_l_block = expanded_vectors.block({"o3_lambda": l})
vectors_l_block_components = vectors_l_block.components
vectors_l_block_n = torch.arange(len(torch.unique(vectors_l_block.properties.column("n"))), dtype=torch.int64, device=species.device) # Need to be smarter to optimize
for a_i in self.all_species:
Expand All @@ -205,17 +205,16 @@ def forward(self,
TensorBlock(
values = densities_ai_l,
samples = Labels(
names = ["structure", "center"],
names = ["structure", "atom"],
values = unique_s_i_indices[where_ai]
),
components = vectors_l_block_components,
properties = Labels(
names = ["a1", "n1", "l1"],
names = ["neighbor_type", "n"],
values = torch.stack(
[
torch.repeat_interleave(unique_species, vectors_l_block_n.shape[0]),
torch.tile(vectors_l_block_n, (unique_species.shape[0],)),
l*torch.ones((densities_ai_l.shape[2],), dtype=torch.int, device=densities_ai_l.device)
],
dim=1
)
Expand All @@ -225,7 +224,7 @@ def forward(self,

spherical_expansion = TensorMap(
keys = Labels(
names = ["a_i", "lam", "sigma"],
names = ["center_type", "o3_lambda", "o3_sigma"],
values = torch.tensor(labels, dtype=torch.int32, device=species.device)
),
blocks = blocks
Expand Down Expand Up @@ -353,7 +352,7 @@ def forward(self,
n_max_l = vector_expansion_l.shape[2]
if self.is_alchemical:
properties = Labels(
names = ["alpha_j", "n"],
names = ["neighbor_type", "n"],
values = torch.stack(
[
torch.repeat_interleave(-torch.arange(self.n_pseudo_species, dtype=torch.int64, device=vector_expansion_l.device), n_max_l),
Expand All @@ -372,7 +371,7 @@ def forward(self,
values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1),
samples = cartesian_vectors.samples,
components = [Labels(
names = ("m",),
names = ("o3_mu",),
values = torch.arange(start=-l, end=l+1, dtype=torch.int32, device=vector_expansion_l.device).reshape(2*l+1, 1)
)],
properties = properties.to(vector_expansion_l.device)
Expand All @@ -382,7 +381,7 @@ def forward(self,
l_max = len(vector_expansion_blocks) - 1
vector_expansion_tmap = TensorMap(
keys = Labels(
names = ("l",),
names = ("o3_lambda",),
values = torch.arange(start=0, end=l_max+1, dtype=torch.int32, device=vector_expansion_blocks[0].values.device).reshape(l_max+1, 1),
),
blocks = vector_expansion_blocks
Expand Down Expand Up @@ -421,12 +420,12 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs
block = TensorBlock(
values = direction_vectors.unsqueeze(dim=-1),
samples = Labels(
names = ["structure", "center", "neighbor", "species_center", "species_neighbor", "cell_x", "cell_y", "cell_z"],
names = ["structure", "center", "neighbor", "center_type", "neighbor_type", "cell_x", "cell_y", "cell_z"],
values = labels
),
components = [
Labels(
names = ["cartesian_dimension"],
names = ["xyz"],
values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1))
)
],
Expand Down

0 comments on commit cddc6b6

Please sign in to comment.