Skip to content

Commit

Permalink
Add -normalize_gradients parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ProGamerGov authored Dec 6, 2020
1 parent 97081c3 commit cbcd023
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ path or a full absolute path.
when using ADAM you will probably need to play with other parameters to get good results, especially
the style weight, content weight, and learning rate.
* `-learning_rate`: Learning rate to use with the ADAM optimizer. Default is 1e1.
* `-normalize_gradients`: If this flag is present, style and content gradients from each layer will be L1 normalized.

**Output options**:
* `-output_image`: Name of the output image. Default is `out.png`.
Expand Down Expand Up @@ -313,4 +314,4 @@ If you find this code useful for your research, please cite:
journal = {GitHub repository},
howpublished = {\url{https://github.com/ProGamerGov/neural-style-pt}},
}
```
```
45 changes: 34 additions & 11 deletions neural_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
Expand Down Expand Up @@ -121,13 +122,13 @@ def main():

if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight)
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)

if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight)
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
Expand All @@ -137,14 +138,14 @@ def main():

if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight)
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1

if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight)
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
Expand Down Expand Up @@ -339,15 +340,15 @@ def preprocess(image_name, image_size):
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
tensor = Normalize(rgb2bgr(Loader(image) * 255)).unsqueeze(0)
return tensor


# Undo the above preprocessing.
def deprocess(output_tensor):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 255
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
Expand Down Expand Up @@ -399,18 +400,36 @@ def normalize_weights(content_losses, style_losses):
i.strength = i.strength / max(i.target.size())


# Scale gradients in the backward pass
class ScaleGradients(torch.autograd.Function):
@staticmethod
def forward(self, input_tensor, strength):
self.strength = strength
return input_tensor

@staticmethod
def backward(self, grad_output):
grad_input = grad_output.clone()
grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8)
return grad_input * self.strength * self.strength, None

This comment has been minimized.

Copy link
@JCBrouwer

JCBrouwer Dec 27, 2021

Why do the gradients get multiplied by strength^2 here?

The squared factor doesn't seem present in the lua version. There's just a self.gradInput:mul(self.strength) which I believe is handled by self.loss = loss * self.strength in the *Loss(nn.Module)s



# Define an nn Module to compute content loss
class ContentLoss(nn.Module):

def __init__(self, strength):
def __init__(self, strength, normalize):
super(ContentLoss, self).__init__()
self.strength = strength
self.crit = nn.MSELoss()
self.mode = 'None'
self.normalize = normalize

def forward(self, input):
if self.mode == 'loss':
self.loss = self.crit(input, self.target) * self.strength
loss = self.crit(input, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = loss * self.strength
elif self.mode == 'capture':
self.target = input.detach()
return input
Expand All @@ -427,14 +446,15 @@ def forward(self, input):
# Define an nn Module to compute style loss
class StyleLoss(nn.Module):

def __init__(self, strength):
def __init__(self, strength, normalize):
super(StyleLoss, self).__init__()
self.target = torch.Tensor()
self.strength = strength
self.gram = GramMatrix()
self.crit = nn.MSELoss()
self.mode = 'None'
self.blend_weight = None
self.normalize = normalize

def forward(self, input):
self.G = self.gram(input)
Expand All @@ -447,7 +467,10 @@ def forward(self, input):
else:
self.target = self.target.add(self.blend_weight, self.G.detach())
elif self.mode == 'loss':
self.loss = self.strength * self.crit(self.G, self.target)
loss = self.crit(self.G, self.target)
if self.normalize:
loss = ScaleGradients.apply(loss, self.strength)
self.loss = self.strength * loss
return input


Expand All @@ -465,4 +488,4 @@ def forward(self, input):


if __name__ == "__main__":
main()
main()

0 comments on commit cbcd023

Please sign in to comment.