generated from reichlab/reichlab-python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from reichlab/energy_score
add energy score metric, reorganize tests
- Loading branch information
Showing
12 changed files
with
181 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,17 @@ | ||
# This file was autogenerated by uv via the following command: | ||
# uv pip compile pyproject.toml -o requirements/requirements.txt | ||
joblib==1.4.2 | ||
# via scikit-learn | ||
numpy==2.1.2 | ||
# via postpredict (pyproject.toml) | ||
# via | ||
# postpredict (pyproject.toml) | ||
# scikit-learn | ||
# scipy | ||
polars==1.9.0 | ||
# via postpredict (pyproject.toml) | ||
scikit-learn==1.5.2 | ||
# via postpredict (pyproject.toml) | ||
scipy==1.14.1 | ||
# via scikit-learn | ||
threadpoolctl==3.5.0 | ||
# via scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import polars as pl | ||
from sklearn.metrics import pairwise_distances | ||
|
||
|
||
def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, | ||
key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], | ||
reduce_mean: bool = True) -> float | pl.DataFrame: | ||
""" | ||
Compute the energy score for a collection of predictive samples. | ||
Parameters | ||
---------- | ||
model_out_wide: pl.DataFrame | ||
DataFrame of model outputs where each row corresponds to one | ||
(multivariate) sample from a multivariate distribution for one | ||
observational unit. | ||
obs_data_wide: pl.DataFrame | ||
DataFrame of observed values where each row corresponds to one | ||
(multivariate) observed outcome for one observational unit. | ||
key_cols: list[str] | ||
Columns that appear in both `model_out_wide` and `obs_data_wide` that | ||
identify observational units. | ||
pred_cols: list[str] | ||
Columns that appear in `model_out_wide` and identify predicted (sampled) | ||
values. The order of these should match the order of `obs_cols`. | ||
obs_cols: list[str] | ||
Columns that appear in `obs_data_wide` and identify observed values. The | ||
order of these should match the order of `pred_cols`. | ||
reduce_mean: bool = True | ||
Indicator of whether to return a numeric mean energy score (default) or | ||
a pl.DataFrame with one row per observational unit. | ||
Returns | ||
------- | ||
Either the mean energy score across all observational units (default) or a | ||
pl.DataFrame with one row per observational unit and scores stored in a | ||
column named `energy_score`. | ||
Notes | ||
----- | ||
We perform the energy score calculation of Eq. (7), p. 223 in | ||
Gneiting, T., Stanberry, L.I., Grimit, E.P. et al. Assessing probabilistic | ||
forecasts of multivariate quantities, with an application to ensemble | ||
predictions of surface winds. TEST 17, 211–235 (2008). | ||
https://doi.org/10.1007/s11749-008-0114-x | ||
https://link.springer.com/article/10.1007/s11749-008-0114-x | ||
""" | ||
def energy_score_one_unit(df: pl.DataFrame): | ||
""" | ||
Compute energy score for one observational unit based on a collection of | ||
samples. Note, we define this function here so that key_cols, pred_cols | ||
and obs_cols are in scope. | ||
See | ||
""" | ||
score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \ | ||
- 0.5 * np.mean(pairwise_distances(df[pred_cols])) | ||
return df[0, key_cols].with_columns(energy_score = pl.lit(score)) | ||
|
||
scores_by_unit = ( | ||
model_out_wide | ||
.join(obs_data_wide, on = key_cols) | ||
.group_by(*key_cols) | ||
.map_groups(energy_score_one_unit) | ||
) | ||
|
||
if not reduce_mean: | ||
return scores_by_unit | ||
|
||
return scores_by_unit["energy_score"].mean() |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Tests for postpredict.metrics.energy_score | ||
|
||
from datetime import datetime | ||
|
||
import numpy as np | ||
import polars as pl | ||
import pytest | ||
from polars.testing import assert_frame_equal | ||
from postpredict.metrics import energy_score | ||
|
||
|
||
def test_energy_score(): | ||
model_out_wide = pl.concat([ | ||
pl.DataFrame({ | ||
"location": "a", | ||
"date": datetime.strptime("2024-10-01", "%Y-%m-%d"), | ||
"output_type": "sample", | ||
"output_type_id": [1.0, 2.0, 3.0, 4.0], | ||
"horizon1": [5.0, 7.7, 18.0, 10.0], | ||
"horizon2": [3.0, 4.0, 10.0, 6.0], | ||
"horizon3": [4.4, 1.0, 12.0, 9.0] | ||
}), | ||
pl.DataFrame({ | ||
"location": "b", | ||
"date": datetime.strptime("2024-10-08", "%Y-%m-%d"), | ||
"output_type": "sample", | ||
"output_type_id": [5.0, 6.0, 7.0, 8.0], | ||
"horizon1": [6.0, 4.0, 5.0, 2.0], | ||
"horizon2": [12.0, 0.0, 15.0, 6.0], | ||
"horizon3": [16.6, 21.0, 32.0, -1.0] | ||
}) | ||
]) | ||
obs_data_wide = pl.DataFrame({ | ||
"location": ["a", "a", "b", "b"], | ||
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), | ||
datetime.strptime("2024-10-08", "%Y-%m-%d"), | ||
datetime.strptime("2024-10-01", "%Y-%m-%d"), | ||
datetime.strptime("2024-10-08", "%Y-%m-%d")], | ||
"value": [3.0, 4.0, 0.0, 7.2], | ||
"value_lead1": [4.0, 10.0, 7.2, 9.6], | ||
"value_lead2": [10.0, 5.0, 9.6, 10.0], | ||
"value_lead3": [5.0, 2.0, 10.0, 14.1] | ||
}) | ||
|
||
# expected scores calculated in R using the scoringRules package: | ||
# library(scoringRules) | ||
# X <- matrix( | ||
# data = c(5.0, 7.7, 18.0, 10.0, 3.0, 4.0, 10.0, 6.0, 4.4, 1.0, 12.0, 9.0), | ||
# nrow = 3, ncol = 4, | ||
# byrow = TRUE | ||
# ) | ||
# y <- c(4.0, 10.0, 5.0) | ||
# print(es_sample(y, X), digits = 20) | ||
# X <- matrix( | ||
# data = c(6.0, 4.0, 5.0, 2.0, 12.0, 0.0, 15.0, 6.0, 16.6, 21.0, 32.0, -1.0), | ||
# nrow = 3, ncol = 4, | ||
# byrow = TRUE | ||
# ) | ||
# y <- c(9.6, 10.0, 14.1) | ||
# print(es_sample(y, X), digits = 20) | ||
expected_scores_df = pl.DataFrame({ | ||
"location": ["a", "b"], | ||
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), | ||
datetime.strptime("2024-10-08", "%Y-%m-%d")], | ||
"energy_score": [5.8560677725938221627, 5.9574451598773787708] | ||
}) | ||
|
||
actual_scores_df = energy_score(model_out_wide = model_out_wide, | ||
obs_data_wide = obs_data_wide, | ||
key_cols = ["location", "date"], | ||
pred_cols = ["horizon1", "horizon2", "horizon3"], | ||
obs_cols = ["value_lead1", "value_lead2", "value_lead3"], | ||
reduce_mean = False) | ||
|
||
assert_frame_equal(actual_scores_df, expected_scores_df, atol = 1e-19) | ||
|
||
expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708]) | ||
actual_mean_score = energy_score(model_out_wide = model_out_wide, | ||
obs_data_wide = obs_data_wide, | ||
key_cols = ["location", "date"], | ||
pred_cols = ["horizon1", "horizon2", "horizon3"], | ||
obs_cols = ["value_lead1", "value_lead2", "value_lead3"], | ||
reduce_mean = True) | ||
assert actual_mean_score == pytest.approx(expected_mean_score) |