diff --git a/src/postpredict/metrics.py b/src/postpredict/metrics.py index fa16043..73117cb 100644 --- a/src/postpredict/metrics.py +++ b/src/postpredict/metrics.py @@ -4,7 +4,7 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, - key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], + index_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str], reduce_mean: bool = True) -> float | pl.DataFrame: """ Compute the energy score for a collection of predictive samples. @@ -18,9 +18,11 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, obs_data_wide: pl.DataFrame DataFrame of observed values where each row corresponds to one (multivariate) observed outcome for one observational unit. - key_cols: list[str] + index_cols: list[str] Columns that appear in both `model_out_wide` and `obs_data_wide` that - identify observational units. + identify the unit of a multivariate prediction (e.g., including the + location, age_group, and reference_time of a prediction). These columns + will be included in the returned dataframe. pred_cols: list[str] Columns that appear in `model_out_wide` and identify predicted (sampled) values. The order of these should match the order of `obs_cols`. @@ -51,8 +53,6 @@ def energy_score_one_unit(df: pl.DataFrame): Compute energy score for one observational unit based on a collection of samples. Note, we define this function here so that key_cols, pred_cols and obs_cols are in scope. - - See """ if df[pred_cols + obs_cols].null_count().to_numpy().sum() > 0: # Return np.nan rather than None here to avoid a rare schema @@ -62,12 +62,12 @@ def energy_score_one_unit(df: pl.DataFrame): score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \ - 0.5 * np.mean(pairwise_distances(df[pred_cols])) - return df[0, key_cols].with_columns(energy_score = pl.lit(score)) + return df[0, index_cols].with_columns(energy_score = pl.lit(score)) scores_by_unit = ( model_out_wide - .join(obs_data_wide, on = key_cols) - .group_by(*key_cols) + .join(obs_data_wide, on = index_cols) + .group_by(*index_cols) .map_groups(energy_score_one_unit) ) @@ -76,3 +76,58 @@ def energy_score_one_unit(df: pl.DataFrame): # replace NaN with None to average only across non-missing values return scores_by_unit["energy_score"].fill_nan(None).mean() + + +def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame, + index_cols: list[str] | None, pred_cols: list[str], + obs_cols: list[str]) -> pl.DataFrame: + """ + Compute the probability integral transform (PIT) value for each of a + collection of marginal predictive distributions represented by a set of + samples. + + Parameters + ---------- + model_out_wide: pl.DataFrame + DataFrame of model outputs where each row corresponds to one + (multivariate) sample from a multivariate distribution for one + observational unit. + obs_data_wide: pl.DataFrame + DataFrame of observed values where each row corresponds to one + (multivariate) observed outcome for one observational unit. + index_cols: list[str] + Columns that appear in both `model_out_wide` and `obs_data_wide` that + identify the unit of a multivariate prediction (e.g., including the + location, age_group, and reference_time of a prediction). These columns + will be included in the returned dataframe. + pred_cols: list[str] + Columns that appear in `model_out_wide` and identify predicted (sampled) + values. The order of these should match the order of `obs_cols`. + obs_cols: list[str] + Columns that appear in `obs_data_wide` and identify observed values. The + order of these should match the order of `pred_cols`. + + Returns + ------- + A pl.DataFrame with one row per observational unit and PIT values stored in + columns named according to `[f"pit_{c}" for c in pred_cols]`. + + Notes + ----- + Here, the PIT value is calculated as the proportion of samples that are less + than or equal to the observed value. + """ + scores_by_unit = ( + model_out_wide + .join(obs_data_wide, on = index_cols) + .group_by(index_cols) + .agg( + [ + (pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \ + for pred_c, obs_c in zip(pred_cols, obs_cols) + ] + ) + .select(index_cols + [f"pit_{pred_c}" for pred_c in pred_cols]) + ) + + return scores_by_unit diff --git a/tests/postpredict/metrics/test_energy_score.py b/tests/postpredict/metrics/test_energy_score.py index 45dc63a..1d6467d 100644 --- a/tests/postpredict/metrics/test_energy_score.py +++ b/tests/postpredict/metrics/test_energy_score.py @@ -67,7 +67,7 @@ def test_energy_score(): actual_scores_df = energy_score(model_out_wide=model_out_wide, obs_data_wide=obs_data_wide, - key_cols=["location", "date"], + index_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"], reduce_mean=False) @@ -77,7 +77,7 @@ def test_energy_score(): expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708]) actual_mean_score = energy_score(model_out_wide=model_out_wide, obs_data_wide=obs_data_wide, - key_cols=["location", "date"], + index_cols=["location", "date"], pred_cols=["horizon1", "horizon2", "horizon3"], obs_cols=["value_lead1", "value_lead2", "value_lead3"], reduce_mean=True) diff --git a/tests/postpredict/metrics/test_marginal_pit.py b/tests/postpredict/metrics/test_marginal_pit.py new file mode 100644 index 0000000..94fccd0 --- /dev/null +++ b/tests/postpredict/metrics/test_marginal_pit.py @@ -0,0 +1,62 @@ +# Tests for postpredict.metrics.marginal_pit + +from datetime import datetime + +import numpy as np +import polars as pl +from polars.testing import assert_frame_equal +from postpredict.metrics import marginal_pit + + +def test_marginal_pit(): + rng = np.random.default_rng(seed=123) + model_out_wide = pl.concat([ + pl.DataFrame({ + "location": "a", + "date": datetime.strptime("2024-10-01", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": np.linspace(0, 99, 100), + "horizon1": rng.permutation(np.linspace(0, 9, 100)), + "horizon2": rng.permutation(np.linspace(8, 17, 100)), + "horizon3": rng.permutation(np.linspace(5.1, 16.1, 100)) + }), + pl.DataFrame({ + "location": "b", + "date": datetime.strptime("2024-10-08", "%Y-%m-%d"), + "output_type": "sample", + "output_type_id": np.linspace(100, 199, 100), + "horizon1": rng.permutation(np.linspace(10.0, 19.0, 100)), + "horizon2": rng.permutation(np.linspace(-3.0, 6.0, 100)), + "horizon3": rng.permutation(np.linspace(10.99, 19.99, 100)) + }) + ]) + obs_data_wide = pl.DataFrame({ + "location": ["a", "a", "b", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d"), + datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "value": [3.0, 4.0, 0.0, 7.2], + "value_lead1": [4.0, 10.0, 7.2, 9.6], + "value_lead2": [10.0, 5.0, 9.6, 10.0], + "value_lead3": [5.0, 2.0, 10.0, 14.1] + }) + + # expected PIT values: the number of samples less than or equal to + # corresponding observed values + expected_scores_df = pl.DataFrame({ + "location": ["a", "b"], + "date": [datetime.strptime("2024-10-01", "%Y-%m-%d"), + datetime.strptime("2024-10-08", "%Y-%m-%d")], + "pit_horizon1": [0.45, 0.0], + "pit_horizon2": [0.23, 1.0], + "pit_horizon3": [0.0, 0.35] + }) + + actual_scores_df = marginal_pit(model_out_wide=model_out_wide, + obs_data_wide=obs_data_wide, + index_cols=["location", "date"], + pred_cols=["horizon1", "horizon2", "horizon3"], + obs_cols=["value_lead1", "value_lead2", "value_lead3"]) + + assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19)