Skip to content

Commit

Permalink
implement GeneralLinear, test MACE equivariance
Browse files Browse the repository at this point in the history
  • Loading branch information
szbernat committed Nov 2, 2022
1 parent fd349a4 commit 65bb015
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 48 deletions.
24 changes: 12 additions & 12 deletions macx/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

from ..gnn import GraphNeuralNetwork
from ..gnn.edge_features import EdgeFeatures
from ..tools.e3nn_ext import ArrayLinear
from ..tools.e3nn_ext import GeneralLinear, WeightedTensorProduct, convert_irreps_array
from .ace import ACELayer, to_onehot
from .symmetric_contraction import WeightedTensorProduct


class MACELayer(ACELayer):
Expand All @@ -20,19 +19,20 @@ def __init__(
):
super().__init__(*ace_args, **ace_kwargs)
embedding_irreps = ace_kwargs["embedding_irreps"]
# emb_irreps = [e3nn.Irrep("0e")] if self.first_layer else embedding_irreps
self.prev_embed_mixing_layer = ArrayLinear(
prev_embed_irreps, prev_embed_irreps, self.embedding_dim
self.prev_embed_mixing_layer = convert_irreps_array(prev_embed_irreps)(
GeneralLinear(prev_embed_irreps, mix_channels=True)
)
self.message_mixing_layer = ArrayLinear(
embedding_irreps, embedding_irreps, self.embedding_dim

self.message_mixing_layer = convert_irreps_array(embedding_irreps)(
GeneralLinear(embedding_irreps, mix_channels=True)
)
if not self.first_layer:
self.embed_mixing_layer = ArrayLinear(
embedding_irreps,
embedding_irreps,
self.embedding_dim,
channel_out=self.n_node_type,
self.embed_mixing_layer = convert_irreps_array(embedding_irreps)(
GeneralLinear(
embedding_irreps,
mix_channels=True,
new_channel_dim=self.n_node_type,
)
)
self.wtp = WeightedTensorProduct(
self.edge_feat_irreps,
Expand Down
22 changes: 0 additions & 22 deletions macx/models/symmetric_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,6 @@ def __call__(self, w, x, y):
return weighted_ia_out.array


class WeightedTensorProduct(hk.Module):
def __init__(
self,
irreps_x: Sequence[e3nn.Irrep],
irreps_y: Sequence[e3nn.Irrep],
irreps_out: Sequence[e3nn.Irrep],
channels_out: int = 1,
):
super().__init__()
self.irreps_x = irreps_x
self.irreps_y = irreps_y
self.irreps_out = irreps_out
self.weighted_sum = e3nn.Linear(irreps_out)

def __call__(self, x, y):
ia_x = e3nn.IrrepsArray(self.irreps_x, x)
ia_y = e3nn.IrrepsArray(self.irreps_y, y)
ia_out = e3nn.tensor_product(ia_x, ia_y, filter_ir_out=self.irreps_out)
weighted_ia_out = self.weighted_sum(ia_out)
return weighted_ia_out.array


class SymmetricContraction(hk.Module):
r"""
Create higher body-order tensors transformig according to some irreps.
Expand Down
107 changes: 93 additions & 14 deletions macx/tools/e3nn_ext.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
from e3nn_jax import IrrepsArray, Linear
from functools import wraps
from typing import Optional, Sequence

import e3nn_jax as e3nn
import haiku as hk

def add_mul_to_irrep_str(irreps, mul):
return "+".join(map(lambda x: f"{mul}x{x}", irreps))

class GeneralLinear(hk.Module):
r"""
General equivariant linear layer.
class ArrayLinear(Linear):
def __init__(self, irreps_out, irreps_in, embedding_dim, channel_out=1):
# base class has attribute irreps_in, don't overwrite it
self.input_irreps = irreps_in
mult_irreps_out = add_mul_to_irrep_str(irreps_out, embedding_dim)
super().__init__(mult_irreps_out, channel_out)
The input is assumed to be of shape
[:math:`...`, :math:`N_\text{channels in}`, :math:`\sum_i m_i(2l_i+1)`] where
:math:`i` runs over the input irreps, and :math:`m_i` is the multiplicity of the
:math:`i`th irrep.
Args:
irreps_out (Sequence[e3nn_jax.Irrep]): sequence of output irreps.
mix_channels (bool): default False, whether to mix the different input channels.
Must be true if either :data:`channels_out` or :data:`new_channel_dim`
is given.
channels_out (int): optional, the number of output channels, mixed from
the input channels. If :data:`None`, it is set to :math:`N_\text{channels in}.
new_channel_dim (int): optional, the dimension of a new channel axes inserted
before the input channels axis. If :data:`None`, no new channel axis
is inserted. The new channels are mixed from the input channels.
"""

def __init__(
self,
irreps_out: Sequence[e3nn.Irrep],
mix_channels: bool = False,
channels_out: Optional[int] = None,
new_channel_dim: Optional[int] = None,
):
if (channels_out or new_channel_dim) and not mix_channels:
raise ValueError(
"mix_channels has to be True if "
"channels_out or new_channel_dim is given"
)
super().__init__()
self.irreps_out = irreps_out
self.mix_channels = mix_channels
self.channels_out = channels_out
self.new_channel_dim = new_channel_dim

def __call__(self, x):
*leading_dims, embedding_dim, _ = x.shape
x = IrrepsArray(self.input_irreps, x).axis_to_mul()[..., None, :] # channel dim
out = super().__call__(x)
out = out.mul_to_axis()
return out.array if self.channel_out > 1 else out.array.squeeze(axis=-3)
emb_dim = x.array.shape[-2]
channels_out = self.channels_out or emb_dim
linear = (
e3nn.Linear(self.irreps_out)
if not self.mix_channels
else e3nn.Linear(
self.irreps_out, channels_out * (self.new_channel_dim or 1)
)
)
out = linear(x)
if self.new_channel_dim:
out = out.reshape((*out.shape[:-2], self.new_channel_dim, channels_out, -1))
return out


class WeightedTensorProduct(hk.Module):
def __init__(
self,
irreps_x: Sequence[e3nn.Irrep],
irreps_y: Sequence[e3nn.Irrep],
irreps_out: Sequence[e3nn.Irrep],
):
super().__init__()
self.irreps_x = irreps_x
self.irreps_y = irreps_y
self.irreps_out = irreps_out
self.weighted_sum = GeneralLinear(irreps_out)

def __call__(self, x, y):
ia_x = e3nn.IrrepsArray(self.irreps_x, x)
ia_y = e3nn.IrrepsArray(self.irreps_y, y)
ia_out = e3nn.tensor_product(ia_x, ia_y, filter_ir_out=self.irreps_out)
weighted_ia_out = self.weighted_sum(ia_out)
return weighted_ia_out.array


def convert_irreps_array(*irreps: Sequence[e3nn.Irrep]):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
assert len(args) == len(irreps)
args = (e3nn.IrrepsArray(irrep, arg) for irrep, arg in zip(irreps, args))
outs = func(*args, **kwargs)
return (
tuple(out.array for out in outs)
if isinstance(outs, tuple)
else outs.array
)

return wrapper

return decorator
43 changes: 43 additions & 0 deletions tests/test_mace_equivariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
def test_mace_equivariance():
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
from macx.models.mace import MACE
from macx.tools.state_callback import state_callback
from scipy.spatial.transform import Rotation

rng = jax.random.PRNGKey(23)
R = jnp.array(Rotation.from_euler("x", 52, degrees=True).as_matrix())

@hk.without_apply_rng
@hk.transform_with_state
def mace(r, r_type):
return MACE(
5,
5,
10.0,
4,
embedding_irreps=[
[e3nn.Irrep("0e"), e3nn.Irrep("1o")],
[e3nn.Irrep("0e"), e3nn.Irrep("1o")],
],
edge_feat_irreps=[e3nn.Irrep("0e"), e3nn.Irrep("1o")],
node_types=[0, 1, 2, 3, 4],
)(r, r_type)

r = jax.random.normal(rng, (1000, 5, 3))
rotated_r = jnp.einsum("ij,baj->bai", R, r)
r_type = jnp.tile(jnp.arange(5)[None], (1000, 1))

jitted = jax.jit(jax.vmap(mace.apply, (None, 0, 0, 0)))

params, state = jax.vmap(mace.init, (None, 0, 0), (None, 0))(rng, r, r_type)
_, state = jitted(params, state, r, r_type)
state, _ = state_callback(state, batch_dim=True)

B, state = jitted(params, state, r, r_type)
B_rot, _ = jitted(params, state, rotated_r, r_type)
rot_B = B.at[:, :, :, 1:].set(jnp.einsum("ij,baej->baei", R, B[:, :, :, 1:]))
diff = B_rot - rot_B
assert jnp.abs(diff).max() < 2.0e-5

0 comments on commit 65bb015

Please sign in to comment.