-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement GeneralLinear, test MACE equivariance
- Loading branch information
Showing
4 changed files
with
148 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,99 @@ | ||
from e3nn_jax import IrrepsArray, Linear | ||
from functools import wraps | ||
from typing import Optional, Sequence | ||
|
||
import e3nn_jax as e3nn | ||
import haiku as hk | ||
|
||
def add_mul_to_irrep_str(irreps, mul): | ||
return "+".join(map(lambda x: f"{mul}x{x}", irreps)) | ||
|
||
class GeneralLinear(hk.Module): | ||
r""" | ||
General equivariant linear layer. | ||
class ArrayLinear(Linear): | ||
def __init__(self, irreps_out, irreps_in, embedding_dim, channel_out=1): | ||
# base class has attribute irreps_in, don't overwrite it | ||
self.input_irreps = irreps_in | ||
mult_irreps_out = add_mul_to_irrep_str(irreps_out, embedding_dim) | ||
super().__init__(mult_irreps_out, channel_out) | ||
The input is assumed to be of shape | ||
[:math:`...`, :math:`N_\text{channels in}`, :math:`\sum_i m_i(2l_i+1)`] where | ||
:math:`i` runs over the input irreps, and :math:`m_i` is the multiplicity of the | ||
:math:`i`th irrep. | ||
Args: | ||
irreps_out (Sequence[e3nn_jax.Irrep]): sequence of output irreps. | ||
mix_channels (bool): default False, whether to mix the different input channels. | ||
Must be true if either :data:`channels_out` or :data:`new_channel_dim` | ||
is given. | ||
channels_out (int): optional, the number of output channels, mixed from | ||
the input channels. If :data:`None`, it is set to :math:`N_\text{channels in}. | ||
new_channel_dim (int): optional, the dimension of a new channel axes inserted | ||
before the input channels axis. If :data:`None`, no new channel axis | ||
is inserted. The new channels are mixed from the input channels. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
irreps_out: Sequence[e3nn.Irrep], | ||
mix_channels: bool = False, | ||
channels_out: Optional[int] = None, | ||
new_channel_dim: Optional[int] = None, | ||
): | ||
if (channels_out or new_channel_dim) and not mix_channels: | ||
raise ValueError( | ||
"mix_channels has to be True if " | ||
"channels_out or new_channel_dim is given" | ||
) | ||
super().__init__() | ||
self.irreps_out = irreps_out | ||
self.mix_channels = mix_channels | ||
self.channels_out = channels_out | ||
self.new_channel_dim = new_channel_dim | ||
|
||
def __call__(self, x): | ||
*leading_dims, embedding_dim, _ = x.shape | ||
x = IrrepsArray(self.input_irreps, x).axis_to_mul()[..., None, :] # channel dim | ||
out = super().__call__(x) | ||
out = out.mul_to_axis() | ||
return out.array if self.channel_out > 1 else out.array.squeeze(axis=-3) | ||
emb_dim = x.array.shape[-2] | ||
channels_out = self.channels_out or emb_dim | ||
linear = ( | ||
e3nn.Linear(self.irreps_out) | ||
if not self.mix_channels | ||
else e3nn.Linear( | ||
self.irreps_out, channels_out * (self.new_channel_dim or 1) | ||
) | ||
) | ||
out = linear(x) | ||
if self.new_channel_dim: | ||
out = out.reshape((*out.shape[:-2], self.new_channel_dim, channels_out, -1)) | ||
return out | ||
|
||
|
||
class WeightedTensorProduct(hk.Module): | ||
def __init__( | ||
self, | ||
irreps_x: Sequence[e3nn.Irrep], | ||
irreps_y: Sequence[e3nn.Irrep], | ||
irreps_out: Sequence[e3nn.Irrep], | ||
): | ||
super().__init__() | ||
self.irreps_x = irreps_x | ||
self.irreps_y = irreps_y | ||
self.irreps_out = irreps_out | ||
self.weighted_sum = GeneralLinear(irreps_out) | ||
|
||
def __call__(self, x, y): | ||
ia_x = e3nn.IrrepsArray(self.irreps_x, x) | ||
ia_y = e3nn.IrrepsArray(self.irreps_y, y) | ||
ia_out = e3nn.tensor_product(ia_x, ia_y, filter_ir_out=self.irreps_out) | ||
weighted_ia_out = self.weighted_sum(ia_out) | ||
return weighted_ia_out.array | ||
|
||
|
||
def convert_irreps_array(*irreps: Sequence[e3nn.Irrep]): | ||
def decorator(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
assert len(args) == len(irreps) | ||
args = (e3nn.IrrepsArray(irrep, arg) for irrep, arg in zip(irreps, args)) | ||
outs = func(*args, **kwargs) | ||
return ( | ||
tuple(out.array for out in outs) | ||
if isinstance(outs, tuple) | ||
else outs.array | ||
) | ||
|
||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
def test_mace_equivariance(): | ||
import e3nn_jax as e3nn | ||
import haiku as hk | ||
import jax | ||
import jax.numpy as jnp | ||
from macx.models.mace import MACE | ||
from macx.tools.state_callback import state_callback | ||
from scipy.spatial.transform import Rotation | ||
|
||
rng = jax.random.PRNGKey(23) | ||
R = jnp.array(Rotation.from_euler("x", 52, degrees=True).as_matrix()) | ||
|
||
@hk.without_apply_rng | ||
@hk.transform_with_state | ||
def mace(r, r_type): | ||
return MACE( | ||
5, | ||
5, | ||
10.0, | ||
4, | ||
embedding_irreps=[ | ||
[e3nn.Irrep("0e"), e3nn.Irrep("1o")], | ||
[e3nn.Irrep("0e"), e3nn.Irrep("1o")], | ||
], | ||
edge_feat_irreps=[e3nn.Irrep("0e"), e3nn.Irrep("1o")], | ||
node_types=[0, 1, 2, 3, 4], | ||
)(r, r_type) | ||
|
||
r = jax.random.normal(rng, (1000, 5, 3)) | ||
rotated_r = jnp.einsum("ij,baj->bai", R, r) | ||
r_type = jnp.tile(jnp.arange(5)[None], (1000, 1)) | ||
|
||
jitted = jax.jit(jax.vmap(mace.apply, (None, 0, 0, 0))) | ||
|
||
params, state = jax.vmap(mace.init, (None, 0, 0), (None, 0))(rng, r, r_type) | ||
_, state = jitted(params, state, r, r_type) | ||
state, _ = state_callback(state, batch_dim=True) | ||
|
||
B, state = jitted(params, state, r, r_type) | ||
B_rot, _ = jitted(params, state, rotated_r, r_type) | ||
rot_B = B.at[:, :, :, 1:].set(jnp.einsum("ij,baej->baei", R, B[:, :, :, 1:])) | ||
diff = B_rot - rot_B | ||
assert jnp.abs(diff).max() < 2.0e-5 |