Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example to marginalize discrete latent variables (WIP) #646

Open
wants to merge 3 commits into
base: enum-messengers
Choose a base branch
from

Conversation

rtbs-dev
Copy link

Initial structure and outline, with lead up to inference. Currently shows some kind of broadcasting error.

@fehiepsi
Copy link
Member

fehiepsi commented Jun 24, 2020

@tbsexton Wonderful tutorial!!! I am looking for its final version. ;)

About the shape error, I think you need to declare some dimensions as event dimensions. Alternatively, you can use plate('n_nodes') to declare them as batch dimensions.

    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(0, 1).expand([n_edges]).to_event(1))
    v = ny.sample("v", dist.Gamma(1, 20).expand([n_edges]).to_event(1))
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v).to_event(1))
    s_ij = squareform(Λ)  # adjacency matrix to recover via inference
    
    with ny.plate("n_cascades", n_cascades, dim=-1):
        # infer infection source node
        x0 = ny.sample("x0", dist.Categorical(probs=ϕ))
        src = one_hot(x0, n_nodes)
        
        # simulate ode and realize
        p_infectious = batched_diffuse(s_ij, 5, src)
        print(s_ij.shape, x0, src.shape, p_infectious.shape)
#         infectious = spread_jax(s_ij, one_hot(x0, n_nodes),0, 5)
        real = dist.Bernoulli(probs=p_infectious).to_event(1)
        return ny.sample("obs", real, obs=infections)

This still does not work yet because batched_diffuse does not work with a batch of x0 I guess. When enumerated, x0 will have shape (n_nodes, n_cascades) hence I think we need to modify batched_diffuse to support batching. I'll try to make a fix for it tomorrow.

@fehiepsi
Copy link
Member

fehiepsi commented Jun 24, 2020

@tbsexton If you also modify batched_diffuse to incorporate enumeration, the shape issue will be resolved

def batched_diffuse(p, T, u_obs):
    org_shape = u_obs.shape
    u_obs = np.reshape(u_obs, (-1, org_shape[-1]))
    out = vmap(diffuse, in_axes=(None, None, 0))(p, T, u_obs)
    return np.reshape(out, org_shape)

However, I prefer to do scan(vmap(...)) rather than vmap(scan(...)). For the former, you can just replace

u_add = lax.tanh(np.matmul(p, u))

by

u_add = np.tanh(np.matmul(p, u[..., None]).squeeze(-1))

and use diffuse instead of batched_diffuse in your model. Using diffuse directly is also much faster than using batched_diffuse.

However, MCMC can't find valid initial parameters for your model. This is probably an issue of our Bernoulli implementation so I will try to debug it.

Should @jit?

No, you don't need to jit any part of your model. MCMC will jit the whole log density computation, which includes those small utilities that decorated by jit in your script.

@rtbs-dev
Copy link
Author

Interesting tricks, thanks! Implementing them at the moment, though, I'm getting a funsor-related error:

...
~/Documents/Code/pr/numpyro/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
    397 
    398     if has_enumerate_support:
--> 399         from numpyro.contrib.funsor import config_enumerate, enum
    400 
    401         if not isinstance(model, enum):

~/Documents/Code/pr/numpyro/numpyro/contrib/funsor/__init__.py in <module>
      7     raise ImportError("`funsor` package is missing. You can install it with `pip install funsor`.")
      8 
----> 9 from numpyro.contrib.funsor.enum_messenger import (enum, infer_config, markov, plate,
     10                                                    to_data, to_funsor, trace)
     11 from numpyro.contrib.funsor.infer_util import config_enumerate, log_density

~/Documents/Code/pr/numpyro/numpyro/contrib/funsor/enum_messenger.py in <module>
     14 from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack
     15 
---> 16 funsor.set_backend("jax")
     17 
     18 

AttributeError: module 'funsor' has no attribute 'set_backend'

@fehiepsi
Copy link
Member

With u_end = np.clip(u_end, a_max=1-1e-6), MCMC runs. So I think your diffuse_one_step implementation is forcing u to 1... I think you will have a better idea of what is going on here. :D

@fehiepsi
Copy link
Member

getting a funsor-related error:

Oops... sorry, you will need to install the master branch of funsor:

pip uninstall funsor
pip install https://github.com/pyro-ppl/funsor/archive/master.zip

@rtbs-dev
Copy link
Author

hmmm different error now:

... 


~/Documents/Code/pr/numpyro/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    153     substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
    154     # no param is needed for log_density computation because we already substitute
--> 155     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    156     return - log_joint
    157 

~/Documents/Code/pr/numpyro/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
     79     model = substitute(model, param_map=params)
     80     with plate_to_enum_plate():
---> 81         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
     82     log_factors = []
     83     sum_vars, prod_vars = frozenset(), frozenset()

~/Documents/Code/pr/numpyro/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    145         :return: `OrderedDict` containing the execution trace.
    146         """
--> 147         self(*args, **kwargs)
    148         return self.trace
    149 

~/Documents/Code/pr/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
     65     def __call__(self, *args, **kwargs):
     66         with self:
---> 67             return self.fn(*args, **kwargs)
     68 
     69 

~/Documents/Code/pr/numpyro/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     54 
     55     def __exit__(self, *args, **kwargs):
---> 56         assert _PYRO_STACK[-1] is self
     57         _PYRO_STACK.pop()
     58 

re: forcing to 1., yes, in the limit all nodes would go to 1, though shouldnt happen after a single iteration. There are two fixes for this

  • introduce a time-step parameter, T, and an "infectousness" coefficient for diffuse, both to be inferred over. Trying to avoid this and poke around with the simpler model before doing this (adds complexity), but biasing it with e.g. a laplace dist. might add needed regularization?

A bit more insight into the "how many time-steps should we assume" problem... i.e how far we should propagate the ODE in each plate before stopping to "measure" the infections?

  • In essence, we can ask the reverse problem in the loop...given N infections, what does my current belief of the network tell me about how long it should take to get there?
  • R_0 in network epidemiology is basically that... the reproduction rate of infection on the nodes. This gets modified for network topologies a bit, but still pretty basic from how I've defined edges as transmission probability
  • Since each transmission probability is a bernoulli $P ~ Ber(p)$, the expected number of successes is just P...so R0 of each node is the sum of $p$ in its neighbors...therefore theres a prior distribution over nodes for R_0, (i.e. infections in one time step), from which we can back out expected time to N

Still working on this, but I was hoping to avoid the problem for a simpler model at first ^_^

@rtbs-dev
Copy link
Author

@fehiepsi Oops! a couple things from your suggestions I had not implemented right... missed the to_event()s in the beta hyperpriors. Still not quite sure I understood what that is doing.

I also forgot to change the diffuse call in the plate to run the same number of iterations as the synthetic data (2, not 5). Big difference. Ill commit changes that got it to run, shortly.

But it runs now! Going to be a bit I guess, since it only seems to be able to process 1-2 it/sec....

@fehiepsi
Copy link
Member

fehiepsi commented Jun 24, 2020

Yeah, the distribution shape in Pyro (which corresponds Pyro/NumPyro 1-1 with the plate graph - hence enabling advanced inference mechanism) is flexible but it is not straightforward to keep things synced. I think the best resource is this tensor shape tutorial.

be able to process 1-2 it/sec

Probably it will be better if the chain moves to some useful domain. Currently, it takes 1000 leapfrog steps per sample. I am not sure if GPU helps because I can't access GPU in a few weeks.

@rtbs-dev
Copy link
Author

@fehiepsi a few updates:

  • tried to speed up inference using SVI and NeuTra, but neither were successful (see bottom)
  • testing out new diffusion model, since Jax has ODEint now. Should work out better than the old heuristic.
  • would I be able to add a log-factor to this model even with the new functionality? Something like a mahalonobis-distance between observed infections and simulated ones? Not sure the HMM example you guys show would translate with the ny.plate.
  • Attempting to split out sparsity and weight inference (see nb for references), but getting another NotImplementedError related to funsor, now. Not sure what changed, but it seems related to shapes
~/Documents/Code/pr/numpyro/numpyro/primitives.py in apply_stack(msg)
     21     pointer = 0
     22     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 23         handler.process_message(msg)
     24         # When a Messenger sets the "stop" field of a message,
     25         # it prevents any Messengers above it on the stack from being applied.

~/Documents/Code/pr/numpyro/numpyro/contrib/funsor/enum_messenger.py in process_message(self, msg)
    506             raise NotImplementedError("expand=True not implemented")
    507 
--> 508         size = msg["fn"].enumerate_support(expand=False).shape[0]
    509         raw_value = np.arange(0, size)
    510         funsor_value = funsor.Tensor(

~/Documents/Code/pr/numpyro/numpyro/distributions/distribution.py in enumerate_support(self, expand)
    261         containing all values in the support.
    262         """
--> 263         raise NotImplementedError
    264 
    265     def expand(self, batch_shape):

@fehiepsi
Copy link
Member

fehiepsi commented Jun 29, 2020

@tbsexton That error is a bug! I'll push a fix soon.

About the slowness, that is because your model has many latent variables. But I believe you can use AutoDiagonalNormal to fit your model. I think Normalization Flows work for tricky posteriors but won't be useful for models with high-dimensional latent variables. AutoLowRankMVN would be helpful in case there are correlations between variables. For large models, GPU will also be helpful.

the HMM example you guys show would translate with the ny.plate.

I think it is compatible with plate because, under the hood, we use a Unit distribution to store that log_factor. Just make sure that log_factor.shape is compatible with your plate shapes. For example,

with plate('N', 10, dim=-2):
    with plate('M', 5, dim=-1):
        factor('factor', np.ones((10, 5)))

Does that solve your usage case?

@fehiepsi
Copy link
Member

fehiepsi commented Jul 1, 2020

@tbsexton Unfortunately, we don't support enumerate sites with .to_event(...). You can enumerate that portion of code with

    with ny.plate("n_edges", n_edges, dim=-1):
        u = ny.sample("u", dist.Uniform(0, 1).expand([n_edges]))
        v = ny.sample("v", dist.Gamma(1, 20).expand([n_edges]))

        ρ = ny.sample("ρ", dist.Beta(u * v, (1 - u) * v))
        A = ny.sample("A", dist.Bernoulli(probs=ρ))
        # resolve the issue: `squareform` does not support batching
        s_ij = squareform(ρ * A)

You will need to make squareform work with batching. One way to do it is

def squareform(edgelist):
    """edgelist to adj. matrix"""
    from numpyro.distributions.util import vec_to_tril_matrix
    half = vec_to_tril_matrix(edgelist, diagonal=-1)
    full = half + np.swapaxes(half, -2, -1)
    return full

However, for the later code, it is a bit complicated to handle batch dimensions. It took me a while to realize that s_ij and src have different batch dimension (broadcastable), which leads to scan not working. So we need to broadcast src before feeding into scan. I got a working version as follows

    with ny.plate("n_cascades", n_cascades, dim=-2):
        x0 = ny.sample("x0", dist.Categorical(probs=ϕ))
        src = one_hot(x0, n_nodes)
        # we can broadcast src using shapes of `src.shape[:-2]` and `s_ij.shape[:-2]`
        # but I am lazy to do that so I use `matmul` here
        src = (np.broadcast_to(np.eye(s_ij.shape[-1]), s_ij.shape) @ np.swapaxes(src, -2, -1)).squeeze(-1)
        p_infectious = diffuse(s_ij, 1, src)
        p_infectious = np.clip(p_infectious, a_max=1 - 1e-6, a_min=1e-6)
        with ny.plate("n_nodes", n_nodes, dim=-1):
            real = dist.Bernoulli(probs=p_infectious)
            ny.sample("obs", real, obs=infections)

MCMC seems to be pretty fast with the above code and gives high n_eff samples, so I guess it works. But you should double-check to see if the model learns something useful. :)

@fehiepsi
Copy link
Member

@tbsexton FYI, we just release 0.3 version, which supports enumeration over discrete latent variables. I really like this topic and I also have time now so I will look at your tutorial in more details (mainly to study and to make it work with NumPyro). If you have further reference to go through the notebook, please let me know, I would greatly appreciate!

@rtbs-dev
Copy link
Author

rtbs-dev commented Jul 29, 2020

Awesome! I will get this pr onto v 0.3 soon.

I took a bit of a break on this upon realizing the inference was not passing some basic benchmarking tests to recover network structure correctly. I just finished getting a pure optimization version working via pure Jax, and will be getting an example notebook live on one of my own repo's soon. Happy to link it here once it's up to help with this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants