-
Notifications
You must be signed in to change notification settings - Fork 2
/
distributions.py
42 lines (35 loc) · 1.35 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import constants as c
import torch
import math
import os
class Gaussian(object):
def __init__(self, mu, rho):
super().__init__()
self.mu = mu
self.rho = rho
# print("Normal Gauss Prior -> , N1: (0, ", str(sigma1), "), N2: (0, ", str(sigma2), ")")
self.normal = torch.distributions.Normal(0, 1)
@property
def sigma(self):
return torch.log1p(torch.exp(self.rho))
def sample(self):
# Only epsilon sampled from N(0, 1)
epsilon = self.normal.sample(self.rho.size())
return self.mu + self.sigma * epsilon
def log_prob(self, input):
return (-math.log(math.sqrt(2 * math.pi))
- torch.log(self.sigma)
- ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()
class ScaleMixtureGaussian(object):
def __init__(self, pi, sigma1, sigma2):
super().__init__()
self.pi = pi
self.sigma1 = sigma1
self.sigma2 = sigma2
# print("Scale Mix Gauss Prior -> , N1: (0, ", str(sigma1), "), N2: (0, ", str(sigma2), ")")
self.gaussian1 = Gaussian(0, sigma1)
self.gaussian2 = Gaussian(0, sigma2)
def log_prob(self, input):
prob1 = torch.exp(self.gaussian1.log_prob(input))
prob2 = torch.exp(self.gaussian2.log_prob(input))
return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()