From 251018bdb93ca0907006d6df11c30cd53512292c Mon Sep 17 00:00:00 2001 From: Yuval Date: Mon, 2 Mar 2020 19:19:23 +0000 Subject: [PATCH] incorporate pr #18 [WIP] implement power low compression loss --- utils/train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/utils/train.py b/utils/train.py index 1fd8d94..531dc1e 100644 --- a/utils/train.py +++ b/utils/train.py @@ -66,7 +66,20 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power) # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power) - loss = criterion(output, target_mag) + if hp.train.complex_loss_ratio > 0: + # Power-law compression + magnitude_loss = criterion( + torch.pow(torch.abs(output), hp.audio.power), + torch.pow(torch.abs(target_mag), hp.audio.power), + ) + complex_loss = criterion( + torch.pow(torch.clamp(output, min=0.0), hp.audio.power), + torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power), + ) + loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio + + else: + loss = criterion(output, target_mag) optimizer.zero_grad() loss.backward()