Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PET-JAX #82

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/pet-jax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: PET-JAX tests

on:
push:
branches: [main]
pull_request:
# Check all PR

jobs:
tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python-version: "3.12"

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: pip install tox

- name: Install JAX
# JAX does not work as a dependency; it needs to be installed separately
run: pip install jax[cpu]

- name: run PET-JAX tests
run: tox -e pet-jax-tests
env:
# Use the CPU only version of torch when building/running the code
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

- name: Upload codecoverage
uses: codecov/codecov-action@v4
with:
files: ./tests/coverage.xml
39 changes: 39 additions & 0 deletions docs/src/architectures/pet-jax.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. _architecture-pet-jax:

PET-JAX
=========

This is a JAX implementation of the PET architecture.

Installation
------------
To use PET-JAX within ``metatensor-models``, you should already have
JAX installed for your platform (see the official JAX installation instructions).
Then, you can run the following command in the root directory of the repository:

.. code-block:: bash

pip install .[pet-jax]

Following this, it is also necessary to hot-fix a few lines of your torch installation
to allow PET-JAX models to be exported. This can be achieved by running the following
Python script:

.. literalinclude:: ../../../src/metatensor/models/experimental/pet_jax/hotfix_torch.py
:language: python

Default Hyperparameters
-----------------------
The default hyperparameters for the PET-JAX model are:

.. literalinclude:: ../../../src/metatensor/models/cli/conf/architecture/experimental.pet_jax.yaml
:language: yaml


Tuning Hyperparameters
----------------------
To be done.

References
----------
.. footbibliography::
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ soap-bpnn = []
alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@fafb0bd",
]
pet-jax = [
"jax",
"equinox",
"optax",
]


[tool.setuptools.packages.find]
where = ["src"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# default hyperparameters for the PET-JAX model
model:
cutoff: 5.0
d_pet: 128
num_heads: 4
num_attention_layers: 2
num_gnn_layers: 2
mlp_dropout_rate: 0.0
attention_dropout_rate: 0.0

training:
batch_size: 16
num_warmup_steps: 1000
num_epochs: 10000
learning_rate: 3e-4
log_interval: 10
2 changes: 2 additions & 0 deletions src/metatensor/models/experimental/pet_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model import Model, DEFAULT_HYPERS # noqa: F401
from .train import train # noqa: F401
28 changes: 28 additions & 0 deletions src/metatensor/models/experimental/pet_jax/hotfix_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This is fixing a small bug in the attention implementation
# in torch that prevents it from being torchscriptable.

import os

import torch


file = os.path.join(os.path.dirname(torch.__file__), "nn", "modules", "activation.py")

with open(file, "r") as f:
lines = f.readlines()
for i, line in enumerate(lines):
if (
"elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:" # noqa: E501
in line
):
lines[i] = line.replace(
"elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:", # noqa: E501
"elif self.in_proj_bias is not None:\n"
" if query.dtype != self.in_proj_bias.dtype:",
)
lines[i + 1] = (
' why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) do not match"\n' # noqa: E501
)

with open(file, "w") as f:
f.writelines(lines)
219 changes: 219 additions & 0 deletions src/metatensor/models/experimental/pet_jax/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelOutput, NeighborsListOptions, System
from omegaconf import OmegaConf

from ... import ARCHITECTURE_CONFIG_PATH
from .pet.pet_torch.corresponding_edges import get_corresponding_edges
from .pet.pet_torch.encoder import Encoder
from .pet.pet_torch.nef import edge_array_to_nef, get_nef_indices, nef_array_to_edges
from .pet.pet_torch.radial_mask import get_radial_mask
from .pet.pet_torch.structures import concatenate_structures
from .pet.pet_torch.transformer import Transformer


ARCHITECTURE_NAME = "experimental.pet_jax"
DEFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / f"{ARCHITECTURE_NAME}.yaml")
)
DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["model"]


class Model(torch.nn.Module):

def __init__(self, capabilities, hypers, composition_weights):
super().__init__()
self.name = ARCHITECTURE_NAME
self.hypers = hypers

self.capabilities = capabilities

# Handle species
self.all_species = capabilities.species
n_species = len(self.all_species)
self.species_to_species_index = torch.full(
(max(self.all_species) + 1,),
-1,
)
for i, species in enumerate(self.all_species):
self.species_to_species_index[species] = i
print("Species indices:", self.species_to_species_index)
print("Number of species:", n_species)

self.encoder = Encoder(n_species, hypers["d_pet"])

self.transformer = Transformer(
hypers["d_pet"],
4 * hypers["d_pet"],
hypers["num_heads"],
hypers["num_attention_layers"],
hypers["mlp_dropout_rate"],
hypers["attention_dropout_rate"],
)
self.readout = torch.nn.Linear(hypers["d_pet"], 1, bias=False)

self.num_mp_layers = hypers["num_gnn_layers"] - 1
gnn_contractions = []
gnn_transformers = []
for _ in range(self.num_mp_layers):
gnn_contractions.append(
torch.nn.Linear(2 * hypers["d_pet"], hypers["d_pet"], bias=False)
)
gnn_transformers.append(
Transformer(
hypers["d_pet"],
4 * hypers["d_pet"],
hypers["num_heads"],
hypers["num_attention_layers"],
hypers["mlp_dropout_rate"],
hypers["attention_dropout_rate"],
)
)
self.gnn_contractions = torch.nn.ModuleList(gnn_contractions)
self.gnn_transformers = torch.nn.ModuleList(gnn_transformers)

self.register_buffer("composition_weights", composition_weights)

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels] = None,
) -> Dict[str, TensorMap]:
# Checks on systems (species) and outputs are done in the
# MetatensorAtomisticModel wrapper

if selected_atoms is not None:
raise NotImplementedError(
"The PET model does not support domain decomposition."
)

n_structures = len(systems)
positions, centers, neighbors, species, segment_indices, edge_vectors = (
concatenate_structures(systems)
)
max_edges_per_node = int(torch.max(torch.bincount(centers)))

# Convert to NEF:
nef_indices, nef_to_edges_neighbor, nef_mask = get_nef_indices(
centers, len(positions), max_edges_per_node
)

# Get radial mask
r = torch.sqrt(torch.sum(edge_vectors**2, dim=-1))
radial_mask = get_radial_mask(r, 5.0, 3.0)

# Element indices
element_indices_nodes = self.species_to_species_index[species]
element_indices_centers = element_indices_nodes[centers]
element_indices_neighbors = element_indices_nodes[neighbors]

# Send everything to NEF:
edge_vectors = edge_array_to_nef(edge_vectors, nef_indices)
radial_mask = edge_array_to_nef(
radial_mask, nef_indices, nef_mask, fill_value=0.0
)
element_indices_centers = edge_array_to_nef(
element_indices_centers, nef_indices
)
element_indices_neighbors = edge_array_to_nef(
element_indices_neighbors, nef_indices
)

features = {
"cartesian": edge_vectors,
"center": element_indices_centers,
"neighbor": element_indices_neighbors,
}

# Encode
features = self.encoder(features)

# Transformer
features = self.transformer(features, radial_mask)

# GNN
if self.num_mp_layers > 0:
corresponding_edges = get_corresponding_edges(
torch.stack([centers, neighbors], dim=-1)
)
for contraction, transformer in zip(
self.gnn_contractions, self.gnn_transformers
):
new_features = nef_array_to_edges(
features, centers, nef_to_edges_neighbor
)
corresponding_new_features = new_features[corresponding_edges]
new_features = torch.concatenate(
[new_features, corresponding_new_features], dim=-1
)
new_features = contraction(new_features)
new_features = edge_array_to_nef(new_features, nef_indices)
new_features = transformer(new_features, radial_mask)
features = features + new_features

# Readout
edge_energies = self.readout(features)
edge_energies = edge_energies * radial_mask[:, :, None]

# Sum over edges
atomic_energies = torch.sum(
edge_energies, dim=(1, 2)
) # also eliminate singleton dimension 2

# Sum over centers
structure_energies = torch.zeros(
n_structures, dtype=atomic_energies.dtype, device=atomic_energies.device
)
structure_energies.index_add_(0, segment_indices, atomic_energies)

# TODO: use utils? use composition calculator?
composition = torch.zeros(
(n_structures, len(self.all_species)), device=atomic_energies.device
)
for number in self.all_species:
where_number = (species == number).to(composition.dtype)
composition[:, self.species_to_species_index[number]].index_add_(
0, segment_indices, where_number
)

structure_energies = structure_energies + composition @ self.composition_weights

return {
list(outputs.keys())[0]: TensorMap(
keys=Labels(
names=["_"],
values=torch.tensor([[0]], device=structure_energies.device),
),
blocks=[
TensorBlock(
values=structure_energies.unsqueeze(1),
samples=Labels(
names=["structure"],
values=torch.arange(
n_structures, device=structure_energies.device
).unsqueeze(1),
),
components=[],
properties=Labels(
names=["_"],
values=torch.tensor(
[[0]], device=structure_energies.device
),
),
)
],
)
}

def requested_neighbors_lists(
self,
) -> List[NeighborsListOptions]:
return [
NeighborsListOptions(
model_cutoff=self.hypers["cutoff"],
full_list=True,
)
]
Empty file.
Loading
Loading