-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
82 lines (69 loc) · 2.92 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import numpy as np
# def collect_outputs(outputs, multi_gpu):
# log_dict = {}
# for loss_type in outputs[0]:
# if multi_gpu:
# collect = []
# for output in outputs:
# for v in output[loss_type]:
# if v == v:
# collect.append(v)
# else:
# collect = [v[loss_type] for v in outputs if v[loss_type] == v[loss_type]]
# if collect:
# log_dict[loss_type] = torch.stack(collect).mean()
# else:
# log_dict[loss_type] = float('nan')
# return log_dict
def define_loss_fn(lat2d):
weights_lat = compute_latitude_weighting(lat2d)
loss = lambda x, y: compute_weighted_mse(x, y, weights_lat)
return weights_lat, loss
def compute_latitude_weighting(lat):
weights_lat = np.cos(np.deg2rad(lat))
weights_lat /= weights_lat.mean()
return weights_lat
def compute_weighted_mse(pred, truth, weights_lat, flat_weights=False):
"""
Compute the MSE with latitude weighting.
Args:
pred : Forecast. Torch tensor.
truth: Truth. Torch tensor.
weights_lat: Latitude weighting, 2d Torch tensor.
Returns:
rmse: Latitude weighted mean squared error
"""
if not flat_weights:
weights_lat = truth.new(weights_lat).expand_as(truth)
error = (pred - truth)**2
out = error * weights_lat
return out.mean()
# def eval_loss(pred, output, lts, loss_function, possible_lead_times, phase='val', target_v=None, normalizer=None):
# results = {}
# # Unpick which of the batch samples contain which lead_time
# lead_time_dist = {t: lts == t for t in possible_lead_times}
# results[f'{phase}_loss'] = loss_function(pred, output)
# # Caclulate loss per lead_time
# for t, cond in lead_time_dist.items():
# if any(cond):
# results[f'{phase}_loss_{t}hrs'] = loss_function(pred[cond], output[cond])
# else:
# results[f'{phase}_loss_{t}hrs'] = pred.new([float('nan')])[0]
# # Undo normalization
# if normalizer:
# scaled_pred_v = (torch.exp(pred[:, 0, :, :]) - 1 ) * normalizer[target_v]['std']
# scaled_output_v = (torch.exp(output[:, 0, :, :]) - 1) * normalizer[target_v]['std']
# results[f'{phase}_loss_' + target_v] = loss_function(scaled_pred_v, scaled_output_v)
# # Caclulate loss per lead_time
# for t, cond in lead_time_dist.items():
# if any(cond):
# results[f'{phase}_loss_{target_v}_{t}hrs'] = loss_function(scaled_pred_v[cond], scaled_output_v[cond])
# else:
# results[f'{phase}_loss_{target_v}_{t}hrs'] = scaled_pred_v.new([float('nan')])[0]
# return results
# def convert_precip_to_mm(output, target_v, normalizer):
# converted = (np.exp(output) - 1) * normalizer[target_v]['std']
# if target_v == 'tp':
# converted *= 1e3
# return converted