-
Notifications
You must be signed in to change notification settings - Fork 7
/
GPDM.py
141 lines (114 loc) · 5.12 KB
/
GPDM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
from math import sqrt
from torchvision.utils import save_image
from tqdm import tqdm
import torch
from torchvision.transforms import Resize as tv_resize
from utils import plot_loss, load_image
def generate(reference_images,
criteria,
init_from = 'zeros',
pyramid_scales=(32, 64, 128, 256),
lr: float = 0.01,
num_steps: int = 300,
aspect_ratio=(1, 1),
num_outputs=1,
additive_noise_sigma=0.0,
device: str = 'cuda:1',
debug_dir=None):
"""
Run the GPDM model to generate an image/s with a similar patch distribution to reference_images/s with a given criteria.
This manages the coarse to fine optimization steps.
"""
if debug_dir:
os.makedirs(debug_dir, exist_ok=True)
criteria = criteria.to(device)
reference_images = reference_images.to(device)
synthesized_images = get_fist_initial_guess(reference_images, init_from, additive_noise_sigma).to(device)
synthesized_images = ensure_size(synthesized_images, num_outputs)
original_image_shape = synthesized_images.shape[-2:]
print(f"Matching the patches of {len(synthesized_images)} generated images to the patches of {len(reference_images)} reference images")
pbar = GPDMLogger(num_steps, len(pyramid_scales))
if debug_dir:
nrow = int(sqrt(len(synthesized_images)))
save_image(synthesized_images, os.path.join(debug_dir, f'init.png'), normalize=True, nrow=nrow)
all_losses = []
for scale in pyramid_scales:
pbar.new_lvl()
lvl_references = tv_resize(scale, antialias=True)(reference_images)
lvl_output_shape = get_output_shape(original_image_shape, scale, aspect_ratio)
synthesized_images = tv_resize(lvl_output_shape, antialias=True)(synthesized_images)
synthesized_images, losses = _match_patch_distributions(synthesized_images, lvl_references, criteria, num_steps, lr, pbar)
all_losses += losses
if debug_dir:
save_image(lvl_references, os.path.join(debug_dir, f'references-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow)
save_image(synthesized_images, os.path.join(debug_dir, f'outputs-lvl-{pbar.lvl}.png'), normalize=True, nrow=nrow)
plot_loss(all_losses, os.path.join(debug_dir, f'train_losses.png'))
pbar.pbar.close()
return synthesized_images, lvl_references
def _match_patch_distributions(synthesized_images, reference_images, criteria, num_steps, lr, pbar):
"""
Minimizes criteria(synthesized_images, reference_images) for num_steps SGD steps by differentiating self.synthesized_images
:param reference_images: tensor of shape (b, C, H1, W1)
:param synthesized_images: tensor of shape (b, C, H2, W2)
:param debug_dir:
"""
synthesized_images.requires_grad_(True)
optim = torch.optim.Adam([synthesized_images], lr=lr)
losses = []
for i in range(num_steps):
# Optimize image
optim.zero_grad()
loss = criteria(synthesized_images, reference_images)
loss.backward()
optim.step()
# Update staus
losses.append(loss.item())
pbar.step()
pbar.print()
return torch.clip(synthesized_images.detach(), -1, 1), losses
class GPDMLogger:
"""Keeps track of the levels and steps of optimization. Logs it via TQDM"""
def __init__(self, n_steps, n_lvls):
self.n_steps = n_steps
self.n_lvls = n_lvls
self.lvl = -1
self.lvl_step = 0
self.steps = 0
self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')
def step(self):
self.pbar.update(1)
self.steps += 1
self.lvl_step += 1
def new_lvl(self):
self.lvl += 1
self.lvl_step = 0
def print(self):
self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls - 1}, step {self.lvl_step}/{self.n_steps}')
def get_fist_initial_guess(reference_images, init_from, additive_noise_sigma):
if init_from == "zeros":
synthesized_images = torch.zeros(1, *reference_images.shape[1:])
elif init_from == "mean":
synthesized_images = torch.mean(reference_images, dim=0, keepdim=True)
elif init_from == "target":
synthesized_images = reference_images.clone()
import torchvision
synthesized_images = torchvision.transforms.GaussianBlur(7, sigma=7)(synthesized_images)
# elif type(init_from) == torch.Tensor:
# synthesized_images = init_from
elif os.path.exists(init_from):
synthesized_images = load_image(init_from)
else:
raise ValueError("Bad init mode", init_from)
if additive_noise_sigma:
synthesized_images += torch.randn_like(synthesized_images) * additive_noise_sigma
return synthesized_images
def ensure_size(batch, num_outputs):
if num_outputs > 1 and batch.shape[0] == 1:
batch = batch.repeat(num_outputs, 1, 1, 1)
return batch
def get_output_shape(initial_image_shape, size, aspect_ratio):
"""Get the size of the output pyramid level"""
h, w = initial_image_shape
h, w = int(size * aspect_ratio[0]), int((w * size / h) * aspect_ratio[1])
return h, w