Skip to content

Commit

Permalink
fix: 🐛 outlier handling of WAPE
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 10, 2023
1 parent 37e62e8 commit 0217cb5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion basicts/metrics/wape.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def masked_wape(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.
mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.)
mask = mask.float()
preds, labels = preds * mask, labels * mask
loss = torch.sum(torch.abs(preds-labels)) / torch.sum(torch.abs(labels))
loss = torch.sum(torch.abs(preds-labels)) / (torch.sum(torch.abs(labels)) + 5e-5)
return torch.mean(loss)

0 comments on commit 0217cb5

Please sign in to comment.