You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I believe there is an issue with BlockNeuralAutoregressiveTransforms not forming properly normalised densities when the inverse is used to transform a distribution (the same issue as we have here danielward27/flowjax#102). See below
from functools import partial
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from numpyro.distributions import Normal, TransformedDistribution
from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
if __name__ == "__main__":
dim = (1,)
init_fn, apply_fn = BlockNeuralAutoregressiveNN(dim[0])
params = init_fn(jr.PRNGKey(0), (1,))[1]
x = jnp.linspace(-30, 30, 10000)[:, None]
arn = partial(apply_fn, params)
bnaf = BlockNeuralAutoregressiveTransform(arn)
# Plot showing bijection (1D) - note not real -> real!
plt.plot(x, bnaf(x))
plt.show()
# Plot transformed normal
dist = TransformedDistribution(Normal(jnp.zeros(dim)), bnaf.inv)
probs = jnp.exp(dist.log_prob(x))
probs, x = jnp.squeeze(probs), jnp.squeeze(x)
plt.plot(x, probs)
plt.show()
# Rough integral
print(jnp.trapz(probs, x)) # ~0.17
Note the codomain of BlockNeuralAutoregressiveTransform is set to real_vector, although the output is actually a linear transformation applied after a Tanh bijection, which won't map to the real line. I'm not sure what the best solution is? Maybe implement something like LeakyTanh (i.e. tanh but switch to linear outside some interval like [-3, 3]), and use that inside BNAFs instead?
The text was updated successfully, but these errors were encountered:
Hi all,
I believe there is an issue with
BlockNeuralAutoregressiveTransform
s not forming properly normalised densities when the inverse is used to transform a distribution (the same issue as we have here danielward27/flowjax#102). See belowNote the codomain of BlockNeuralAutoregressiveTransform is set to
real_vector
, although the output is actually a linear transformation applied after a Tanh bijection, which won't map to the real line. I'm not sure what the best solution is? Maybe implement something likeLeakyTanh
(i.e. tanh but switch to linear outside some interval like [-3, 3]), and use that inside BNAFs instead?The text was updated successfully, but these errors were encountered: