Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior #22

Open
neuronphysics opened this issue Mar 3, 2023 · 0 comments

Comments

@neuronphysics
Copy link

neuronphysics commented Mar 3, 2023

Hi,

I am trying to compute a KL divergence between a Gaussian Mixture model prior and a normal distribution posterior. It is analytically intractable unless doing some approximation. However, it is also possible to compute it via Monte Carlo Sampling. I was wondering how do you suggest implementing it with your library?

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import MultivariateNormal, OneHotCategorical, MixtureSameFamily, Categorical
from torch.distributions.independent import Independent
class VGMM(nn.Module): 
     def __init__(self,
                  u_dim,
                  h_dim,
                  z_dim,
                  n_mixtures,
                  device,
                  batch_norm=False,
                  ):
        super(VGMM, self).__init__()
        self.n_mixtures =n_mixtures
        self.u_dim= u_dim
        self.h_dim=h_dim
        self.z_dim=z_dim
        self.device=device
        self.batch_norm=  batch_norm
        encoder_layers=[nn.Linear(self.u_dim , self.h_dim)]
        if self.batch_norm:
            encoder_layers.append(torch.nn.BatchNorm1d(self.h_dim))
        encoder_layers=encoder_layers+[
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
        ]
        if self.batch_norm:
            encoder_layers= encoder_layers+[nn.BatchNorm1d(self.h_dim)]

        encoder_layers  = encoder_layers+[nn.ReLU()]

        self.enc        = torch.nn.Sequential(*encoder_layers)

        self.enc_mean   = nn.Linear(self.h_dim, self.z_dim)

        self.enc_logvar = nn.Linear(self.h_dim, self.z_dim)
        self.dist = MixtureSameFamily
        self.comp = Normal
        self.mix = Categorical

        layers_prior = [nn.Linear(self.u_dim, self.h_dim)]
        if self.batch_norm:
            layers_prior.append(torch.nn.BatchNorm1d(self.h_dim))
        layers_prior = layers_prior + [
            nn.ReLU(),
        ]

        self.prior = torch.nn.Sequential(*layers_prior)

        self.prior_mean = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )

        self.prior_logvar = nn.ModuleList(
            [nn.Linear(self.h_dim, self.z_dim) for _ in range(n_mixtures)]
        )
        self.prior_weights = nn.Linear(self.h_dim, n_mixtures) 
     def forward(self, u):
        encoder_input = self.enc(u)
        enc_mean   = self.enc_mean(encoder_input)
        enc_logvar = self.enc_logvar(encoder_input)
        enc_logvar = nn.Softplus()(enc_logvar)
        prior_input =self.prior(u)
        prior_mean  = torch.cat([ self.prior_mean[n](prior_input).unsqueeze(1) for n in range(self.n_mixtures)],dim=1,)
        prior_logvar = torch.cat([self.prior_logvar[n](prior_input).unsqueeze(1)for n in range(self.n_mixtures)],dim=1,)
        prior_w     = self.prior_weights(prior_input)
        prior_sigma = prior_logvar.exp().sqrt()
        prior_dist = self.dist(self.mix(logits=prior_w), Independent(self.comp(prior_mean, prior_sigma), 1))
        post_dist = self.comp(enc_mean, enc_logvar.exp().sqrt())
        z_t      = self.reparametrization(enc_mean, enc_logvar)
        return prior_dist, post_dist, z_t
     def reparametrization(self, mu, log_var):
        var = torch.exp(log_var* 0.5)
        eps = torch.FloatTensor(var.size()).normal_(mean=0, std=1).to(self.device)
        eps = torch.autograd.Variable(eps)
        return eps.mul(var).add_(mu).add_(1e-7)     

How do you suggest I can use library to compute the KL term? Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant