Skip to content

Commit

Permalink
marginal_pit metric
Browse files Browse the repository at this point in the history
  • Loading branch information
elray1 committed Oct 21, 2024
1 parent 018e545 commit 8115735
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
58 changes: 56 additions & 2 deletions src/postpredict/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def energy_score_one_unit(df: pl.DataFrame):
Compute energy score for one observational unit based on a collection of
samples. Note, we define this function here so that key_cols, pred_cols
and obs_cols are in scope.
See
"""
if df[pred_cols + obs_cols].null_count().to_numpy().sum() > 0:
# Return np.nan rather than None here to avoid a rare schema
Expand All @@ -76,3 +74,59 @@ def energy_score_one_unit(df: pl.DataFrame):

# replace NaN with None to average only across non-missing values
return scores_by_unit["energy_score"].fill_nan(None).mean()


def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
reduce_mean: bool = True) -> float | pl.DataFrame:
"""
Compute the probability integral transform (PIT) value for each of a
collection of marginal predictive distributions represented by a set of
samples.
Parameters
----------
model_out_wide: pl.DataFrame
DataFrame of model outputs where each row corresponds to one
(multivariate) sample from a multivariate distribution for one
observational unit.
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
key_cols: list[str]
Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify observational units.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
obs_cols: list[str]
Columns that appear in `obs_data_wide` and identify observed values. The
order of these should match the order of `pred_cols`.
reduce_mean: bool = True
Indicator of whether to return a numeric mean energy score (default) or
a pl.DataFrame with one row per observational unit.
Returns
-------
A pl.DataFrame with one row per observational unit and PIT values stored in
columns named according to `[f"pit_{c}" for c in pred_cols]`.
Notes
-----
Here, the PIT value is calculated as the proportion of samples that are less
than or equal to the observed value.
"""
scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = key_cols)
.group_by(key_cols)
.agg(
[
(pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \
for pred_c, obs_c in zip(pred_cols, obs_cols)
]
)
.select(key_cols + [f"pit_{pred_c}" for pred_c in pred_cols])
)

return scores_by_unit
65 changes: 65 additions & 0 deletions tests/postpredict/metrics/test_marginal_pit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Tests for postpredict.metrics.marginal_pit

from datetime import datetime

import numpy as np
import polars as pl
import pytest
from polars.testing import assert_frame_equal
from postpredict.metrics import marginal_pit


def test_marginal_pit():
rng = np.random.default_rng(seed=123)
model_out_wide = pl.concat([
pl.DataFrame({
"location": "a",
"date": datetime.strptime("2024-10-01", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": np.linspace(0, 99, 100),
"horizon1": rng.permutation(np.linspace(0, 9, 100)),
"horizon2": rng.permutation(np.linspace(8, 17, 100)),
"horizon3": rng.permutation(np.linspace(5.1, 16.1, 100))
}),
pl.DataFrame({
"location": "b",
"date": datetime.strptime("2024-10-08", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": np.linspace(100, 199, 100),
"horizon1": rng.permutation(np.linspace(10.0, 19.0, 100)),
"horizon2": rng.permutation(np.linspace(-3.0, 6.0, 100)),
"horizon3": rng.permutation(np.linspace(10.99, 19.99, 100))
})
])
obs_data_wide = pl.DataFrame({
"location": ["a", "a", "b", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d"),
datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"value": [3.0, 4.0, 0.0, 7.2],
"value_lead1": [4.0, 10.0, 7.2, 9.6],
"value_lead2": [10.0, 5.0, 9.6, 10.0],
"value_lead3": [5.0, 2.0, 10.0, 14.1]
})

# expected scores calculated in R using the scoringRules package:
expected_scores_df = pl.DataFrame({
"location": ["a", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"pit_horizon1": [0.45, 0.0],
"pit_horizon2": [0.23, 1.0],
"pit_horizon3": [0.0, 0.35]
})

actual_scores_df = marginal_pit(model_out_wide=model_out_wide,
obs_data_wide=obs_data_wide,
key_cols=["location", "date"],
pred_cols=["horizon1", "horizon2", "horizon3"],
obs_cols=["value_lead1", "value_lead2", "value_lead3"],
reduce_mean=False)

print(actual_scores_df)

assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19)

0 comments on commit 8115735

Please sign in to comment.