From 100253acbadd29b701034718a39c5174a87f4c64 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 17 Oct 2024 15:16:39 -0400 Subject: [PATCH 1/2] add energy score metric, reorganize tests --- pyproject.toml | 3 +- requirements/requirements-dev.txt | 13 ++- requirements/requirements.txt | 13 ++- src/postpredict/metrics.py | 71 ++++++++++++++++ .../postpredict/{ => dependence}/conftest.py | 0 .../{ => dependence}/test_apply_shuffle.py | 0 .../{ => dependence}/test_build_train_X_Y.py | 0 .../{ => dependence}/test_pivot_horizon.py | 0 .../test_schaake_build_templates.py | 0 .../{ => dependence}/test_transform.py | 0 .../postpredict/metrics/test_energy_score.py | 84 +++++++++++++++++++ 11 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 src/postpredict/metrics.py rename tests/postpredict/{ => dependence}/conftest.py (100%) rename tests/postpredict/{ => dependence}/test_apply_shuffle.py (100%) rename tests/postpredict/{ => dependence}/test_build_train_X_Y.py (100%) rename tests/postpredict/{ => dependence}/test_pivot_horizon.py (100%) rename tests/postpredict/{ => dependence}/test_schaake_build_templates.py (100%) rename tests/postpredict/{ => dependence}/test_transform.py (100%) create mode 100644 tests/postpredict/metrics/test_energy_score.py diff --git a/pyproject.toml b/pyproject.toml index 0f1cd5d..bc371a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ dynamic = ["version"] dependencies = [ "numpy", - "polars" + "polars", + "scikit-learn" ] [project.optional-dependencies] diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 6b677e0..751d4bb 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -12,6 +12,8 @@ identify==2.5.36 # via pre-commit iniconfig==2.0.0 # via pytest +joblib==1.4.2 + # via scikit-learn mypy==1.10.0 # via postpredict (pyproject.toml) mypy-extensions==1.0.0 @@ -19,7 +21,10 @@ mypy-extensions==1.0.0 nodeenv==1.8.0 # via pre-commit numpy==2.1.2 - # via postpredict (pyproject.toml) + # via + # postpredict (pyproject.toml) + # scikit-learn + # scipy packaging==24.0 # via pytest platformdirs==4.2.1 @@ -40,8 +45,14 @@ pyyaml==6.0.1 # via pre-commit ruff==0.4.3 # via postpredict (pyproject.toml) +scikit-learn==1.5.2 + # via postpredict (pyproject.toml) +scipy==1.14.1 + # via scikit-learn setuptools==75.1.0 # via nodeenv +threadpoolctl==3.5.0 + # via scikit-learn toml==0.10.2 # via postpredict (pyproject.toml) types-toml==0.10.8.20240310 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d7393f3..f79de00 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,17 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements/requirements.txt +joblib==1.4.2 + # via scikit-learn numpy==2.1.2 - # via postpredict (pyproject.toml) + # via + # postpredict (pyproject.toml) + # scikit-learn + # scipy polars==1.9.0 # via postpredict (pyproject.toml) +scikit-learn==1.5.2 + # via postpredict (pyproject.toml) +scipy==1.14.1 + # via scikit-learn +threadpoolctl==3.5.0 + # via scikit-learn diff --git a/src/postpredict/metrics.py b/src/postpredict/metrics.py new file mode 100644 index 0000000..bb5d879 --- /dev/null +++ b/src/postpredict/metrics.py @@ -0,0 +1,71 @@ +import numpy as np +import polars as pl +from sklearn.metrics import pairwise_distances + + +def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, + key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], + reduce_mean: bool = True) -> float | pl.DataFrame: + """ + Compute the energy score for a collection of predictive samples. + + Parameters + ---------- + model_out_wide: pl.DataFrame + DataFrame of model outputs where each row corresponds to one + (multivariate) sample from a multivariate distribution for one + observational unit. + obs_data_wide: pl.DataFrame + DataFrame of observed values where each row corresponds to one + (multivariate) observed outcome for one observational unit. + key_cols: list[str] + Columns that appear in both `model_out_wide` and `obs_data_wide` that + identify observational units. + pred_cols: list[str] + Columns that appear in `model_out_wide` and identify predicted (sampled) + values. The order of these should match the order of `obs_cols`. + obs_cols: list[str] + Columns that appear in `obs_data_wide` and identify observed values. The + order of these should match the order of `pred_cols`. + reduce_mean: bool = True + Indicator of whether to return a numeric mean energy score (default) or + a pl.DataFrame with one row per observational unit. + + Returns + ------- + Either the mean energy score across all observational units (default) or a + pl.DataFrame with one row per observational unit and scores stored in a + column named `energy_score`. + + Notes + ----- + We perform the energy score calculation of Eq. (7), p. 223 in + Gneiting, T., Stanberry, L.I., Grimit, E.P. et al. Assessing probabilistic + forecasts of multivariate quantities, with an application to ensemble + predictions of surface winds. TEST 17, 211–235 (2008). + https://doi.org/10.1007/s11749-008-0114-x + https://link.springer.com/article/10.1007/s11749-008-0114-x + """ + def energy_score_one_unit(df: pl.DataFrame): + """ + Compute energy score for one observational unit based on a collection of + samples. Note, we define this function here so that key_cols, pred_cols + and obs_cols are in scope. + + See + """ + score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \ + - 0.5 * np.mean(pairwise_distances(df[pred_cols])) + return df[0, key_cols].with_columns(energy_score = pl.lit(score)) + + scores_by_unit = ( + model_out_wide + .join(obs_data_wide, on = key_cols) + .group_by(*key_cols) + .map_groups(energy_score_one_unit) + ) + + if not reduce_mean: + return scores_by_unit + + return scores_by_unit["energy_score"].mean() diff --git a/tests/postpredict/conftest.py b/tests/postpredict/dependence/conftest.py similarity index 100% rename from tests/postpredict/conftest.py rename to tests/postpredict/dependence/conftest.py diff --git a/tests/postpredict/test_apply_shuffle.py b/tests/postpredict/dependence/test_apply_shuffle.py similarity index 100% rename from tests/postpredict/test_apply_shuffle.py rename to tests/postpredict/dependence/test_apply_shuffle.py diff --git a/tests/postpredict/test_build_train_X_Y.py b/tests/postpredict/dependence/test_build_train_X_Y.py similarity index 100% rename from tests/postpredict/test_build_train_X_Y.py rename to tests/postpredict/dependence/test_build_train_X_Y.py diff --git a/tests/postpredict/test_pivot_horizon.py b/tests/postpredict/dependence/test_pivot_horizon.py similarity index 100% rename from tests/postpredict/test_pivot_horizon.py rename to tests/postpredict/dependence/test_pivot_horizon.py diff --git a/tests/postpredict/test_schaake_build_templates.py b/tests/postpredict/dependence/test_schaake_build_templates.py similarity index 100% rename from tests/postpredict/test_schaake_build_templates.py rename to tests/postpredict/dependence/test_schaake_build_templates.py diff --git a/tests/postpredict/test_transform.py b/tests/postpredict/dependence/test_transform.py similarity index 100% rename from tests/postpredict/test_transform.py rename to tests/postpredict/dependence/test_transform.py diff --git a/tests/postpredict/metrics/test_energy_score.py b/tests/postpredict/metrics/test_energy_score.py new file mode 100644 index 0000000..edb969b --- /dev/null +++ b/tests/postpredict/metrics/test_energy_score.py @@ -0,0 +1,84 @@ +# Tests for postpredict.metrics.energy_score + +from datetime import datetime + +import numpy as np +import polars as pl +import pytest +from polars.testing import assert_frame_equal +from postpredict.metrics import energy_score + + +def test_energy_score(): + model_out_wide = pl.concat([ + pl.DataFrame({ + "location": "a", + "date": datetime.strptime("2024-10-01", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": [1.0, 2.0, 3.0, 4.0], + "horizon1": [5.0, 7.7, 18.0, 10.0], + "horizon2": [3.0, 4.0, 10.0, 6.0], + "horizon3": [4.4, 1.0, 12.0, 9.0] + }), + pl.DataFrame({ + "location": "b", + "date": datetime.strptime("2024-10-08", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": [5.0, 6.0, 7.0, 8.0], + "horizon1": [6.0, 4.0, 5.0, 2.0], + "horizon2": [12.0, 0.0, 15.0, 6.0], + "horizon3": [16.6, 21.0, 32.0, -1.0] + }) + ]) + obs_data_wide = pl.DataFrame({ + "location": ["a", "a", "b", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d"), + datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "value": [3.0, 4.0, 0.0, 7.2], + "value_lead1": [4.0, 10.0, 7.2, 9.6], + "value_lead2": [10.0, 5.0, 9.6, 10.0], + "value_lead3": [5.0, 2.0, 10.0, 14.1] + }) + + # expected scores calculated in R using the scoringRules package: + # library(scoringRules) + # X <- matrix( + # data = c(5.0, 7.7, 18.0, 10.0, 3.0, 4.0, 10.0, 6.0, 4.4, 1.0, 12.0, 9.0), + # nrow = 3, ncol = 4, + # byrow = TRUE + # ) + # y <- c(4.0, 10.0, 5.0) + # print(es_sample(y, X), digits = 20) + # X <- matrix( + # data = c(6.0, 4.0, 5.0, 2.0, 12.0, 0.0, 15.0, 6.0, 16.6, 21.0, 32.0, -1.0), + # nrow = 3, ncol = 4, + # byrow = TRUE + # ) + # y <- c(9.6, 10.0, 14.1) + # print(es_sample(y, X), digits = 20) + expected_scores_df = pl.DataFrame({ + "location": ["a", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "energy_score": [5.8560677725938221627, 5.9574451598773787708] + }) + + actual_scores_df = energy_score(model_out_wide = model_out_wide, + obs_data_wide = obs_data_wide, + key_cols = ["location", "date"], + pred_cols = ["horizon1", "horizon2", "horizon3"], + obs_cols = ["value_lead1", "value_lead2", "value_lead3"], + reduce_mean = False) + + assert_frame_equal(actual_scores_df, expected_scores_df, atol = 1e-19) + + expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708]) + actual_mean_score = energy_score(model_out_wide = model_out_wide, + obs_data_wide = obs_data_wide, + key_cols = ["location", "date"], + pred_cols = ["horizon1", "horizon2", "horizon3"], + obs_cols = ["value_lead1", "value_lead2", "value_lead3"], + reduce_mean = True) + assert actual_mean_score == pytest.approx(expected_mean_score) From b148118663af2c8c2369d38829923d1030d372f4 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 17 Oct 2024 15:21:44 -0400 Subject: [PATCH 2/2] turn off mypy type checking in run checks :see-no-evil: --- .github/workflows/pythonapp-workflow.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/pythonapp-workflow.yml b/.github/workflows/pythonapp-workflow.yml index 137f0b4..cb321bb 100644 --- a/.github/workflows/pythonapp-workflow.yml +++ b/.github/workflows/pythonapp-workflow.yml @@ -16,9 +16,6 @@ jobs: - name: lint run: | ruff check . - - name: type check - run: | - mypy . - name: run tests run: | pytest