Skip to content

Commit

Permalink
add docstring to ACE and MACE, fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
szbernat committed Nov 2, 2022
1 parent e85debf commit 4f26f3a
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 100 deletions.
3 changes: 1 addition & 2 deletions macx/gnn/edge_features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Sequence
from functools import partial
from typing import Literal, Optional
from typing import Optional

import e3nn_jax as e3nn
import jax.numpy as jnp
Expand Down
5 changes: 3 additions & 2 deletions macx/gnn/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def init_state(self, shape, dtype):
raise NotImplementedError

def node_factory(self, node_attrs):
r"""Return the initial embeddings as a :class:`GraphNodes` instance."""
r"""Return the initial node representations."""
raise NotImplementedError

def edge_feature_callback(self, pos_sender, pos_receiver, sender_idx, receiver_idx):
Expand All @@ -144,7 +144,7 @@ def edge_feature_callback(self, pos_sender, pos_receiver, sender_idx, receiver_i
raise NotImplementedError

def edge_factory(self, r, occupancies, custom_mask):
r"""Return a function that builds all the edges used in the GNN."""
r"""Return a function that builds the edges used in the GNN."""
mask_val = r.shape[0] + 1
edge_factory = GraphEdgeBuilder(
self.cutoff,
Expand All @@ -167,6 +167,7 @@ def __call__(self, r, node_attrs):
Args:
r (float, (:math:`N_\text{nodes}`, 3)): coordinates of the graph nodes.
node_attrs (Any): additional information about the nodes (like atom type).
Returns:
float, (:math:`N_\text{nodes}`, :data:`embedding_dim`):
Expand Down
13 changes: 11 additions & 2 deletions macx/gnn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def mask_self_edges(idx):


def mask_custom_edges(idx, mask):
r"""
Mask the edges according to :data:`mask`.
Args:
idx (int, (:math:`N_\text{nodes}`, :math:`N_\text{nodes}`)): matrix of
receiving node indeces.
mask (bool, (:math:`N_\text{nodes}`, :math:`N_\text{nodes}`)): mask definition,
entries that contain :data:`False` will be masked out.
"""
return jnp.where(mask, idx.shape[1], idx)


Expand Down Expand Up @@ -81,7 +90,7 @@ def prune_graph_edges(
returns some data (features) computed for the edges.
Returns:
~jax.types.GraphEdges: object containing the indeces of the edge
GraphEdges: object containing the indeces of the edge
sending and edge receiving nodes, along with the features associated
with the edges.
"""
Expand Down Expand Up @@ -141,7 +150,7 @@ def dist(sender, receiver):


def difference_callback(pos_sender, pos_receiver, sender_idx, receiver_idx):
r"""feature_callback computing the Euclidian difference vector for each edge."""
r"""Feature_callback computing the Euclidian difference vector for each edge."""
if len(pos_sender) == 0 or len(pos_receiver) == 0:
return jnp.zeros((len(sender_idx), 3))
diffs = pos_receiver[receiver_idx] - pos_sender[sender_idx]
Expand Down
66 changes: 47 additions & 19 deletions macx/models/ace.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
from typing import Sequence

import e3nn_jax as e3nn
import haiku as hk
import jax.numpy as jnp
from jax import ops

from ..gnn import GraphNeuralNetwork, MessagePassingLayer
from ..gnn.edge_features import EdgeFeatures
from .symmetric_contraction import SymmetricContraction
from ..tools.e3nn_ext import GeneralLinear, convert_irreps_array
from ..tools.symmetric_contraction import SymmetricContraction


def to_onehot(features, node_types):
r"""
Create onehot encoded vectors from :data:`features`.
Args:
features (int, jnp.ndarray): type of the nodes
node_types (Sequence[int]): list of possible node types.
"""
ones = []
for i, e in enumerate(node_types):
ones.append(jnp.where(features == e, jnp.ones(1), jnp.zeros(1))[..., None])
return jnp.concatenate(ones, axis=-1)


class ACELayer(MessagePassingLayer):
r"""
Compute the ACE interaction.
The ACE interaction is composed of constructing the edge features, summing them
on each receiver node, and creating the symmetrized, many-body node embeddings.
Args:
max_body_order (int): the maximum body order up to which node embeddings are
constructed.
embedding_irreps (Sequence[e3nn.Irrep]): the irreps of the node embeddings.
mix_atomic_basis (bool): default :data:`True`, whether to apply a linear layer
on the initial node embeddings, before symmetrizing them.
"""

def __init__(
self,
ilayer,
Expand All @@ -37,16 +58,9 @@ def __init__(
self.n_node_type,
)
if mix_atomic_basis:
self.atomic_basis_weights = hk.get_parameter(
"atomic_basis_weights",
[len(self.edge_feat_irreps), self.embedding_dim, self.embedding_dim],
init=hk.initializers.VarianceScaling(),
self.atomic_basis_layer = convert_irreps_array(embedding_irreps)(
GeneralLinear(embedding_irreps, mix_channels=True)
)
acc = 0
self.edge_split_idxs = []
for ir in self.edge_feat_irreps[:-1]:
acc += 2 * ir.l + 1
self.edge_split_idxs.append(acc)

def get_update_edges_fn(self):
return None
Expand All @@ -59,14 +73,8 @@ def aggregate_edges_for_nodes(nodes, edges):
num_segments=self.n_nodes,
)
if self.mix_atomic_basis:
As = jnp.split(A, self.edge_split_idxs, axis=-1)
A = jnp.concatenate(
[
jnp.einsum("kj,bji->bki", weight, A)
for weight, A in zip(self.atomic_basis_weights, As)
],
axis=-1,
)
A = self.atomic_basis_layer(A)

return A

return aggregate_edges_for_nodes
Expand All @@ -80,6 +88,26 @@ def update_nodes(nodes, A):


class ACE(GraphNeuralNetwork):
r"""
The ACE model.
Args:
n_nodes (int): the maximum number of nodes in the graph.
embedding_dim (int): the embedding dimension, should be equal to the number of
radial basis functions.
cutoff (float): distance cutoff, beyond which interactions are not considered.
max_body_order (int): the maximum body order up to which node embeddings are
constructed.
embedding_irreps (Sequence[e3nn.Irrep]): the irreps of the node embeddings.
edge_feat_irreps (Sequence[e3nn.Irrep]): the irreps of the edge features.
node_types (Sequence[int]): the list of possible node types.
edge_feat_factory (Optional[Callable]): the edge feature constructing class,
defaults to :class:`~gnn.edge_features.EdgeFeatures`.
edge_feat_kwargs (Optional[dict]): extra arguments to be passed to
:data:`edge_feat_factory`.
layer_kwargs (dict): optional, kwargs to be passed to the layers.
"""

def __init__(
self,
n_nodes: int,
Expand Down
47 changes: 41 additions & 6 deletions macx/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import e3nn_jax as e3nn
import jax.numpy as jnp
from jax import ops

from ..gnn import GraphNeuralNetwork
from ..gnn.edge_features import EdgeFeatures
Expand All @@ -11,12 +10,28 @@


class MACELayer(ACELayer):
r"""
Compute a single MACE interaction layer.
Args:
max_body_order (int): the maximum body order up to which node embeddings are
constructed.
embedding_irreps (Sequence[e3nn.Irrep]): the irreps of the node embeddings.
mix_atomic_basis (bool): default :data:`True`, whether to apply a linear layer
on the initial node embeddings, before symmetrizing them.
convolution_weight_factory (Optional[Callable]): a callable returning the
linear weights to be used in the edge features--node embeddings
:tensor product.
residual_weight_factory (Optional[Callable]): a callable returning the
linear weights to be used in the embedding update tensor product
"""

def __init__(
self,
*ace_args,
prev_embed_irreps: Sequence[e3nn.Irrep],
convolution_weight_factory: Optional[Callable] = None,
residual_weight_factory: Optional[Callable] = None,
update_weight_factory: Optional[Callable] = None,
**ace_kwargs,
):
super().__init__(*ace_args, **ace_kwargs)
Expand Down Expand Up @@ -46,9 +61,9 @@ def __init__(
)
)
if not self.first_layer:
self.residual_tp = convert_irreps_array(
embedding_irreps, prev_embed_irreps
)(WeightedTensorProduct(embedding_irreps, residual_weight_factory))
self.update_tp = convert_irreps_array(embedding_irreps, prev_embed_irreps)(
WeightedTensorProduct(embedding_irreps, update_weight_factory)
)

def get_update_edges_fn(self):
return None
Expand Down Expand Up @@ -80,14 +95,34 @@ def update_nodes(nodes, A):
self.embed_mixing_layer(nodes["embedding"]),
nodes["node_type"],
)
nodes["embedding"] = self.residual_tp(update, residual, update)
nodes["embedding"] = self.update_tp(update, residual, update)

return nodes["embedding"] if self.last_layer else nodes

return update_nodes


class MACE(GraphNeuralNetwork):
r"""
The MACE model.
Args:
n_nodes (int): the maximum number of nodes in the graph.
embedding_dim (int): the embedding dimension, should be equal to the number of
radial basis functions.
cutoff (float): distance cutoff, beyond which interactions are not considered.
max_body_order (int): the maximum body order up to which node embeddings are
constructed.
embedding_irreps (Sequence[e3nn.Irrep]): the irreps of the node embeddings.
edge_feat_irreps (Sequence[e3nn.Irrep]): the irreps of the edge features.
node_types (Sequence[int]): the list of possible node types.
edge_feat_factory (Optional[Callable]): the edge feature constructing class,
defaults to :class:`~gnn.edge_features.EdgeFeatures`.
edge_feat_kwargs (Optional[dict]): extra arguments to be passed to
:data:`edge_feat_factory`.
layer_kwargs (dict): optional, kwargs to be passed to the layers.
"""

def __init__(
self,
n_nodes: int,
Expand Down
Loading

0 comments on commit 4f26f3a

Please sign in to comment.