Skip to content

Commit

Permalink
Format everything with black + add pre-commit hook
Browse files Browse the repository at this point in the history
  • Loading branch information
sirmarcel committed Oct 27, 2022
1 parent 00b025d commit 9f00d35
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 48 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.10
4 changes: 2 additions & 2 deletions macx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__="0.1.0"
__version__ = "0.1.0"

from .edge_feature import EdgeFeature
from .edge_feature import EdgeFeature
2 changes: 1 addition & 1 deletion macx/edge_feature/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .edge_feature import EdgeFeature
from .edge_feature import EdgeFeature
37 changes: 18 additions & 19 deletions macx/edge_feature/edge_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,27 @@


class EdgeFeature(hk.Module):
def __init__(self,
radial_fn: str,
n_rbf: int,
r_cut: float,
l_max: int,
z_max: int = 100):
def __init__(
self, radial_fn: str, n_rbf: int, r_cut: float, l_max: int, z_max: int = 100
):

super().__init__()

self.embed_fn = hk.Embed(vocab_size=z_max, embed_dim=1)

if radial_fn == 'bessel':
if radial_fn == "bessel":
self.radial_fn = BesselBasis(n_rbf=n_rbf, r_cut=r_cut)

self.spherical_fn = partial(sh, irreps_out=np.arange(l_max+1).tolist(),
normalize=True,
normalization='integral')

def __call__(self,
z: Array,
idx_i: Array,
idx_j: Array,
r_ij: Array,
*args,
**kwargs):
self.spherical_fn = partial(
sh,
irreps_out=np.arange(l_max + 1).tolist(),
normalize=True,
normalization="integral",
)

def __call__(
self, z: Array, idx_i: Array, idx_j: Array, r_ij: Array, *args, **kwargs
):
"""
Args:
Expand All @@ -59,10 +55,13 @@ def __call__(self,
d_ij = jnp.linalg.norm(r_ij, axis=-1) # shape: (P)
rbf_ij = self.radial_fn(d_ij) # shape: (P,n_rbf)
sph_ij = self.spherical_fn(input=r_ij) # shape: (P,m_tot) m_tot = (l_max+1)^2
A_ij = rbf_ij[:, :, None] * sph_ij[:, None, :] * x_i * x_j # shape: (P,n_rbf,m_tot)
A_ij = (
rbf_ij[:, :, None] * sph_ij[:, None, :] * x_i * x_j
) # shape: (P,n_rbf,m_tot)

return A_ij


# implemented elsewhere
#
# class Aggregation(hk.Module):
Expand Down
21 changes: 11 additions & 10 deletions macx/edge_feature/radial_basis_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Array = Union[np.ndarray, jnp.ndarray]


def safe_mask(mask, fn: Callable, operand: Array, placeholder: float = 0.) -> Array:
def safe_mask(mask, fn: Callable, operand: Array, placeholder: float = 0.0) -> Array:
"""
Safe mask which ensures that gradients flow nicely. See also
https://github.com/google/jax-md/blob/b4bce7ab9b37b6b9b2d0a5f02c143aeeb4e2a560/jax_md/util.py#L67
Expand Down Expand Up @@ -43,24 +43,25 @@ def __call__(self, r: Array) -> Array:
_r = r[..., None] # shape: (P,1)
f = lambda x: jnp.sin(jnp.pi / self.r_cut * self.offsets * x) / x

basis = safe_mask(mask=_r != 0,
fn=f,
operand=_r,
placeholder=0.) # shape: (P, n_rbf)
basis = safe_mask(
mask=_r != 0, fn=f, operand=_r, placeholder=0.0
) # shape: (P, n_rbf)

fc = self.cutoff_fn(r) # shape: (P)

return fc[..., None] * basis # shape: (P, n_rbf)

class CosineCutoff(hk.Module):

class CosineCutoff(hk.Module):
def __init__(self, r_cut: float):
super().__init__()
self.r_cut = jnp.float32(r_cut)

def __call__(self, dR: Array) -> Array:
cutoff_fn = lambda x: 0.5 * (jnp.cos(x * jnp.pi / self.r_cut) + 1.0)
return safe_mask(mask=(dR < self.r_cut),
fn=cutoff_fn,
operand=dR,
placeholder=jnp.float32(0.))
return safe_mask(
mask=(dR < self.r_cut),
fn=cutoff_fn,
operand=dR,
placeholder=jnp.float32(0.0),
)
13 changes: 4 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import setuptools

base_requires = [
'jax>=0.3.23',
'e3nn-jax>=0.10.1'
]
test_requires = [
'pytest',
]
base_requires = ["jax>=0.3.23", "e3nn-jax>=0.10.1"]
test_requires = ["pytest", "black"]

setuptools.setup(
name="macx",
Expand All @@ -15,8 +10,8 @@
author_email="",
license="ASL",
install_requires=base_requires,
extras_require={'test': test_requires},
extras_require={"test": test_requires},
packages=setuptools.find_packages(),
url="https://github.com/sirmarcel/macx",
python_requires='>=3.8',
python_requires=">=3.8",
)
13 changes: 6 additions & 7 deletions tests/test_edge_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import jax
import jax.numpy as jnp


def test_fire():
n_rbf = 10
l_max = 3

@hk.without_apply_rng
@hk.transform
def edge_feature_fn(z, idx_i, idx_j, r_ij):
return EdgeFeature(radial_fn="bessel",
n_rbf=n_rbf,
r_cut=5.0,
l_max=l_max)(z=z, idx_i=idx_i, idx_j=idx_j, r_ij=r_ij)

return EdgeFeature(radial_fn="bessel", n_rbf=n_rbf, r_cut=5.0, l_max=l_max)(
z=z, idx_i=idx_i, idx_j=idx_j, r_ij=r_ij
)

n = 5
z = jnp.ones(n, dtype=int)
Expand All @@ -25,9 +24,9 @@ def edge_feature_fn(z, idx_i, idx_j, r_ij):

r_ij = jax.vmap(lambda i, j: R[j] - R[i])(idx_i, idx_j)

inputs = {'z': z, 'idx_i': idx_i, 'idx_j': idx_j, 'r_ij': r_ij}
inputs = {"z": z, "idx_i": idx_i, "idx_j": idx_j, "r_ij": r_ij}
params = edge_feature_fn.init(jax.random.PRNGKey(0), **inputs)

out = edge_feature_fn.apply(params, **inputs)

assert out.shape == (len(idx_i), n_rbf, (l_max+1)**2)
assert out.shape == (len(idx_i), n_rbf, (l_max + 1) ** 2)
1 change: 1 addition & 0 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import macx


def test_preprocessing():
# TODO add tests
some_property1 = True
Expand Down

0 comments on commit 9f00d35

Please sign in to comment.