Skip to content

Commit

Permalink
Merge pull request #7 from reichlab/energy_score
Browse files Browse the repository at this point in the history
add energy score metric, reorganize tests
  • Loading branch information
elray1 authored Oct 17, 2024
2 parents 95ea603 + b148118 commit 9da50db
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 6 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/pythonapp-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ jobs:
- name: lint
run: |
ruff check .
- name: type check
run: |
mypy .
- name: run tests
run: |
pytest
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ dynamic = ["version"]

dependencies = [
"numpy",
"polars"
"polars",
"scikit-learn"
]

[project.optional-dependencies]
Expand Down
13 changes: 12 additions & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@ identify==2.5.36
# via pre-commit
iniconfig==2.0.0
# via pytest
joblib==1.4.2
# via scikit-learn
mypy==1.10.0
# via postpredict (pyproject.toml)
mypy-extensions==1.0.0
# via mypy
nodeenv==1.8.0
# via pre-commit
numpy==2.1.2
# via postpredict (pyproject.toml)
# via
# postpredict (pyproject.toml)
# scikit-learn
# scipy
packaging==24.0
# via pytest
platformdirs==4.2.1
Expand All @@ -40,8 +45,14 @@ pyyaml==6.0.1
# via pre-commit
ruff==0.4.3
# via postpredict (pyproject.toml)
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
setuptools==75.1.0
# via nodeenv
threadpoolctl==3.5.0
# via scikit-learn
toml==0.10.2
# via postpredict (pyproject.toml)
types-toml==0.10.8.20240310
Expand Down
13 changes: 12 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements/requirements.txt
joblib==1.4.2
# via scikit-learn
numpy==2.1.2
# via postpredict (pyproject.toml)
# via
# postpredict (pyproject.toml)
# scikit-learn
# scipy
polars==1.9.0
# via postpredict (pyproject.toml)
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
threadpoolctl==3.5.0
# via scikit-learn
71 changes: 71 additions & 0 deletions src/postpredict/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import polars as pl
from sklearn.metrics import pairwise_distances


def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
reduce_mean: bool = True) -> float | pl.DataFrame:
"""
Compute the energy score for a collection of predictive samples.
Parameters
----------
model_out_wide: pl.DataFrame
DataFrame of model outputs where each row corresponds to one
(multivariate) sample from a multivariate distribution for one
observational unit.
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
key_cols: list[str]
Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify observational units.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
obs_cols: list[str]
Columns that appear in `obs_data_wide` and identify observed values. The
order of these should match the order of `pred_cols`.
reduce_mean: bool = True
Indicator of whether to return a numeric mean energy score (default) or
a pl.DataFrame with one row per observational unit.
Returns
-------
Either the mean energy score across all observational units (default) or a
pl.DataFrame with one row per observational unit and scores stored in a
column named `energy_score`.
Notes
-----
We perform the energy score calculation of Eq. (7), p. 223 in
Gneiting, T., Stanberry, L.I., Grimit, E.P. et al. Assessing probabilistic
forecasts of multivariate quantities, with an application to ensemble
predictions of surface winds. TEST 17, 211–235 (2008).
https://doi.org/10.1007/s11749-008-0114-x
https://link.springer.com/article/10.1007/s11749-008-0114-x
"""
def energy_score_one_unit(df: pl.DataFrame):
"""
Compute energy score for one observational unit based on a collection of
samples. Note, we define this function here so that key_cols, pred_cols
and obs_cols are in scope.
See
"""
score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \
- 0.5 * np.mean(pairwise_distances(df[pred_cols]))
return df[0, key_cols].with_columns(energy_score = pl.lit(score))

scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = key_cols)
.group_by(*key_cols)
.map_groups(energy_score_one_unit)
)

if not reduce_mean:
return scores_by_unit

return scores_by_unit["energy_score"].mean()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
84 changes: 84 additions & 0 deletions tests/postpredict/metrics/test_energy_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Tests for postpredict.metrics.energy_score

from datetime import datetime

import numpy as np
import polars as pl
import pytest
from polars.testing import assert_frame_equal
from postpredict.metrics import energy_score


def test_energy_score():
model_out_wide = pl.concat([
pl.DataFrame({
"location": "a",
"date": datetime.strptime("2024-10-01", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": [1.0, 2.0, 3.0, 4.0],
"horizon1": [5.0, 7.7, 18.0, 10.0],
"horizon2": [3.0, 4.0, 10.0, 6.0],
"horizon3": [4.4, 1.0, 12.0, 9.0]
}),
pl.DataFrame({
"location": "b",
"date": datetime.strptime("2024-10-08", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": [5.0, 6.0, 7.0, 8.0],
"horizon1": [6.0, 4.0, 5.0, 2.0],
"horizon2": [12.0, 0.0, 15.0, 6.0],
"horizon3": [16.6, 21.0, 32.0, -1.0]
})
])
obs_data_wide = pl.DataFrame({
"location": ["a", "a", "b", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d"),
datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"value": [3.0, 4.0, 0.0, 7.2],
"value_lead1": [4.0, 10.0, 7.2, 9.6],
"value_lead2": [10.0, 5.0, 9.6, 10.0],
"value_lead3": [5.0, 2.0, 10.0, 14.1]
})

# expected scores calculated in R using the scoringRules package:
# library(scoringRules)
# X <- matrix(
# data = c(5.0, 7.7, 18.0, 10.0, 3.0, 4.0, 10.0, 6.0, 4.4, 1.0, 12.0, 9.0),
# nrow = 3, ncol = 4,
# byrow = TRUE
# )
# y <- c(4.0, 10.0, 5.0)
# print(es_sample(y, X), digits = 20)
# X <- matrix(
# data = c(6.0, 4.0, 5.0, 2.0, 12.0, 0.0, 15.0, 6.0, 16.6, 21.0, 32.0, -1.0),
# nrow = 3, ncol = 4,
# byrow = TRUE
# )
# y <- c(9.6, 10.0, 14.1)
# print(es_sample(y, X), digits = 20)
expected_scores_df = pl.DataFrame({
"location": ["a", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"energy_score": [5.8560677725938221627, 5.9574451598773787708]
})

actual_scores_df = energy_score(model_out_wide = model_out_wide,
obs_data_wide = obs_data_wide,
key_cols = ["location", "date"],
pred_cols = ["horizon1", "horizon2", "horizon3"],
obs_cols = ["value_lead1", "value_lead2", "value_lead3"],
reduce_mean = False)

assert_frame_equal(actual_scores_df, expected_scores_df, atol = 1e-19)

expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708])
actual_mean_score = energy_score(model_out_wide = model_out_wide,
obs_data_wide = obs_data_wide,
key_cols = ["location", "date"],
pred_cols = ["horizon1", "horizon2", "horizon3"],
obs_cols = ["value_lead1", "value_lead2", "value_lead3"],
reduce_mean = True)
assert actual_mean_score == pytest.approx(expected_mean_score)

0 comments on commit 9da50db

Please sign in to comment.