-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConditionAugmenter.py
63 lines (51 loc) · 2.29 KB
/
ConditionAugmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Akanimax's code from https://github.com/akanimax/T2F
"""
----Currently not used----
Module implementing the Condition Augmentation
"""
import torch
class ConditionAugmentor(torch.nn.Module):
""" Perform conditioning augmentation
from the paper -> https://arxiv.org/abs/1710.10916 (StackGAN++)
uses the reparameterization trick from VAE paper.
"""
def __init__(self, input_size, latent_size, use_eql=True, device=torch.device("cpu")):
"""
constructor of the class
:param input_size: input size to the augmentor
:param latent_size: required output size
:param use_eql: boolean for whether to use equalized learning rate
:param device: device on which to run the Module
"""
super(ConditionAugmentor, self).__init__()
assert latent_size % 2 == 0, "Latent manifold has odd number of dimensions"
# state of the object
self.device = device
self.input_size = input_size
self.latent_size = latent_size
# required modules:
if use_eql:
from pro_gan_pytorch.CustomLayers import _equalized_linear
self.transformer = _equalized_linear(self.input_size, 2 * self.latent_size).to(device)
else:
self.transformer = torch.nn.Linear(self.input_size, 2 * self.latent_size).to(device)
def forward(self, x, epsilon=1e-12):
"""
forward pass (computations)
:param x: input
:param epsilon: a small noise added for numerical stability
:return: c_not_hat, mus, sigmas => augmented text embeddings, means, stds
"""
from torch.nn.functional import relu
# apply the feed forward layer:
combined = self.transformer(x)
# use the reparametrization trick
mid_point = self.latent_size
mus, sigmas = combined[:, :mid_point], combined[:, mid_point:]
# mus don't need to be transformed, but sigmas cannot be negative.
# so, we'll apply a ReLU on top of sigmas
sigmas = relu(sigmas) # hopefully the network will learn a good sigma mapping
sigmas = sigmas + epsilon # small noise added for stability
epsilon = torch.randn(*mus.shape).to(self.device)
c_not_hat = (epsilon * sigmas) + mus
return c_not_hat, mus, sigmas