forked from akanametov/musegan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcriterion.py
68 lines (52 loc) · 1.5 KB
/
criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Loss function and gradient penalty for MuseGAN."""
from torch import Tensor
import torch
from torch import nn
class WassersteinLoss(nn.Module):
"""WassersteinLoss."""
def __init__(self) -> None:
"""Initialize."""
super().__init__()
def forward(self, y_pred: Tensor, y_target: Tensor) -> Tensor:
"""Calculate Wasserstein loss.
Parameters
----------
y_pred: Tensor
Prediction.
y_target: Tensor
Target.
Returns
-------
Tensor:
Loss value.
"""
loss = - torch.mean(y_pred * y_target)
return loss
class GradientPenalty(nn.Module):
"""Gradient penalty."""
def __init__(self) -> None:
"""Initialize."""
super().__init__()
def forward(self, inputs: Tensor, outputs: Tensor) -> Tensor:
"""Calculate gradient penalty.
Parameters
----------
inputs: Tensor
Input from which to track gradient.
outputs: Tensor
Output to which to track gradient.
Returns
-------
Tensor:
Penalty value.
"""
grad = torch.autograd.grad(
inputs=inputs,
outputs=outputs,
grad_outputs=torch.ones_like(outputs),
create_graph=True,
retain_graph=True,
)[0]
grad_ = torch.norm(grad.view(grad.size(0), -1), p=2, dim=1)
penalty = torch.mean((1. - grad_) ** 2)
return penalty