Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block Neural Autoregressive Flow density not properly normalised #1655

Open
danielward27 opened this issue Sep 29, 2023 · 1 comment · May be fixed by #1897
Open

Block Neural Autoregressive Flow density not properly normalised #1655

danielward27 opened this issue Sep 29, 2023 · 1 comment · May be fixed by #1897

Comments

@danielward27
Copy link
Contributor

Hi all,

I believe there is an issue with BlockNeuralAutoregressiveTransforms not forming properly normalised densities when the inverse is used to transform a distribution (the same issue as we have here danielward27/flowjax#102). See below

from functools import partial

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from numpyro.distributions import Normal, TransformedDistribution
from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN

if __name__ == "__main__":
    dim = (1,)
    init_fn, apply_fn = BlockNeuralAutoregressiveNN(dim[0])

    params = init_fn(jr.PRNGKey(0), (1,))[1]
    x = jnp.linspace(-30, 30, 10000)[:, None]
    arn = partial(apply_fn, params)
    bnaf = BlockNeuralAutoregressiveTransform(arn)

    # Plot showing bijection (1D) - note not real -> real!
    plt.plot(x, bnaf(x))
    plt.show()

    # Plot transformed normal
    dist = TransformedDistribution(Normal(jnp.zeros(dim)), bnaf.inv)
    probs = jnp.exp(dist.log_prob(x))
    probs, x = jnp.squeeze(probs), jnp.squeeze(x)

    plt.plot(x, probs)
    plt.show()

    # Rough integral
    print(jnp.trapz(probs, x))  # ~0.17

Note the codomain of BlockNeuralAutoregressiveTransform is set to real_vector, although the output is actually a linear transformation applied after a Tanh bijection, which won't map to the real line. I'm not sure what the best solution is? Maybe implement something like LeakyTanh (i.e. tanh but switch to linear outside some interval like [-3, 3]), and use that inside BNAFs instead?

@fehiepsi
Copy link
Member

fehiepsi commented Oct 2, 2023

Very interesting! Thanks for detailed explanation. Unfortunately, I'm not sure what's the best solution here. :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants