diff --git a/bgflow/nn/flow/transformer/jax.py b/bgflow/nn/flow/transformer/jax.py index df08cb65..838a6769 100644 --- a/bgflow/nn/flow/transformer/jax.py +++ b/bgflow/nn/flow/transformer/jax.py @@ -2,7 +2,7 @@ try: import jax - import jax.numpy as jnp + from jax import lax, numpy as jnp except ImportError: jax = None jnp = None @@ -28,16 +28,10 @@ def affine_transform(x, a, b): def smooth_ramp(x, logalpha, power=1, eps=1e-9): """Smooth ramp.""" assert power > 0 - assert isinstance(power, int) assert eps > 0 alpha = jnp.exp(logalpha) - # double `where` trick to avoid NaN in backward pass - z = jnp.where(x > eps, x, jnp.ones_like(x) * eps) - normalizer = jnp.exp(-alpha * 1.) - return jnp.where( - x > eps, - jnp.exp(-alpha * jnp.power(z, -power)) / normalizer, - jnp.zeros_like(z)) + x = jnp.clip(x, a_min=eps) + return jnp.exp(alpha * (1-lax.integer_pow(x, -power))) def monomial_ramp(x, order=2):