forked from ethanluoyc/pytorch-vae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae.py
120 lines (93 loc) · 3.51 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
class Normal(object):
def __init__(self, mu, sigma, log_sigma, v=None, r=None):
self.mu = mu
self.sigma = sigma # either stdev diagonal itself, or stdev diagonal from decomposition
self.logsigma = log_sigma
dim = mu.get_shape()
if v is None:
v = torch.FloatTensor(*dim)
if r is None:
r = torch.FloatTensor(*dim)
self.v = v
self.r = r
class Encoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Encoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class Decoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Decoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x))
class VAE(torch.nn.Module):
latent_dim = 8
def __init__(self, encoder, decoder):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._enc_mu = torch.nn.Linear(100, 8)
self._enc_log_sigma = torch.nn.Linear(100, 8)
def _sample_latent(self, h_enc):
"""
Return the latent normal sample z ~ N(mu, sigma^2)
"""
mu = self._enc_mu(h_enc)
log_sigma = self._enc_log_sigma(h_enc)
sigma = torch.exp(log_sigma)
std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
self.z_mean = mu
self.z_sigma = sigma
return mu + sigma * Variable(std_z, requires_grad=False) # Reparameterization trick
def forward(self, state):
h_enc = self.encoder(state)
z = self._sample_latent(h_enc)
return self.decoder(z)
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
if __name__ == '__main__':
input_dim = 28 * 28
batch_size = 32
transform = transforms.Compose(
[transforms.ToTensor()])
mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
shuffle=True, num_workers=2)
print('Number of samples: ', len(mnist))
encoder = Encoder(input_dim, 100, 100)
decoder = Decoder(8, 100, input_dim)
vae = VAE(encoder, decoder)
criterion = nn.MSELoss()
optimizer = optim.Adam(vae.parameters(), lr=0.0001)
l = None
for epoch in range(100):
for i, data in enumerate(dataloader, 0):
inputs, classes = data
inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
optimizer.zero_grad()
dec = vae(inputs)
ll = latent_loss(vae.z_mean, vae.z_sigma)
loss = criterion(dec, inputs) + ll
loss.backward()
optimizer.step()
l = loss.data[0]
print(epoch, l)
plt.imshow(vae(inputs).data[0].numpy().reshape(28, 28), cmap='gray')
plt.show(block=True)