From 8115735e9ef4911bd0befb262459f82bca51caad Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Mon, 21 Oct 2024 17:24:47 -0400 Subject: [PATCH 1/5] marginal_pit metric --- src/postpredict/metrics.py | 58 ++++++++++++++++- .../postpredict/metrics/test_marginal_pit.py | 65 +++++++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 tests/postpredict/metrics/test_marginal_pit.py diff --git a/src/postpredict/metrics.py b/src/postpredict/metrics.py index fa16043..4f92f6d 100644 --- a/src/postpredict/metrics.py +++ b/src/postpredict/metrics.py @@ -51,8 +51,6 @@ def energy_score_one_unit(df: pl.DataFrame): Compute energy score for one observational unit based on a collection of samples. Note, we define this function here so that key_cols, pred_cols and obs_cols are in scope. - - See """ if df[pred_cols + obs_cols].null_count().to_numpy().sum() > 0: # Return np.nan rather than None here to avoid a rare schema @@ -76,3 +74,59 @@ def energy_score_one_unit(df: pl.DataFrame): # replace NaN with None to average only across non-missing values return scores_by_unit["energy_score"].fill_nan(None).mean() + + +def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, + key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], + reduce_mean: bool = True) -> float | pl.DataFrame: + """ + Compute the probability integral transform (PIT) value for each of a + collection of marginal predictive distributions represented by a set of + samples. + + Parameters + ---------- + model_out_wide: pl.DataFrame + DataFrame of model outputs where each row corresponds to one + (multivariate) sample from a multivariate distribution for one + observational unit. + obs_data_wide: pl.DataFrame + DataFrame of observed values where each row corresponds to one + (multivariate) observed outcome for one observational unit. + key_cols: list[str] + Columns that appear in both `model_out_wide` and `obs_data_wide` that + identify observational units. + pred_cols: list[str] + Columns that appear in `model_out_wide` and identify predicted (sampled) + values. The order of these should match the order of `obs_cols`. + obs_cols: list[str] + Columns that appear in `obs_data_wide` and identify observed values. The + order of these should match the order of `pred_cols`. + reduce_mean: bool = True + Indicator of whether to return a numeric mean energy score (default) or + a pl.DataFrame with one row per observational unit. + + Returns + ------- + A pl.DataFrame with one row per observational unit and PIT values stored in + columns named according to `[f"pit_{c}" for c in pred_cols]`. + + Notes + ----- + Here, the PIT value is calculated as the proportion of samples that are less + than or equal to the observed value. + """ + scores_by_unit = ( + model_out_wide + .join(obs_data_wide, on = key_cols) + .group_by(key_cols) + .agg( + [ + (pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \ + for pred_c, obs_c in zip(pred_cols, obs_cols) + ] + ) + .select(key_cols + [f"pit_{pred_c}" for pred_c in pred_cols]) + ) + + return scores_by_unit diff --git a/tests/postpredict/metrics/test_marginal_pit.py b/tests/postpredict/metrics/test_marginal_pit.py new file mode 100644 index 0000000..a7c22a6 --- /dev/null +++ b/tests/postpredict/metrics/test_marginal_pit.py @@ -0,0 +1,65 @@ +# Tests for postpredict.metrics.marginal_pit + +from datetime import datetime + +import numpy as np +import polars as pl +import pytest +from polars.testing import assert_frame_equal +from postpredict.metrics import marginal_pit + + +def test_marginal_pit(): + rng = np.random.default_rng(seed=123) + model_out_wide = pl.concat([ + pl.DataFrame({ + "location": "a", + "date": datetime.strptime("2024-10-01", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": np.linspace(0, 99, 100), + "horizon1": rng.permutation(np.linspace(0, 9, 100)), + "horizon2": rng.permutation(np.linspace(8, 17, 100)), + "horizon3": rng.permutation(np.linspace(5.1, 16.1, 100)) + }), + pl.DataFrame({ + "location": "b", + "date": datetime.strptime("2024-10-08", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": np.linspace(100, 199, 100), + "horizon1": rng.permutation(np.linspace(10.0, 19.0, 100)), + "horizon2": rng.permutation(np.linspace(-3.0, 6.0, 100)), + "horizon3": rng.permutation(np.linspace(10.99, 19.99, 100)) + }) + ]) + obs_data_wide = pl.DataFrame({ + "location": ["a", "a", "b", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d"), + datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "value": [3.0, 4.0, 0.0, 7.2], + "value_lead1": [4.0, 10.0, 7.2, 9.6], + "value_lead2": [10.0, 5.0, 9.6, 10.0], + "value_lead3": [5.0, 2.0, 10.0, 14.1] + }) + + # expected scores calculated in R using the scoringRules package: + expected_scores_df = pl.DataFrame({ + "location": ["a", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "pit_horizon1": [0.45, 0.0], + "pit_horizon2": [0.23, 1.0], + "pit_horizon3": [0.0, 0.35] + }) + + actual_scores_df = marginal_pit(model_out_wide=model_out_wide, + obs_data_wide=obs_data_wide, + key_cols=["location", "date"], + pred_cols=["horizon1", "horizon2", "horizon3"], + obs_cols=["value_lead1", "value_lead2", "value_lead3"], + reduce_mean=False) + + print(actual_scores_df) + + assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19) From 0c7eedd665ed224f05bd9a7d3f9343d03b969036 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Mon, 21 Oct 2024 17:28:20 -0400 Subject: [PATCH 2/5] remove unused import, correct comments --- tests/postpredict/metrics/test_marginal_pit.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/postpredict/metrics/test_marginal_pit.py b/tests/postpredict/metrics/test_marginal_pit.py index a7c22a6..b24dd98 100644 --- a/tests/postpredict/metrics/test_marginal_pit.py +++ b/tests/postpredict/metrics/test_marginal_pit.py @@ -4,7 +4,6 @@ import numpy as np import polars as pl -import pytest from polars.testing import assert_frame_equal from postpredict.metrics import marginal_pit @@ -43,7 +42,8 @@ def test_marginal_pit(): "value_lead3": [5.0, 2.0, 10.0, 14.1] }) - # expected scores calculated in R using the scoringRules package: + # expected PIT values: the number of samples less than or equal to + # corresponding observed values expected_scores_df = pl.DataFrame({ "location": ["a", "b"], "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), @@ -59,7 +59,5 @@ def test_marginal_pit(): pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"], reduce_mean=False) - - print(actual_scores_df) assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19) From 336ae901d50864cd0810e835e9976d05f3961220 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 22 Oct 2024 10:36:25 -0400 Subject: [PATCH 3/5] remove unused argument to marginal_pit --- src/postpredict/metrics.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/postpredict/metrics.py b/src/postpredict/metrics.py index 4f92f6d..4a612ee 100644 --- a/src/postpredict/metrics.py +++ b/src/postpredict/metrics.py @@ -77,8 +77,8 @@ def energy_score_one_unit(df: pl.DataFrame): def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, - key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], - reduce_mean: bool = True) -> float | pl.DataFrame: + key_cols: list[str] | None, pred_cols: list[str], + obs_cols: list[str]) -> pl.DataFrame: """ Compute the probability integral transform (PIT) value for each of a collection of marginal predictive distributions represented by a set of @@ -102,9 +102,6 @@ def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, obs_cols: list[str] Columns that appear in `obs_data_wide` and identify observed values. The order of these should match the order of `pred_cols`. - reduce_mean: bool = True - Indicator of whether to return a numeric mean energy score (default) or - a pl.DataFrame with one row per observational unit. Returns ------- From 095a2b24682f039da4d82b839c41ad03d59a7485 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 22 Oct 2024 10:42:42 -0400 Subject: [PATCH 4/5] remove unused argument from marginal_pit test --- tests/postpredict/metrics/test_marginal_pit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/postpredict/metrics/test_marginal_pit.py b/tests/postpredict/metrics/test_marginal_pit.py index b24dd98..1045924 100644 --- a/tests/postpredict/metrics/test_marginal_pit.py +++ b/tests/postpredict/metrics/test_marginal_pit.py @@ -57,7 +57,6 @@ def test_marginal_pit(): obs_data_wide=obs_data_wide, key_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], - obs_cols=["value_lead1", "value_lead2", "value_lead3"], - reduce_mean=False) + obs_cols=["value_lead1", "value_lead2", "value_lead3"]) assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19) From 7d7914968c95c1cd51a0ebe8f2dd01163fd7b84d Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 22 Oct 2024 11:20:26 -0400 Subject: [PATCH 5/5] in metrics functions, rename key_cols to index_cols --- src/postpredict/metrics.py | 28 +++++++++++-------- .../postpredict/metrics/test_energy_score.py | 4 +-- .../postpredict/metrics/test_marginal_pit.py | 2 +- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/postpredict/metrics.py b/src/postpredict/metrics.py index 4a612ee..73117cb 100644 --- a/src/postpredict/metrics.py +++ b/src/postpredict/metrics.py @@ -4,7 +4,7 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, - key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], + index_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], reduce_mean: bool = True) -> float | pl.DataFrame: """ Compute the energy score for a collection of predictive samples. @@ -18,9 +18,11 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, obs_data_wide: pl.DataFrame DataFrame of observed values where each row corresponds to one (multivariate) observed outcome for one observational unit. - key_cols: list[str] + index_cols: list[str] Columns that appear in both `model_out_wide` and `obs_data_wide` that - identify observational units. + identify the unit of a multivariate prediction (e.g., including the + location, age_group, and reference_time of a prediction). These columns + will be included in the returned dataframe. pred_cols: list[str] Columns that appear in `model_out_wide` and identify predicted (sampled) values. The order of these should match the order of `obs_cols`. @@ -60,12 +62,12 @@ def energy_score_one_unit(df: pl.DataFrame): score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \ - 0.5 * np.mean(pairwise_distances(df[pred_cols])) - return df[0, key_cols].with_columns(energy_score = pl.lit(score)) + return df[0, index_cols].with_columns(energy_score = pl.lit(score)) scores_by_unit = ( model_out_wide - .join(obs_data_wide, on = key_cols) - .group_by(*key_cols) + .join(obs_data_wide, on = index_cols) + .group_by(*index_cols) .map_groups(energy_score_one_unit) ) @@ -77,7 +79,7 @@ def energy_score_one_unit(df: pl.DataFrame): def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, - key_cols: list[str] | None, pred_cols: list[str], + index_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str]) -> pl.DataFrame: """ Compute the probability integral transform (PIT) value for each of a @@ -93,9 +95,11 @@ def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, obs_data_wide: pl.DataFrame DataFrame of observed values where each row corresponds to one (multivariate) observed outcome for one observational unit. - key_cols: list[str] + index_cols: list[str] Columns that appear in both `model_out_wide` and `obs_data_wide` that - identify observational units. + identify the unit of a multivariate prediction (e.g., including the + location, age_group, and reference_time of a prediction). These columns + will be included in the returned dataframe. pred_cols: list[str] Columns that appear in `model_out_wide` and identify predicted (sampled) values. The order of these should match the order of `obs_cols`. @@ -115,15 +119,15 @@ def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, """ scores_by_unit = ( model_out_wide - .join(obs_data_wide, on = key_cols) - .group_by(key_cols) + .join(obs_data_wide, on = index_cols) + .group_by(index_cols) .agg( [ (pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \ for pred_c, obs_c in zip(pred_cols, obs_cols) ] ) - .select(key_cols + [f"pit_{pred_c}" for pred_c in pred_cols]) + .select(index_cols + [f"pit_{pred_c}" for pred_c in pred_cols]) ) return scores_by_unit diff --git a/tests/postpredict/metrics/test_energy_score.py b/tests/postpredict/metrics/test_energy_score.py index 45dc63a..1d6467d 100644 --- a/tests/postpredict/metrics/test_energy_score.py +++ b/tests/postpredict/metrics/test_energy_score.py @@ -67,7 +67,7 @@ def test_energy_score(): actual_scores_df = energy_score(model_out_wide=model_out_wide, obs_data_wide=obs_data_wide, - key_cols=["location", "date"], + index_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"], reduce_mean=False) @@ -77,7 +77,7 @@ def test_energy_score(): expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708]) actual_mean_score = energy_score(model_out_wide=model_out_wide, obs_data_wide=obs_data_wide, - key_cols=["location", "date"], + index_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"], reduce_mean=True) diff --git a/tests/postpredict/metrics/test_marginal_pit.py b/tests/postpredict/metrics/test_marginal_pit.py index 1045924..94fccd0 100644 --- a/tests/postpredict/metrics/test_marginal_pit.py +++ b/tests/postpredict/metrics/test_marginal_pit.py @@ -55,7 +55,7 @@ def test_marginal_pit(): actual_scores_df = marginal_pit(model_out_wide=model_out_wide, obs_data_wide=obs_data_wide, - key_cols=["location", "date"], + index_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"])