diff --git a/.github/workflows/pet-jax.yml b/.github/workflows/pet-jax.yml new file mode 100644 index 000000000..cc442f296 --- /dev/null +++ b/.github/workflows/pet-jax.yml @@ -0,0 +1,40 @@ +name: PET-JAX tests + +on: + push: + branches: [main] + pull_request: + # Check all PR + +jobs: + tests: + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-22.04 + python-version: "3.12" + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install tox + + - name: Install JAX + # JAX does not work as a dependency; it needs to be installed separately + run: pip install jax[cpu] + + - name: run PET-JAX tests + run: tox -e pet-jax-tests + env: + # Use the CPU only version of torch when building/running the code + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu + + - name: Upload codecoverage + uses: codecov/codecov-action@v4 + with: + files: ./tests/coverage.xml diff --git a/docs/src/architectures/pet-jax.rst b/docs/src/architectures/pet-jax.rst new file mode 100644 index 000000000..844bc6409 --- /dev/null +++ b/docs/src/architectures/pet-jax.rst @@ -0,0 +1,39 @@ +.. _architecture-pet-jax: + +PET-JAX +========= + +This is a JAX implementation of the PET architecture. + +Installation +------------ +To use PET-JAX within ``metatensor-models``, you should already have +JAX installed for your platform (see the official JAX installation instructions). +Then, you can run the following command in the root directory of the repository: + +.. code-block:: bash + + pip install .[pet-jax] + +Following this, it is also necessary to hot-fix a few lines of your torch installation +to allow PET-JAX models to be exported. This can be achieved by running the following +Python script: + +.. literalinclude:: ../../../src/metatensor/models/experimental/pet_jax/hotfix_torch.py + :language: python + +Default Hyperparameters +----------------------- +The default hyperparameters for the PET-JAX model are: + +.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/experimental.pet_jax.yaml + :language: yaml + + +Tuning Hyperparameters +---------------------- +To be done. + +References +---------- +.. footbibliography:: diff --git a/pyproject.toml b/pyproject.toml index 5ddd0aa7c..07cccd10f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,12 @@ soap-bpnn = [] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@fafb0bd", ] +pet-jax = [ + "jax", + "equinox", + "optax", +] + [tool.setuptools.packages.find] where = ["src"] diff --git a/src/metatensor/models/cli/conf/architecture/experimental.pet_jax.yaml b/src/metatensor/models/cli/conf/architecture/experimental.pet_jax.yaml new file mode 100644 index 000000000..87ed69ba4 --- /dev/null +++ b/src/metatensor/models/cli/conf/architecture/experimental.pet_jax.yaml @@ -0,0 +1,16 @@ +# default hyperparameters for the PET-JAX model +model: + cutoff: 5.0 + d_pet: 128 + num_heads: 4 + num_attention_layers: 2 + num_gnn_layers: 2 + mlp_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + +training: + batch_size: 16 + num_warmup_steps: 1000 + num_epochs: 10000 + learning_rate: 3e-4 + log_interval: 10 diff --git a/src/metatensor/models/experimental/pet_jax/__init__.py b/src/metatensor/models/experimental/pet_jax/__init__.py new file mode 100644 index 000000000..ff9a77daf --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/__init__.py @@ -0,0 +1,2 @@ +from .model import Model, DEFAULT_HYPERS # noqa: F401 +from .train import train # noqa: F401 diff --git a/src/metatensor/models/experimental/pet_jax/hotfix_torch.py b/src/metatensor/models/experimental/pet_jax/hotfix_torch.py new file mode 100644 index 000000000..d7ad212e9 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/hotfix_torch.py @@ -0,0 +1,28 @@ +# This is fixing a small bug in the attention implementation +# in torch that prevents it from being torchscriptable. + +import os + +import torch + + +file = os.path.join(os.path.dirname(torch.__file__), "nn", "modules", "activation.py") + +with open(file, "r") as f: + lines = f.readlines() + for i, line in enumerate(lines): + if ( + "elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:" # noqa: E501 + in line + ): + lines[i] = line.replace( + "elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:", # noqa: E501 + "elif self.in_proj_bias is not None:\n" + " if query.dtype != self.in_proj_bias.dtype:", + ) + lines[i + 1] = ( + ' why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) do not match"\n' # noqa: E501 + ) + +with open(file, "w") as f: + f.writelines(lines) diff --git a/src/metatensor/models/experimental/pet_jax/model.py b/src/metatensor/models/experimental/pet_jax/model.py new file mode 100644 index 000000000..f74b2ddfc --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/model.py @@ -0,0 +1,219 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelOutput, NeighborsListOptions, System +from omegaconf import OmegaConf + +from ... import ARCHITECTURE_CONFIG_PATH +from .pet.pet_torch.corresponding_edges import get_corresponding_edges +from .pet.pet_torch.encoder import Encoder +from .pet.pet_torch.nef import edge_array_to_nef, get_nef_indices, nef_array_to_edges +from .pet.pet_torch.radial_mask import get_radial_mask +from .pet.pet_torch.structures import concatenate_structures +from .pet.pet_torch.transformer import Transformer + + +ARCHITECTURE_NAME = "experimental.pet_jax" +DEFAULT_HYPERS = OmegaConf.to_container( + OmegaConf.load(ARCHITECTURE_CONFIG_PATH / f"{ARCHITECTURE_NAME}.yaml") +) +DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"] + + +class Model(torch.nn.Module): + + def __init__(self, capabilities, hypers, composition_weights): + super().__init__() + self.name = ARCHITECTURE_NAME + self.hypers = hypers + + self.capabilities = capabilities + + # Handle species + self.all_species = capabilities.species + n_species = len(self.all_species) + self.species_to_species_index = torch.full( + (max(self.all_species) + 1,), + -1, + ) + for i, species in enumerate(self.all_species): + self.species_to_species_index[species] = i + print("Species indices:", self.species_to_species_index) + print("Number of species:", n_species) + + self.encoder = Encoder(n_species, hypers["d_pet"]) + + self.transformer = Transformer( + hypers["d_pet"], + 4 * hypers["d_pet"], + hypers["num_heads"], + hypers["num_attention_layers"], + hypers["mlp_dropout_rate"], + hypers["attention_dropout_rate"], + ) + self.readout = torch.nn.Linear(hypers["d_pet"], 1, bias=False) + + self.num_mp_layers = hypers["num_gnn_layers"] - 1 + gnn_contractions = [] + gnn_transformers = [] + for _ in range(self.num_mp_layers): + gnn_contractions.append( + torch.nn.Linear(2 * hypers["d_pet"], hypers["d_pet"], bias=False) + ) + gnn_transformers.append( + Transformer( + hypers["d_pet"], + 4 * hypers["d_pet"], + hypers["num_heads"], + hypers["num_attention_layers"], + hypers["mlp_dropout_rate"], + hypers["attention_dropout_rate"], + ) + ) + self.gnn_contractions = torch.nn.ModuleList(gnn_contractions) + self.gnn_transformers = torch.nn.ModuleList(gnn_transformers) + + self.register_buffer("composition_weights", composition_weights) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + # Checks on systems (species) and outputs are done in the + # MetatensorAtomisticModel wrapper + + if selected_atoms is not None: + raise NotImplementedError( + "The PET model does not support domain decomposition." + ) + + n_structures = len(systems) + positions, centers, neighbors, species, segment_indices, edge_vectors = ( + concatenate_structures(systems) + ) + max_edges_per_node = int(torch.max(torch.bincount(centers))) + + # Convert to NEF: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + centers, len(positions), max_edges_per_node + ) + + # Get radial mask + r = torch.sqrt(torch.sum(edge_vectors**2, dim=-1)) + radial_mask = get_radial_mask(r, 5.0, 3.0) + + # Element indices + element_indices_nodes = self.species_to_species_index[species] + element_indices_centers = element_indices_nodes[centers] + element_indices_neighbors = element_indices_nodes[neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + radial_mask = edge_array_to_nef( + radial_mask, nef_indices, nef_mask, fill_value=0.0 + ) + element_indices_centers = edge_array_to_nef( + element_indices_centers, nef_indices + ) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + + features = { + "cartesian": edge_vectors, + "center": element_indices_centers, + "neighbor": element_indices_neighbors, + } + + # Encode + features = self.encoder(features) + + # Transformer + features = self.transformer(features, radial_mask) + + # GNN + if self.num_mp_layers > 0: + corresponding_edges = get_corresponding_edges( + torch.stack([centers, neighbors], dim=-1) + ) + for contraction, transformer in zip( + self.gnn_contractions, self.gnn_transformers + ): + new_features = nef_array_to_edges( + features, centers, nef_to_edges_neighbor + ) + corresponding_new_features = new_features[corresponding_edges] + new_features = torch.concatenate( + [new_features, corresponding_new_features], dim=-1 + ) + new_features = contraction(new_features) + new_features = edge_array_to_nef(new_features, nef_indices) + new_features = transformer(new_features, radial_mask) + features = features + new_features + + # Readout + edge_energies = self.readout(features) + edge_energies = edge_energies * radial_mask[:, :, None] + + # Sum over edges + atomic_energies = torch.sum( + edge_energies, dim=(1, 2) + ) # also eliminate singleton dimension 2 + + # Sum over centers + structure_energies = torch.zeros( + n_structures, dtype=atomic_energies.dtype, device=atomic_energies.device + ) + structure_energies.index_add_(0, segment_indices, atomic_energies) + + # TODO: use utils? use composition calculator? + composition = torch.zeros( + (n_structures, len(self.all_species)), device=atomic_energies.device + ) + for number in self.all_species: + where_number = (species == number).to(composition.dtype) + composition[:, self.species_to_species_index[number]].index_add_( + 0, segment_indices, where_number + ) + + structure_energies = structure_energies + composition @ self.composition_weights + + return { + list(outputs.keys())[0]: TensorMap( + keys=Labels( + names=["_"], + values=torch.tensor([[0]], device=structure_energies.device), + ), + blocks=[ + TensorBlock( + values=structure_energies.unsqueeze(1), + samples=Labels( + names=["structure"], + values=torch.arange( + n_structures, device=structure_energies.device + ).unsqueeze(1), + ), + components=[], + properties=Labels( + names=["_"], + values=torch.tensor( + [[0]], device=structure_energies.device + ), + ), + ) + ], + ) + } + + def requested_neighbors_lists( + self, + ) -> List[NeighborsListOptions]: + return [ + NeighborsListOptions( + model_cutoff=self.hypers["cutoff"], + full_list=True, + ) + ] diff --git a/src/metatensor/models/experimental/pet_jax/pet/__init__.py b/src/metatensor/models/experimental/pet_jax/pet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/metatensor/models/experimental/pet_jax/pet/attention.py b/src/metatensor/models/experimental/pet_jax/pet/attention.py new file mode 100644 index 000000000..d9ea1393e --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/attention.py @@ -0,0 +1,69 @@ +import equinox as eqx +import jax +import jax.numpy as jnp + + +class AttentionBlock(eqx.Module): + """A single transformer attention block.""" + + # attention: RadialAttention + attention: eqx.nn.MultiheadAttention + layernorm: eqx.nn.Embedding + dropout: eqx.nn.Dropout + num_heads: int + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float, + attention_dropout_rate: float, + key: jax.random.PRNGKey, + ): + self.num_heads = num_heads + self.attention = eqx.nn.MultiheadAttention( + num_heads=num_heads, + query_size=hidden_size, + dropout_p=attention_dropout_rate, + key=key, + ) + + self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) + self.dropout = eqx.nn.Dropout(dropout_rate) + + def __call__( + self, + inputs: jnp.ndarray, # seq_len hidden_size + radial_mask: jnp.ndarray, # seq_len + enable_dropout: bool = False, + key: "jax.random.PRNGKey" = None, + ) -> jnp.ndarray: # seq_len hidden_size + + attention_key, dropout_key = ( + (None, None) if key is None else jax.random.split(key) + ) + + # Apply radial mask + inputs = inputs * radial_mask[:, None] + + # Pre-layer normalization + normed_inputs = jax.vmap(self.layernorm)(inputs) + + # Attention + attention_output = self.attention( + query=normed_inputs, + key_=normed_inputs, + value=normed_inputs, + inference=not enable_dropout, + key=attention_key, + ) + + # Apply dropout + output = self.dropout( + attention_output, inference=not enable_dropout, key=dropout_key + ) + + # Residual connection + output += inputs + + return output diff --git a/src/metatensor/models/experimental/pet_jax/pet/encoder.py b/src/metatensor/models/experimental/pet_jax/pet/encoder.py new file mode 100644 index 000000000..3ca2dc01e --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/encoder.py @@ -0,0 +1,61 @@ +from typing import Dict + +import equinox as eqx +import jax +import jax.numpy as jnp + + +class Encoder(eqx.Module): + + cartesian_encoder: eqx.nn.Linear + center_encoder: eqx.nn.Embedding + neighbor_encoder: eqx.nn.Embedding + compressor: eqx.nn.Linear + + def __init__( + self, + n_species: int, + hidden_size: int, + key: jax.random.PRNGKey, + ): + key1, key2, key3 = jax.random.split(key, num=3) + + self.cartesian_encoder = eqx.nn.Linear( + in_features=3, out_features=hidden_size, key=key1 + ) + self.center_encoder = eqx.nn.Embedding( + num_embeddings=n_species, embedding_size=hidden_size, key=key2 + ) + self.neighbor_encoder = eqx.nn.Embedding( + num_embeddings=n_species, embedding_size=hidden_size, key=key3 + ) + self.compressor = eqx.nn.Linear( + in_features=3 * hidden_size, out_features=hidden_size, key=key + ) + + def __call__( + self, + features: Dict[str, jnp.ndarray], + ): + # Encode cartesian coordinates + cartesian_features = jax.vmap(jax.vmap(self.cartesian_encoder))( + features["cartesian"] + ) + + # Encode centers + center_features = jax.vmap(jax.vmap(self.center_encoder))(features["center"]) + + # Encode neighbors + neighbor_features = jax.vmap(jax.vmap((self.neighbor_encoder)))( + features["neighbor"] + ) + + # Concatenate + encoded_features = jnp.concatenate( + [cartesian_features, center_features, neighbor_features], axis=-1 + ) + + # Compress + compressed_features = jax.vmap(jax.vmap(self.compressor))(encoded_features) + + return compressed_features diff --git a/src/metatensor/models/experimental/pet_jax/pet/feedforward.py b/src/metatensor/models/experimental/pet_jax/pet/feedforward.py new file mode 100644 index 000000000..51bd4f4ed --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/feedforward.py @@ -0,0 +1,57 @@ +from typing import Optional + +import equinox as eqx +import jax +import jax.numpy as jnp + + +class FeedForwardBlock(eqx.Module): + """A single transformer feed forward block.""" + + mlp: eqx.nn.Linear + output: eqx.nn.Linear + layernorm: eqx.nn.LayerNorm + dropout: eqx.nn.Dropout + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + dropout_rate: float, + key: jax.random.PRNGKey, + ): + mlp_key, output_key = jax.random.split(key) + self.mlp = eqx.nn.Linear( + in_features=hidden_size, out_features=intermediate_size, key=mlp_key + ) + self.output = eqx.nn.Linear( + in_features=intermediate_size, out_features=hidden_size, key=output_key + ) + + self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) + self.dropout = eqx.nn.Dropout(dropout_rate) + + def __call__( + self, + inputs: jnp.ndarray, # hidden_size + enable_dropout: bool = True, + key: Optional[jax.random.PRNGKey] = None, + ) -> jnp.ndarray: # hidden_size + + # Pre-layer normalization + normed_inputs = self.layernorm(inputs) + + # Feed-forward + hidden = self.mlp(normed_inputs) + hidden = jax.nn.gelu(hidden) + + # Project back to input size + output = self.output(hidden) + + # Apply dropout + output = self.dropout(output, inference=not enable_dropout, key=key) + + # Residual connection + output += inputs + + return output diff --git a/src/metatensor/models/experimental/pet_jax/pet/models.py b/src/metatensor/models/experimental/pet_jax/pet/models.py new file mode 100644 index 000000000..e50264e33 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/models.py @@ -0,0 +1,238 @@ +from typing import List + +import equinox as eqx +import jax +import jax.numpy as jnp + +from .encoder import Encoder +from .radial_mask import get_radial_mask +from .transformer import Transformer +from .utils.corresponding_edges import get_corresponding_edges +from .utils.jax_batch import JAXBatch +from .utils.nef import edge_array_to_nef, get_nef_indices, nef_array_to_edges + + +class PET(eqx.Module): + + # note: these are registered in the PyTree in this order + all_species: List[int] + species_to_species_index: jnp.ndarray + encoder: Encoder + transformer: Transformer + readout: eqx.nn.MLP + gnn_contractions: List[eqx.nn.Linear] + gnn_transformers: List[Transformer] + composition_weights: jnp.ndarray + + def __init__(self, all_species, hypers, composition_weights, key): + n_species = len(all_species) + print("hello 2") + + # Handle species + self.all_species = all_species + self.species_to_species_index = jnp.full( + (jnp.max(all_species) + 1,), + -1, + dtype=(jnp.int64 if jax.config.jax_enable_x64 else jnp.int32), + ) + for i, species in enumerate(all_species): + self.species_to_species_index = self.species_to_species_index.at[ + species + ].set(i) + print("Species indices:", self.species_to_species_index) + print("Number of species:", n_species) + + key_enc, key_attn, key_readout, key_gnns = jax.random.split(key, 4) + self.encoder = Encoder(n_species, hypers["d_pet"], key_enc) + self.transformer = Transformer( + hypers["d_pet"], + 4 * hypers["d_pet"], + hypers["num_heads"], + hypers["num_attention_layers"], + hypers["mlp_dropout_rate"], + hypers["attention_dropout_rate"], + key_attn, + ) + self.readout = eqx.nn.Linear( + hypers["d_pet"], 1, use_bias=False, key=key_readout + ) + num_mp_layers = hypers["num_gnn_layers"] - 1 + gnn_keys = jax.random.split(key_gnns, num_mp_layers) + self.gnn_transformers = [] + self.gnn_contractions = [] + for i in range(num_mp_layers): + contraction_key, transformer_key = jax.random.split(gnn_keys[i]) + self.gnn_contractions.append( + eqx.nn.Linear( + 2 * hypers["d_pet"], + hypers["d_pet"], + use_bias=False, + key=contraction_key, + ) + ) + self.gnn_transformers.append( + Transformer( + hypers["d_pet"], + 4 * hypers["d_pet"], + hypers["num_heads"], + hypers["num_attention_layers"], + hypers["mlp_dropout_rate"], + hypers["attention_dropout_rate"], + transformer_key, + ) + ) + + self.composition_weights = composition_weights + + def __call__(self, structures, max_edges_per_node, is_training, key=None): + + n_structures = len(structures.n_nodes) + + # Convert to NEF: + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices( + structures.centers, len(structures.positions), max_edges_per_node + ) + + segment_indices = jnp.repeat( + jnp.arange(n_structures), + structures.n_nodes, + total_repeat_length=len(structures.positions), + ) + # segment_indices = segment_indices.at[len(structures.positions):] + # .set(n_structures) + + # get edge vectors: + edge_vectors = ( + structures.positions[structures.neighbors] + - structures.positions[structures.centers] + + jnp.einsum( + "ia, iab -> ib", + structures.cell_shifts, + structures.cells[segment_indices[structures.centers]], + ) + ) + + # Get radial mask + r = jnp.sqrt(jnp.sum(edge_vectors**2, axis=-1)) + radial_mask = jax.vmap(get_radial_mask, in_axes=(0, None, None))(r, 5.0, 3.0) + + # Element indices + element_indices_nodes = self.species_to_species_index[structures.numbers] + element_indices_centers = element_indices_nodes[structures.centers] + element_indices_neighbors = element_indices_nodes[structures.neighbors] + + # Send everything to NEF: + edge_vectors = edge_array_to_nef(edge_vectors, nef_indices) + radial_mask = edge_array_to_nef(radial_mask, nef_indices, nef_mask, 0.0) + element_indices_centers = edge_array_to_nef( + element_indices_centers, nef_indices + ) + element_indices_neighbors = edge_array_to_nef( + element_indices_neighbors, nef_indices + ) + + features = { + "cartesian": edge_vectors, + "center": element_indices_centers, + "neighbor": element_indices_neighbors, + } + + # Encode + features = self.encoder(features) + + # Transformer + features = jax.vmap(self.transformer, in_axes=(0, None, 0, None))( + features, is_training, radial_mask, key + ) + + # GNN + num_mp_layers = len(self.gnn_transformers) + if num_mp_layers > 0: + corresponding_edges = get_corresponding_edges( + jnp.stack([structures.centers, structures.neighbors], axis=-1) + ) + for i in range(num_mp_layers): + new_features = nef_array_to_edges( + features, structures.centers, nef_to_edges_neighbor + ) + corresponding_new_features = new_features[corresponding_edges] + new_features = jax.vmap(self.gnn_contractions[i])( + jnp.concatenate([new_features, corresponding_new_features], axis=-1) + ) + new_features = edge_array_to_nef(new_features, nef_indices) + new_features = jax.vmap( + self.gnn_transformers[i], in_axes=(0, None, 0, None) + )(new_features, is_training, radial_mask, key) + features = features + new_features + + # Readout + edge_energies = jax.vmap(jax.vmap(self.readout))(features) + edge_energies = edge_energies * radial_mask[:, :, None] + + # Sum over edges + atomic_energies = jnp.sum( + edge_energies, axis=(1, 2) + ) # also eliminate singleton dimension 2 + + # Sum over centers + structure_energies = jax.ops.segment_sum( + atomic_energies, + segment_indices, + num_segments=n_structures, + indices_are_sorted=True, + ) + + # Add composition weights + composition = jnp.empty((n_structures, len(self.all_species))) + for number in self.all_species: + where_number = (structures.numbers == number).astype(composition.dtype) + composition = composition.at[:, self.species_to_species_index[number]].set( + jax.ops.segment_sum( + where_number, + segment_indices, + num_segments=n_structures, + ) + ) + + # composition weights are not trainable + structure_energies = structure_energies + composition @ jax.lax.stop_gradient( + self.composition_weights + ) + + return {"energies": structure_energies} + + +@eqx.filter_grad +def predict_forces( + positions: jax.Array, + model: eqx.Module, + structures: JAXBatch, + max_edges_per_node, + is_training, + key, +): + structures = structures._replace(positions=positions) + return jnp.sum(model(structures, max_edges_per_node, is_training, key)["energies"]) + + +class PET_energy_force(eqx.Module): + + pet: PET + + def __init__(self, all_species, hypers, composition_weights, key): + print("hello 1") + self.pet = PET(all_species, hypers, composition_weights, key) + + def __call__(self, structures, max_edges_per_node, is_training, key=None): + energies = self.pet(structures, max_edges_per_node, is_training, key)[ + "energies" + ] + minus_forces = predict_forces( + structures.positions, + self.pet, + structures, + max_edges_per_node, + is_training, + key, + ) + return {"energies": energies, "forces": -minus_forces} diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/__init__.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/attention.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/attention.py new file mode 100644 index 000000000..b23f5957a --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/attention.py @@ -0,0 +1,52 @@ +import torch + + +class AttentionBlock(torch.nn.Module): + """A single transformer attention block.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float, + attention_dropout_rate: float, + ): + super().__init__() + + self.num_heads = num_heads + self.attention = torch.nn.MultiheadAttention( + hidden_size, + num_heads, + dropout=attention_dropout_rate, + bias=False, + batch_first=True, + ) + self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward( + self, + inputs: torch.Tensor, # seq_len hidden_size + radial_mask: torch.Tensor, # seq_len + ) -> torch.Tensor: # seq_len hidden_size + + # Apply radial mask + inputs = inputs * radial_mask[:, :, None] + + # Pre-layer normalization + normed_inputs = self.layernorm(inputs) + + # Attention + attention_output, _ = self.attention( + query=normed_inputs, + key=normed_inputs, + value=normed_inputs, + ) + + # Apply dropout + output = self.dropout(attention_output) + + # Residual connection + output += inputs + + return output diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/corresponding_edges.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/corresponding_edges.py new file mode 100644 index 000000000..52e0cf3a4 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/corresponding_edges.py @@ -0,0 +1,12 @@ +import torch + + +def get_corresponding_edges(array): + n_edges = len(array) + array_inversed = array.flip(1) + inverse_indices = torch.empty((n_edges,), dtype=torch.long) + for i in range(n_edges): + inverse_indices[i] = torch.nonzero( + torch.all(array_inversed == array[i], dim=1) + )[0][0] + return inverse_indices diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/encoder.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/encoder.py new file mode 100644 index 000000000..7a0d19c3c --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/encoder.py @@ -0,0 +1,49 @@ +from typing import Dict + +import torch + + +class Encoder(torch.nn.Module): + + def __init__( + self, + n_species: int, + hidden_size: int, + ): + super().__init__() + + self.cartesian_encoder = torch.nn.Linear( + in_features=3, out_features=hidden_size + ) + self.center_encoder = torch.nn.Embedding( + num_embeddings=n_species, embedding_dim=hidden_size + ) + self.neighbor_encoder = torch.nn.Embedding( + num_embeddings=n_species, embedding_dim=hidden_size + ) + self.compressor = torch.nn.Linear( + in_features=3 * hidden_size, out_features=hidden_size + ) + + def forward( + self, + features: Dict[str, torch.Tensor], + ): + # Encode cartesian coordinates + cartesian_features = self.cartesian_encoder(features["cartesian"]) + + # Encode centers + center_features = self.center_encoder(features["center"]) + + # Encode neighbors + neighbor_features = self.neighbor_encoder(features["neighbor"]) + + # Concatenate + encoded_features = torch.concatenate( + [cartesian_features, center_features, neighbor_features], dim=-1 + ) + + # Compress + compressed_features = self.compressor(encoded_features) + + return compressed_features diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/feedforward.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/feedforward.py new file mode 100644 index 000000000..f3bb975d2 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/feedforward.py @@ -0,0 +1,46 @@ +import torch + + +class FeedForwardBlock(torch.nn.Module): + """A single transformer feed forward block.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + dropout_rate: float, + ): + super().__init__() + + self.mlp = torch.nn.Linear( + in_features=hidden_size, out_features=intermediate_size + ) + self.output = torch.nn.Linear( + in_features=intermediate_size, out_features=hidden_size + ) + + self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward( + self, + inputs: torch.Tensor, # hidden_size + ) -> torch.Tensor: # hidden_size + + # Pre-layer normalization + normed_inputs = self.layernorm(inputs) + + # Feed-forward + hidden = self.mlp(normed_inputs) + hidden = torch.nn.functional.gelu(hidden) + + # Project back to input size + output = self.output(hidden) + + # Apply dropout + output = self.dropout(output) + + # Residual connection + output += inputs + + return output diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/nef.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/nef.py new file mode 100644 index 000000000..1a8511b2f --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/nef.py @@ -0,0 +1,46 @@ +from typing import Optional + +import torch + + +def get_nef_indices(centers, n_nodes: int, n_edges_per_node: int): + """Transform the center indices into NEF indices.""" + + n_edges = len(centers) + edges_to_nef = torch.zeros((n_nodes, n_edges_per_node), dtype=torch.long) + nef_to_edges_neighbor = torch.empty((n_edges,), dtype=torch.long) + node_counter = torch.zeros((n_nodes,), dtype=torch.long) + nef_mask = torch.full((n_nodes, n_edges_per_node), 0, dtype=torch.bool) + + for i in range(n_edges): + center = centers[i] + edges_to_nef[center, node_counter[center]] = i + nef_mask[center, node_counter[center]] = True + nef_to_edges_neighbor[i] = node_counter[center] + node_counter[center] += 1 + + return (edges_to_nef, nef_to_edges_neighbor, nef_mask) + + +def edge_array_to_nef( + edge_array, + nef_indices, + mask: Optional[torch.Tensor] = None, + fill_value: float = 0.0, +): + """Converts an edge array to a NEF array.""" + + if mask is None: + return edge_array[nef_indices] + else: + return torch.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges(nef_array, centers, nef_to_edges_neighbor): + """Converts a NEF array to an edge array.""" + + return nef_array[centers, nef_to_edges_neighbor] diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/radial_mask.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/radial_mask.py new file mode 100644 index 000000000..3346e6ec5 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/radial_mask.py @@ -0,0 +1,10 @@ +import torch + + +def get_radial_mask(r, r_cut: float, r_transition: float): + # All radii are already guaranteed to be smaller than r_cut + return torch.where( + r < r_transition, + torch.ones_like(r), + 0.5 * (torch.cos(torch.pi * (r - r_transition) / (r_cut - r_transition)) + 1.0), + ) diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/structures.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/structures.py new file mode 100644 index 000000000..97eaf4aa8 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/structures.py @@ -0,0 +1,40 @@ +from typing import List + +import torch +from metatensor.torch.atomistic import System + + +def concatenate_structures(systems: List[System]): + + positions = [] + centers = [] + neighbors = [] + species = [] + segment_indices = [] + edge_vectors = [] + node_counter = 0 + + for i, system in enumerate(systems): + positions.append(system.positions) + species.append(system.species) + segment_indices.append(torch.full((len(system.positions),), i)) + + assert len(system.known_neighbors_lists()) == 1 + neighbor_list = system.get_neighbors_list(system.known_neighbors_lists()[0]) + nl_values = neighbor_list.samples.values + edge_vectors_system = neighbor_list.values.reshape(-1, 3) + + centers.append(nl_values[:, 0] + node_counter) + neighbors.append(nl_values[:, 1] + node_counter) + edge_vectors.append(edge_vectors_system) + + node_counter += len(system.positions) + + positions = torch.cat(positions) + centers = torch.cat(centers) + neighbors = torch.cat(neighbors) + species = torch.cat(species) + segment_indices = torch.cat(segment_indices) + edge_vectors = torch.cat(edge_vectors) + + return positions, centers, neighbors, species, segment_indices, edge_vectors diff --git a/src/metatensor/models/experimental/pet_jax/pet/pet_torch/transformer.py b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/transformer.py new file mode 100644 index 000000000..58f5c2fee --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/pet_torch/transformer.py @@ -0,0 +1,80 @@ +import torch + +from .attention import AttentionBlock +from .feedforward import FeedForwardBlock + + +class TransformerLayer(torch.nn.Module): + """A single transformer layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + dropout_rate: float, + attention_dropout_rate: float, + ): + super().__init__() + + self.attention_block = AttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + ) + self.ff_block = FeedForwardBlock( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dropout_rate=dropout_rate, + ) + + def forward( + self, + inputs: torch.Tensor, + radial_mask: torch.Tensor, + ) -> torch.Tensor: + + attention_output = self.attention_block(inputs, radial_mask) + output = self.ff_block(attention_output) + + return output + + +class Transformer(torch.nn.Module): + """A transformer model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_layers: int, + dropout_rate: float, + attention_dropout_rate: float, + ): + super().__init__() + + self.layers = torch.nn.ModuleList( + [ + TransformerLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + inputs, + radial_mask, + ): + + x = inputs + for layer in self.layers: + x = layer(x, radial_mask) + return x diff --git a/src/metatensor/models/experimental/pet_jax/pet/radial_mask.py b/src/metatensor/models/experimental/pet_jax/pet/radial_mask.py new file mode 100644 index 000000000..9cd5a4ba1 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/radial_mask.py @@ -0,0 +1,10 @@ +import jax.numpy as jnp + + +def get_radial_mask(r, r_cut, r_transition): + # All radii are already guaranteed to be smaller than r_cut + return jnp.where( + r < r_transition, + jnp.ones_like(r), + 0.5 * (jnp.cos(jnp.pi * (r - r_transition) / (r_cut - r_transition)) + 1.0), + ) diff --git a/src/metatensor/models/experimental/pet_jax/pet/transformer.py b/src/metatensor/models/experimental/pet_jax/pet/transformer.py new file mode 100644 index 000000000..c38e6c327 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/transformer.py @@ -0,0 +1,106 @@ +from typing import List, Optional + +import equinox as eqx +import jax +import jax.numpy as jnp + +from .attention import AttentionBlock +from .feedforward import FeedForwardBlock + + +class TransformerLayer(eqx.Module): + """A single transformer layer.""" + + attention_block: AttentionBlock + ff_block: FeedForwardBlock + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + dropout_rate: float, + attention_dropout_rate: float, + key: jax.random.PRNGKey, + ): + attention_key, ff_key = jax.random.split(key) + + self.attention_block = AttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + key=attention_key, + ) + self.ff_block = FeedForwardBlock( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dropout_rate=dropout_rate, + key=ff_key, + ) + + def __call__( + self, + inputs: jnp.ndarray, # seq_len hidden_size + radial_mask: jnp.ndarray, # seq_len + enable_dropout: bool = False, + key: Optional[jax.random.PRNGKey] = None, + ) -> jnp.ndarray: # seq_len hidden_size + + attn_key, ff_key = (None, None) if key is None else jax.random.split(key) + attention_output = self.attention_block( + inputs, radial_mask, enable_dropout=enable_dropout, key=attn_key + ) + seq_len = inputs.shape[0] + ff_keys = None if ff_key is None else jax.random.split(ff_key, num=seq_len) + output = jax.vmap(self.ff_block, in_axes=(0, None, 0))( + attention_output, enable_dropout, ff_keys + ) + + return output + + +class Transformer(eqx.Module): + """A transformer model.""" + + layers: List[eqx.Module] + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + num_layers: int, + dropout_rate: float, + attention_dropout_rate: float, + key: jax.random.PRNGKey, + ): + + keys = jax.random.split(key, num=num_layers) + self.layers = [ + TransformerLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + key=layer_key, + ) + for layer_key in keys + ] + + def __call__( + self, + inputs: jnp.ndarray, # seq_len hidden_size + enable_dropout: bool, + radial_mask: jnp.ndarray, # seq_len + key: Optional[jax.random.PRNGKey] = None, + ) -> jnp.ndarray: # seq_len hidden_size + + x = inputs + + for layer in self.layers: + current_key, key = (None, None) if key is None else jax.random.split(key) + x = layer(x, radial_mask, enable_dropout=enable_dropout, key=current_key) + + return x diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/__init__.py b/src/metatensor/models/experimental/pet_jax/pet/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/augmentation.py b/src/metatensor/models/experimental/pet_jax/pet/utils/augmentation.py new file mode 100644 index 000000000..c48739e49 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/augmentation.py @@ -0,0 +1,36 @@ +import random + +from scipy.spatial.transform import Rotation + +from .mts_to_structure import Structure + + +def apply_random_augmentation(structure: Structure): + """ + Apply a random augmentation to a ``Structure``. + + :param structure: The structure to augment. + + :return: The augmented structure. + """ + + transformation = get_random_augmentation() + return Structure( + positions=structure.positions @ transformation.T, + cell=structure.cell @ transformation.T, + numbers=structure.numbers, + centers=structure.centers, + neighbors=structure.neighbors, + cell_shifts=structure.cell_shifts, + energy=structure.energy, + forces=structure.forces @ transformation.T, + ) + + +def get_random_augmentation(): + + transformation = Rotation.random().as_matrix() + invert = random.choice([True, False]) + if invert: + transformation *= -1 + return transformation diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/corresponding_edges.py b/src/metatensor/models/experimental/pet_jax/pet/utils/corresponding_edges.py new file mode 100644 index 000000000..32ed12c0c --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/corresponding_edges.py @@ -0,0 +1,22 @@ +import jax +import jax.numpy as jnp + + +def loop_body(i, carry): + array, array_inversed, inverse_indices = carry + inverse_indices = inverse_indices.at[i].set( + jnp.nonzero(jnp.all(array_inversed == array[i], axis=1), size=1)[0][0] + ) + return array, array_inversed, inverse_indices + + +def get_corresponding_edges(array): + n_edges = len(array) + int_dtype = jnp.int64 if jax.config.jax_enable_x64 else jnp.int32 + array_inversed = array[:, ::-1] + return jax.lax.fori_loop( + 0, + n_edges, + loop_body, + (array, array_inversed, jnp.empty((n_edges,), dtype=int_dtype)), + )[2] diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/dataloader.py b/src/metatensor/models/experimental/pet_jax/pet/utils/dataloader.py new file mode 100644 index 000000000..06fe74aa4 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/dataloader.py @@ -0,0 +1,14 @@ +import numpy as np + + +def dataloader(dataset, batch_size, shuffle=True): + dataset_size = len(dataset) + + indices = np.arange(dataset_size) + if shuffle: + np.random.shuffle(indices) + + for start_idx in range(0, dataset_size, batch_size): + end_idx = min(start_idx + batch_size, dataset_size) + batch_indices = indices[start_idx:end_idx] + yield [dataset[i] for i in batch_indices] diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/jax_batch.py b/src/metatensor/models/experimental/pet_jax/pet/utils/jax_batch.py new file mode 100644 index 000000000..7e907e6c3 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/jax_batch.py @@ -0,0 +1,102 @@ +from collections import namedtuple + +import jax.numpy as jnp + + +JAXBatch = namedtuple( + "JAXBatch", + "positions, cells, numbers, centers, neighbors, " + "cell_shifts, n_nodes, energies, forces", +) + + +def jax_structures_to_batch(structures): + """Converts a list of JAX structures to a JAX batch. + + :param structures: A list of JAX structures. + + :return: A JAX batch. + """ + n_nodes = jnp.array([len(structure.positions) for structure in structures]) + + # concatenate after shifting + shifted_centers = [] + shifted_neighbors = [] + shift = 0 + for structure in structures: + shifted_centers.append(structure.centers + shift) + shifted_neighbors.append(structure.neighbors + shift) + shift += len(structure.positions) + centers = jnp.concatenate(shifted_centers) + neighbors = jnp.concatenate(shifted_neighbors) + + return JAXBatch( + positions=jnp.concatenate([structure.positions for structure in structures]), + cells=jnp.stack([structure.cell for structure in structures]), + numbers=jnp.concatenate([structure.numbers for structure in structures]), + centers=centers, + neighbors=neighbors, + cell_shifts=jnp.concatenate( + [structure.cell_shifts for structure in structures] + ), + n_nodes=n_nodes, + energies=jnp.stack([structure.energy for structure in structures]), + forces=jnp.concatenate([structure.forces for structure in structures]), + ) + + +def calculate_padding_sizes(batch: JAXBatch): + """Calculate the padding sizes for a batch. Works in powers of two. + + :param structures: A batch of structures. + + :return: A tuple with the padding sizes: nodes, edges, edges per node. + """ + n_nodes = batch.positions.shape[0] + n_edges = batch.neighbors.shape[0] + n_edges_per_node = jnp.bincount(batch.neighbors).max() + return ( + 2 ** int(jnp.ceil(jnp.log2(n_nodes + 1))), + 2 ** int(jnp.ceil(jnp.log2(n_edges))), + 2 ** int(jnp.ceil(jnp.log2(n_edges_per_node))), + ) + + +def pad_batch(batch: JAXBatch, n_nodes: int, n_edges: int): + """Pad a batch to the given sizes. + + :param batch: The batch to pad. + :param n_nodes: The number of nodes to pad to. + :param n_edges: The number of edges to pad to. + + :return: The padded batch. + """ + + # note: for node arrays, n_nodes - 1 is always + # a padding value (see calculate_padding_sizes above) + + return JAXBatch( + positions=jnp.pad( + batch.positions, ((0, n_nodes - len(batch.positions)), (0, 0)) + ), + cells=jnp.pad(batch.cells, ((0, n_nodes - len(batch.cells)), (0, 0), (0, 0))), + numbers=jnp.pad(batch.numbers, (0, n_nodes - len(batch.numbers))), + centers=jnp.pad( + batch.centers, + (0, n_edges - len(batch.centers)), + mode="constant", + constant_values=n_nodes - 1, + ), + neighbors=jnp.pad( + batch.neighbors, + (0, n_edges - len(batch.neighbors)), + mode="constant", + constant_values=n_nodes - 1, + ), + cell_shifts=jnp.pad( + batch.cell_shifts, ((0, n_edges - len(batch.cell_shifts)), (0, 0)) + ), + n_nodes=batch.n_nodes, + energies=jnp.pad(batch.energies, (0, len(batch.n_nodes) - len(batch.energies))), + forces=jnp.pad(batch.forces, (0, n_edges - len(batch.forces))), + ) diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/jax_structure.py b/src/metatensor/models/experimental/pet_jax/pet/utils/jax_structure.py new file mode 100644 index 000000000..2d5440854 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/jax_structure.py @@ -0,0 +1,38 @@ +from collections import namedtuple + +import jax.numpy as jnp +import numpy as np + +from .mts_to_structure import Structure + + +JAXStructure = namedtuple( + "JAXStructure", + "positions, cell, numbers, centers, neighbors, cell_shifts, energy, forces", +) + + +def structure_to_jax(structure: Structure): + """Converts a Structure to a JAX dictionary. + + :param structure: The structure to convert. + + :return: The same named tuple, but with jnp arrays. + """ + + if not np.all(structure.centers[1:] >= structure.centers[:-1]): + raise ValueError( + "centers array of the neighbor list is not sorted. " + "This is required for the JAX implementation." + ) + + return JAXStructure( + positions=jnp.array(structure.positions), + cell=jnp.array(structure.cell), + numbers=jnp.array(structure.numbers), + centers=jnp.array(structure.centers), + neighbors=jnp.array(structure.neighbors), + cell_shifts=jnp.array(structure.cell_shifts), + energy=jnp.array(structure.energy), + forces=jnp.array(structure.forces), + ) diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/mts_to_structure.py b/src/metatensor/models/experimental/pet_jax/pet/utils/mts_to_structure.py new file mode 100644 index 000000000..f355a8e5d --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/mts_to_structure.py @@ -0,0 +1,41 @@ +from collections import namedtuple + +import ase.neighborlist +import numpy as np +from metatensor.torch.atomistic import System + + +Structure = namedtuple( + "Structure", + "positions, cell, numbers, centers, neighbors, cell_shifts, energy, forces", +) + + +def mts_to_structure( + system: System, energy: float, forces: np.ndarray, cutoff: float +) -> Structure: + """Converts a `metatensor.torch.atomistic.System` to a `Structure`.""" + positions = system.positions.numpy() + numbers = system.species.numpy() + cell = system.cell[:].numpy() + + centers, neighbors, cell_shifts = ase.neighborlist.primitive_neighbor_list( + quantities="ijS", + positions=system.positions.numpy(), + cell=system.cell.numpy(), + pbc=[not np.all(system.cell.numpy() == 0)] * 3, + cutoff=cutoff, + self_interaction=False, + use_scaled_positions=False, + ) + + return Structure( + positions=positions, + cell=cell, + numbers=numbers, + centers=centers, + neighbors=neighbors, + cell_shifts=cell_shifts, + energy=energy, + forces=forces, + ) diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/nef.py b/src/metatensor/models/experimental/pet_jax/pet/utils/nef.py new file mode 100644 index 000000000..2096269c7 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/nef.py @@ -0,0 +1,47 @@ +import jax +import jax.numpy as jnp + + +def loop_body(i, carry): + centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter = carry + center = centers[i] + edges_to_nef = edges_to_nef.at[center, node_counter[center]].set(i) + nef_mask = nef_mask.at[center, node_counter[center]].set(True) + nef_to_edges_neighbor = nef_to_edges_neighbor.at[i].set(node_counter[center]) + node_counter = node_counter.at[center].add(1) + return centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter + + +def get_nef_indices(centers, n_nodes: int, n_edges_per_node: int): + int_dtype = jnp.int64 if jax.config.jax_enable_x64 else jnp.int32 + n_edges = len(centers) + edges_to_nef = jnp.zeros((n_nodes, n_edges_per_node), dtype=int_dtype) + nef_to_edges_neighbor = jnp.empty((n_edges,), dtype=int_dtype) + node_counter = jnp.zeros((n_nodes,), dtype=int_dtype) + nef_mask = jnp.full((n_nodes, n_edges_per_node), False, dtype=bool) + # returns edges_to_nef, nef_to_edges_neighbor, nef_mask + # edges_to_nef can be used to index an edge array to get the corresponding nef array + # nef_to_edges_neighbor can be used to index the second dimension of a nef array + # to get the corresponding edge array (the first dimension is indexed by `centers`) + # nef_mask masks out the padding values in the nef array + return jax.lax.fori_loop( + 0, + n_edges, + loop_body, + (centers, edges_to_nef, nef_to_edges_neighbor, nef_mask, node_counter), + )[1:4] + + +def edge_array_to_nef(edge_array, nef_indices, mask=None, fill_value=0.0): + if mask is None: + return edge_array[nef_indices] + else: + return jnp.where( + mask.reshape(mask.shape + (1,) * (len(edge_array.shape) - 1)), + edge_array[nef_indices], + fill_value, + ) + + +def nef_array_to_edges(nef_array, centers, nef_to_edges_neighbor): + return nef_array[centers, nef_to_edges_neighbor] diff --git a/src/metatensor/models/experimental/pet_jax/pet/utils/to_torch.py b/src/metatensor/models/experimental/pet_jax/pet/utils/to_torch.py new file mode 100644 index 000000000..c12e2a253 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/pet/utils/to_torch.py @@ -0,0 +1,69 @@ +import jax +import jax.numpy as jnp +import numpy as np +import torch +from metatensor.torch.atomistic import ModelCapabilities + +from ...model import Model as PET_torch +from ..models import PET as PET_jax + + +def pet_to_torch(pet_jax: PET_jax, hypers: dict, capabilities: ModelCapabilities): + """Convert a pet-jax model to a torch model""" + + jax_device = pet_jax.composition_weights.device_buffer.device() + if jax_device.platform == "cpu": + torch_device_type = "cpu" + elif jax_device.platform == "gpu": + torch_device_type = "cuda" + else: + raise ValueError( + f"Failed to convert device {jax_device.platform} " + "during jax-to-torch conversion of PET-JAX" + ) + device = torch.device(torch_device_type) + + pet_torch = PET_torch( + capabilities=capabilities, + hypers=hypers, + composition_weights=torch.tensor( + np.array(pet_jax.composition_weights), device=device + ), + ) + + # skip the species list (in both atomic numbers indices) and composition weights + jax_params = [ + x for x in jax.tree_util.tree_leaves(pet_jax) if isinstance(x, jax.Array) + ][2:-1] + torch_params = list(pet_torch.parameters()) + + torch_counter = 0 + jax_counter = 0 + while True: + torch_param = torch_params[torch_counter] + jax_param = jax_params[jax_counter] + + if torch_param.shape != jax_param.shape: + if ( + torch_param.shape[0] == 3 * jax_param.shape[0] + and torch_param.shape[1:] == jax_param.shape[1:] + ): + # we're dealing with the attention weights + jax_param = [jax_param] + jax_param.append(jax_params[jax_counter + 1]) + jax_param.append(jax_params[jax_counter + 2]) + jax_counter += 2 + jax_param = jnp.concatenate(jax_param) + else: + raise ValueError( + f"Failed to convert parameter {torch_param.shape} " + f"to {jax_param.shape} during jax-to-torch conversion of PET-JAX" + ) + torch_param.data = torch.tensor(np.array(jax_param), device=device) + jax_counter += 1 + torch_counter += 1 + if jax_counter == len(jax_params): + assert torch_counter == len(torch_params) + break + + return pet_torch diff --git a/src/metatensor/models/experimental/pet_jax/tests/__init__.py b/src/metatensor/models/experimental/pet_jax/tests/__init__.py new file mode 100644 index 000000000..b6aa045b3 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/tests/__init__.py @@ -0,0 +1,6 @@ +from pathlib import Path + +DATASET_PATH = str( + Path(__file__).parent.resolve() + / "../../../../../../tests/resources/qm9_reduced_100.xyz" +) diff --git a/src/metatensor/models/experimental/pet_jax/tests/test_functionality.py b/src/metatensor/models/experimental/pet_jax/tests/test_functionality.py new file mode 100644 index 000000000..cea0c0096 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/tests/test_functionality.py @@ -0,0 +1,98 @@ +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +import optax +import torch +from metatensor.torch.atomistic import ( + ModelCapabilities, + ModelOutput, + NeighborsListOptions, +) + +from metatensor.models.experimental.pet_jax import DEFAULT_HYPERS +from metatensor.models.experimental.pet_jax.model import Model as PET_torch +from metatensor.models.experimental.pet_jax.pet.models import PET as PET_jax +from metatensor.models.experimental.pet_jax.pet.utils.jax_batch import ( + calculate_padding_sizes, + jax_structures_to_batch, +) +from metatensor.models.experimental.pet_jax.pet.utils.jax_structure import ( + structure_to_jax, +) +from metatensor.models.experimental.pet_jax.pet.utils.mts_to_structure import ( + mts_to_structure, +) +from metatensor.models.utils.data.readers.structures import read_structures_ase +from metatensor.models.utils.neighbors_lists import get_system_with_neighbors_lists + +from . import DATASET_PATH + + +def test_pet_jax(): + """Checks that the PET-JAX model can train and that its + composition features are not being trained.""" + + all_species = [1, 6, 7, 8] + composition_weights = jnp.array([0.1, 0.2, 0.3, 0.4]) + pet_jax = PET_jax( + jnp.array(all_species), + DEFAULT_HYPERS["model"], + composition_weights, + key=jax.random.PRNGKey(0), + ) + + systems = read_structures_ase(DATASET_PATH, dtype=torch.get_default_dtype()) + systems = systems[:5] + jax_structures = [ + structure_to_jax(mts_to_structure(system, 0.0, np.zeros((0, 3)), 4.0)) + for system in systems + ] + jax_batch = jax_structures_to_batch( + [structure_to_jax(structure) for structure in jax_structures] + ) + _, _, n_edges_per_node = calculate_padding_sizes(jax_batch) + + def loss_fn(pet, batch, n_edges_per_node): + output = pet(batch, n_edges_per_node, is_training=True) + return jnp.sum((output["energies"] - jnp.zeros_like(output["energies"])) ** 2) + + grad_fn = eqx.filter_grad(loss_fn) + gradients = grad_fn(pet_jax, jax_batch, n_edges_per_node) + + optimizer = optax.adam(learning_rate=1.0) + optimizer_state = optimizer.init(eqx.filter(pet_jax, eqx.is_inexact_array)) + updates, optimizer_state = optimizer.update(gradients, optimizer_state, pet_jax) + pet_jax = eqx.apply_updates(pet_jax, updates) + + assert jnp.allclose(pet_jax.composition_weights, composition_weights) + + +def test_pet_torch(): + """Tests that the torch version can predict successfully.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + composition_weights = [0.1, 0.2, 0.3, 0.4] + pet_torch = PET_torch( + capabilities=capabilities, + hypers=DEFAULT_HYPERS["model"], + composition_weights=torch.tensor(composition_weights), + ) + + systems = read_structures_ase(DATASET_PATH, dtype=torch.get_default_dtype()) + + nl_options = NeighborsListOptions(model_cutoff=4.0, full_list=True) + systems = [ + get_system_with_neighbors_lists(system, [nl_options]) for system in systems + ] + pet_torch(systems, {"energy": ModelOutput()}) diff --git a/src/metatensor/models/experimental/pet_jax/tests/test_internals.py b/src/metatensor/models/experimental/pet_jax/tests/test_internals.py new file mode 100644 index 000000000..69d00c0ba --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/tests/test_internals.py @@ -0,0 +1,54 @@ +import jax +import jax.numpy as jnp + +from metatensor.models.experimental.pet_jax.pet.utils.corresponding_edges import ( + get_corresponding_edges, +) +from metatensor.models.experimental.pet_jax.pet.utils.edges_to_nef import ( + edge_array_to_nef, + get_nef_indices, + nef_array_to_edges, +) + + +def test_corresponding_edges(): + """Tests the get_corresponding_edges function, needed for message passing.""" + + get_corresponding_edges_jit = jax.jit(get_corresponding_edges) + arr = jnp.array([[0, 1]] * 500 + [[1, 0]] * 500) + corresponding_edges = get_corresponding_edges_jit(arr) + expected = jnp.array([500] * 500 + [0] * 500) + assert jnp.all(corresponding_edges == expected) + + +def test_nef_indices(): + """Tests the NEF indexing, needed to feed edges to a transformer.""" + + get_nef_indices_jit = jax.jit(get_nef_indices, static_argnums=(1, 2)) + edge_array_to_nef_jit = jax.jit(edge_array_to_nef) + nef_array_to_edges_jit = jax.jit(nef_array_to_edges) + + centers = jnp.array([0, 4, 3, 1, 0, 0, 3, 3, 3, 4]) + nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices_jit(centers, 5, 4) + + expected_nef_mask = jnp.array( + [ + [True, True, True, False], + [True, False, False, False], + [False, False, False, False], + [True, True, True, True], + [True, True, False, False], + ] + ) + assert jnp.all(nef_mask == expected_nef_mask) + + nef_centers = edge_array_to_nef_jit(centers, nef_indices) + + expected_nef_centers = jnp.array( + [[0, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0], [3, 3, 3, 3], [4, 4, 0, 0]] + ) + + assert jnp.all(nef_centers == expected_nef_centers) + + centers_again = nef_array_to_edges_jit(nef_centers, centers, nef_to_edges_neighbor) + assert jnp.all(centers == centers_again) diff --git a/src/metatensor/models/experimental/pet_jax/tests/test_to_torch.py b/src/metatensor/models/experimental/pet_jax/tests/test_to_torch.py new file mode 100644 index 000000000..ee513c281 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/tests/test_to_torch.py @@ -0,0 +1,83 @@ +import jax +import jax.numpy as jnp +import numpy as np +import torch +from metatensor.torch.atomistic import ( + ModelCapabilities, + ModelOutput, + NeighborsListOptions, +) + +from metatensor.models.experimental.pet_jax import DEFAULT_HYPERS +from metatensor.models.experimental.pet_jax.pet.models import PET as PET_jax +from metatensor.models.experimental.pet_jax.pet.utils.jax_batch import ( + calculate_padding_sizes, + jax_structures_to_batch, +) +from metatensor.models.experimental.pet_jax.pet.utils.jax_structure import ( + structure_to_jax, +) +from metatensor.models.experimental.pet_jax.pet.utils.mts_to_structure import ( + mts_to_structure, +) +from metatensor.models.experimental.pet_jax.pet.utils.to_torch import pet_to_torch +from metatensor.models.utils.data.readers.structures import read_structures_ase +from metatensor.models.utils.neighbors_lists import get_system_with_neighbors_lists + +from . import DATASET_PATH + + +def test_pet_to_torch(): + """Tests that the model can be converted to torch and predict the same output.""" + + all_species = [1, 6, 7, 8] + composition_weights = [0.1, 0.2, 0.3, 0.4] + pet_jax = PET_jax( + jnp.array(all_species), + DEFAULT_HYPERS["model"], + jnp.array(composition_weights), + key=jax.random.PRNGKey(0), + ) + + systems = read_structures_ase(DATASET_PATH, dtype=torch.get_default_dtype()) + systems = systems[:5] + + # jax evaluation + jax_structures = [ + structure_to_jax(mts_to_structure(system, 0.0, np.zeros((0, 3)), 4.0)) + for system in systems + ] + jax_batch = jax_structures_to_batch( + [structure_to_jax(structure) for structure in jax_structures] + ) + _, _, n_edges_per_node = calculate_padding_sizes(jax_batch) + output_jax = pet_jax(jax_batch, n_edges_per_node, is_training=False) + + # convert to torch + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=all_species, + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + pet_torch = pet_to_torch(pet_jax, DEFAULT_HYPERS["model"], capabilities) + + # neighbor lists + nl_options = NeighborsListOptions(model_cutoff=4.0, full_list=True) + systems = [ + get_system_with_neighbors_lists(system, [nl_options]) for system in systems + ] + + # torch evaluation + output_torch = pet_torch(systems, {"energy": ModelOutput()}) + + assert torch.allclose( + torch.tensor(np.array(output_jax["energies"])), + output_torch["energy"].block().values.squeeze(-1), + atol=1e-4, + rtol=1e-4, + ) diff --git a/src/metatensor/models/experimental/pet_jax/tests/test_torchscript.py b/src/metatensor/models/experimental/pet_jax/tests/test_torchscript.py new file mode 100644 index 000000000..660480254 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/tests/test_torchscript.py @@ -0,0 +1,45 @@ +import torch +from metatensor.torch.atomistic import ModelCapabilities, ModelOutput + +from metatensor.models.experimental.pet_jax import DEFAULT_HYPERS, Model + + +def test_torchscript(): + """Tests that the model can be jitted.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + pet = Model( + capabilities, DEFAULT_HYPERS["model"], torch.tensor([0.1, 0.2, 0.3, 0.4]) + ) + torch.jit.script(pet, {"energy": pet.capabilities.outputs["energy"]}) + + +def test_torchscript_save(): + """Tests that the model can be jitted and saved.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + pet = Model( + capabilities, DEFAULT_HYPERS["model"], torch.tensor([0.1, 0.2, 0.3, 0.4]) + ) + torch.jit.save( + torch.jit.script(pet, {"energy": pet.capabilities.outputs["energy"]}), + "pet.pt", + ) diff --git a/src/metatensor/models/experimental/pet_jax/train.py b/src/metatensor/models/experimental/pet_jax/train.py new file mode 100644 index 000000000..ec9393e56 --- /dev/null +++ b/src/metatensor/models/experimental/pet_jax/train.py @@ -0,0 +1,299 @@ +import logging +from typing import Dict, List, Optional, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +import optax +import torch +from metatensor.learn.data.dataset import _BaseDataset +from metatensor.torch.atomistic import ModelCapabilities + +from ...utils.composition import calculate_composition_weights +from ...utils.data import check_datasets, get_all_targets +from ...utils.logging import MetricLogger +from .model import DEFAULT_HYPERS +from .pet.models import PET, PET_energy_force +from .pet.utils.augmentation import apply_random_augmentation +from .pet.utils.dataloader import dataloader +from .pet.utils.jax_batch import calculate_padding_sizes, jax_structures_to_batch +from .pet.utils.jax_structure import structure_to_jax +from .pet.utils.mts_to_structure import mts_to_structure +from .pet.utils.to_torch import pet_to_torch + + +logger = logging.getLogger(__name__) + + +def train( + train_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], + validation_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], + requested_capabilities: ModelCapabilities, + hypers: Dict = DEFAULT_HYPERS, + continue_from: Optional[str] = None, + output_dir: str = ".", + device_str: str = "cpu", +): + logger.info( + "This is a JAX version of the PET architecture. " + "It does not support message passing yet." + ) + + # Random seed + logger.warn( + "The random seed is not being set from outside, but it is hardcoded for now." + ) + key = jax.random.PRNGKey(1337) + + # Device + if device_str == "gpu": + device_str = "cuda" + jax.config.update("jax_platform_name", device_str) + logger.info( + "Running on device " + f"{list(jnp.array([1, 2, 3]).addressable_data(0).devices())[0]}" + ) + + # Dtype + if torch.get_default_dtype() == torch.float64: + jax.config.update("jax_enable_x64", True) + elif torch.get_default_dtype() == torch.float32: + pass + else: + raise ValueError(f"Unsupported dtype {torch.get_default_dtype()} in PET-JAX.") + + if len(train_datasets) != 1: + raise NotImplementedError( + "Only one training dataset is supported in PET-JAX for the moment." + ) + if len(validation_datasets) != 1: + raise NotImplementedError( + "Only one validation dataset is supported in PET-JAX for the moment." + ) + + if continue_from is not None: + raise NotImplementedError( + "Continuing from a previous run is not supported yet in PET-JAX." + ) + model_capabilities = requested_capabilities + # TODO: implement restarting + + # Perform checks on the datasets: + logger.info("Checking datasets for consistency") + check_datasets( + train_datasets, + validation_datasets, + model_capabilities, + ) + + # Check capabilities: + if len(model_capabilities.outputs) != 1: + raise NotImplementedError( + "Only one output is supported in PET-JAX for the moment." + ) + if next(iter(model_capabilities.outputs.values())).quantity != "energy": + raise NotImplementedError( + "Only energy outputs are supported in PET-JAX for the moment." + ) + + # Extract whether we're also training on forces + do_forces = next(iter(train_datasets[0]))[1].block(0).has_gradient("positions") + + # Calculate and set the composition weights for all targets: + logger.info("Calculating composition weights") + target_name = next(iter(model_capabilities.outputs.keys())) + train_datasets_with_target = [] + for dataset in train_datasets: + if target_name in get_all_targets(dataset): + train_datasets_with_target.append(dataset) + if len(train_datasets_with_target) == 0: + raise ValueError( + f"Target {target_name} in the model's new capabilities is not " + "present in any of the training datasets." + ) + composition_weights = calculate_composition_weights( + train_datasets_with_target, target_name + ) + composition_weights_jax = jnp.array(composition_weights.numpy()) + + # Extract the training and validation sets from metatensor format + cutoff = hypers["model"]["cutoff"] + training_set = [ + mts_to_structure( + structure, + float(targets.block().values), + ( + -targets.block().gradient("positions").values.reshape(-1, 3).numpy() + if do_forces + else np.zeros((0, 3)) + ), + cutoff, + ) + for structure, targets in train_datasets[0] + ] + valid_set = [ + mts_to_structure( + structure, + float(targets.block().values), + ( + -targets.block().gradient("positions").values.reshape(-1, 3).numpy() + if do_forces + else np.zeros((0, 3)) + ), + cutoff, + ) + for structure, targets in validation_datasets[0] + ] + + def loss_fn(model, structures, n_edges_per_node, do_forces, force_weight, key): + predictions = model(structures, n_edges_per_node, is_training=True, key=key) + loss = jnp.sum((predictions["energies"] - structures.energies) ** 2) + if do_forces: + loss += force_weight * jnp.sum( + (predictions["forces"] - structures.forces) ** 2 + ) + return loss + + grad_loss_fn = eqx.filter_value_and_grad(loss_fn) + + @eqx.filter_jit + def train_step( + model, + structures, + n_edges_per_node_array, + optimizer, + opt_state, + do_forces, + force_weight, + key, + ): + n_edges_per_node = len(n_edges_per_node_array) + loss, grads = grad_loss_fn( + model, structures, n_edges_per_node, do_forces, force_weight, key + ) + updates, opt_state = optimizer.update(grads, opt_state, model) + model = eqx.apply_updates(model, updates) + return loss, model, opt_state + + # Initialize the model + all_species = jnp.array(model_capabilities.species) + if do_forces: + print("hello") + model = PET_energy_force( + all_species, hypers["model"], composition_weights_jax, key=key + ) + else: + model = PET(all_species, hypers["model"], composition_weights_jax, key=key) + + training_hypers = hypers["training"] + learning_rate = training_hypers["learning_rate"] + force_weight = 1.0 # TODO: pass this in + num_epochs = training_hypers["num_epochs"] + batch_size = training_hypers["batch_size"] + num_warmup_steps = training_hypers["num_warmup_steps"] + + schedule = optax.linear_schedule(0.0, learning_rate, num_warmup_steps) + optimizer = optax.chain( + optax.clip_by_global_norm(10.0), + optax.adamw(learning_rate=schedule), + ) + + opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array)) + + @eqx.filter_jit + def _evaluate_model(model, jax_batch, n_edges_per_node_array): + n_edges_per_node = len(n_edges_per_node_array) + return model(jax_batch, n_edges_per_node, is_training=False) + + def evaluate_model(model, dataset, force_weight, do_forces): + energy_sse = 0.0 + energy_sae = 0.0 + if do_forces: + force_sse = 0.0 + force_sae = 0.0 + number_of_forces = 0 + for batch in dataloader(dataset, batch_size, shuffle=False): + jax_batch = jax_structures_to_batch( + [structure_to_jax(structure) for structure in batch] + ) + n_nodes, n_edges, n_edges_per_node = calculate_padding_sizes(jax_batch) + # TODO: pad the batch + # jax_batch = pad_batch(jax_batch, n_nodes, n_edges) + predictions = _evaluate_model( + model, jax_batch, jnp.zeros((n_edges_per_node,)) + ) + energy_sse += jnp.sum((predictions["energies"] - jax_batch.energies) ** 2) + energy_sae += jnp.sum(jnp.abs(predictions["energies"] - jax_batch.energies)) + if do_forces: + force_sse += jnp.sum((predictions["forces"] - jax_batch.forces) ** 2) + force_sae += jnp.sum(jnp.abs(predictions["forces"] - jax_batch.forces)) + number_of_forces += 3 * len(jax_batch.forces) + energy_mse = energy_sse / len(dataset) + # energy_mae = energy_sae / len(dataset) + energy_rmse = jnp.sqrt(energy_mse) + result_dict = {} + result_dict["loss"] = energy_sse + result_dict[target_name] = energy_rmse + # result_dict["energy_mae"] = energy_mae + if do_forces: + force_mse = force_sse / number_of_forces + # force_mae = force_sae / number_of_forces + force_rmse = jnp.sqrt(force_mse) + result_dict["loss"] += force_weight * force_sse + result_dict[target_name + "_positions_gradients"] = force_rmse + # result_dict["force_mae"] = force_mae + return result_dict + + train_metrics = evaluate_model(model, training_set, force_weight, do_forces) + valid_metrics = evaluate_model(model, valid_set, force_weight, do_forces) + train_loss = train_metrics.pop("loss") + valid_loss = valid_metrics.pop("loss") + metric_logger = MetricLogger( + model_capabilities, + train_loss, + valid_loss, + train_metrics, + valid_metrics, + ) + metric_logger.log(0, train_loss, valid_loss, train_metrics, valid_metrics) + + for epoch in range(1, num_epochs): + train_loss = 0.0 + for batch in dataloader(training_set, batch_size, shuffle=True): + jax_batch = jax_structures_to_batch( + [ + structure_to_jax(apply_random_augmentation(structure)) + for structure in batch + ] + ) + n_nodes, n_edges, n_edges_per_node = calculate_padding_sizes(jax_batch) + # TODO: pad the batch + # jax_batch = pad_batch(jax_batch, n_nodes, n_edges) + subkey, key = jax.random.split(key) + loss, model, opt_state = train_step( + model, + jax_batch, + jnp.zeros((n_edges_per_node,)), + optimizer, + opt_state, + do_forces, + force_weight, + subkey, + ) + train_loss += loss + + if epoch % training_hypers["log_interval"] == 0: + train_metrics = evaluate_model(model, training_set, force_weight, do_forces) + valid_metrics = evaluate_model(model, valid_set, force_weight, do_forces) + train_loss = train_metrics.pop("loss") + valid_loss = valid_metrics.pop("loss") + metric_logger.log( + epoch, train_loss, valid_loss, train_metrics, valid_metrics + ) + # TODO: implement checkpoints + + # Convert to torch + model = pet_to_torch(model, hypers["model"], requested_capabilities) + + return model diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 8794802e6..37ff249ce 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -190,7 +190,6 @@ def test_both(is_training): jitted_model = torch.jit.script(model) output = jitted_model(systems, {"energy": model.capabilities.outputs["energy"]}) - print(output["energy"].block().values.requires_grad) jitted_gradients = compute_gradient( output["energy"].block().values, [system.positions for system in systems] + strains, diff --git a/tox.ini b/tox.ini index 0842d840c..6d7e1032b 100644 --- a/tox.ini +++ b/tox.ini @@ -90,6 +90,17 @@ deps = commands = pytest --import-mode=append {posargs} src/metatensor/models/experimental/alchemical_model/tests/ +[testenv:pet-jax] +description = Run PET-JAX tests with pytest +passenv = * +deps = + pytest + equinox + optax +commands = + python src/metatensor/models/experimental/pet_jax/hotfix_torch.py + pytest --import-mode=append {posargs} src/metatensor/models/experimental/pet_jax/tests/ + [testenv:docs] description = builds the documentation with sphinx deps =