Skip to content

Commit

Permalink
incorporate pr maum-ai#18 [WIP] implement power low compression loss
Browse files Browse the repository at this point in the history
  • Loading branch information
kwikwag committed Mar 2, 2020
1 parent 3d70627 commit 251018b
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,20 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp,

# output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
# target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
loss = criterion(output, target_mag)
if hp.train.complex_loss_ratio > 0:
# Power-law compression
magnitude_loss = criterion(
torch.pow(torch.abs(output), hp.audio.power),
torch.pow(torch.abs(target_mag), hp.audio.power),
)
complex_loss = criterion(
torch.pow(torch.clamp(output, min=0.0), hp.audio.power),
torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power),
)
loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio

else:
loss = criterion(output, target_mag)

optimizer.zero_grad()
loss.backward()
Expand Down

0 comments on commit 251018b

Please sign in to comment.