Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

marginal_pit metric #11

Merged
merged 5 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions src/postpredict/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
index_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
reduce_mean: bool = True) -> float | pl.DataFrame:
"""
Compute the energy score for a collection of predictive samples.
Expand All @@ -18,9 +18,11 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
key_cols: list[str]
index_cols: list[str]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to stick with a naming convention throughout the package where key_cols identify an observational unit in the sense of a group of people (location, age group), columns like reference_time_col and time_col refer to time points, and here index_cols includes both key_cols and reference_time_col, since this is what's need to identify a predicted unit.

Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify observational units.
identify the unit of a multivariate prediction (e.g., including the
location, age_group, and reference_time of a prediction). These columns
will be included in the returned dataframe.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
Expand Down Expand Up @@ -51,8 +53,6 @@ def energy_score_one_unit(df: pl.DataFrame):
Compute energy score for one observational unit based on a collection of
samples. Note, we define this function here so that key_cols, pred_cols
and obs_cols are in scope.

See
"""
if df[pred_cols + obs_cols].null_count().to_numpy().sum() > 0:
# Return np.nan rather than None here to avoid a rare schema
Expand All @@ -62,12 +62,12 @@ def energy_score_one_unit(df: pl.DataFrame):
score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \
- 0.5 * np.mean(pairwise_distances(df[pred_cols]))

return df[0, key_cols].with_columns(energy_score = pl.lit(score))
return df[0, index_cols].with_columns(energy_score = pl.lit(score))

scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = key_cols)
.group_by(*key_cols)
.join(obs_data_wide, on = index_cols)
.group_by(*index_cols)
.map_groups(energy_score_one_unit)
)

Expand All @@ -76,3 +76,58 @@ def energy_score_one_unit(df: pl.DataFrame):

# replace NaN with None to average only across non-missing values
return scores_by_unit["energy_score"].fill_nan(None).mean()


def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
index_cols: list[str] | None, pred_cols: list[str],
obs_cols: list[str]) -> pl.DataFrame:
"""
Compute the probability integral transform (PIT) value for each of a
collection of marginal predictive distributions represented by a set of
samples.

Parameters
----------
model_out_wide: pl.DataFrame
DataFrame of model outputs where each row corresponds to one
(multivariate) sample from a multivariate distribution for one
observational unit.
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
index_cols: list[str]
Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify the unit of a multivariate prediction (e.g., including the
location, age_group, and reference_time of a prediction). These columns
will be included in the returned dataframe.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
obs_cols: list[str]
Columns that appear in `obs_data_wide` and identify observed values. The
order of these should match the order of `pred_cols`.

Returns
-------
A pl.DataFrame with one row per observational unit and PIT values stored in
columns named according to `[f"pit_{c}" for c in pred_cols]`.

Notes
-----
Here, the PIT value is calculated as the proportion of samples that are less
than or equal to the observed value.
"""
scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = index_cols)
.group_by(index_cols)
.agg(
[
(pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \
for pred_c, obs_c in zip(pred_cols, obs_cols)
]
)
.select(index_cols + [f"pit_{pred_c}" for pred_c in pred_cols])
)

return scores_by_unit
4 changes: 2 additions & 2 deletions tests/postpredict/metrics/test_energy_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_energy_score():

actual_scores_df = energy_score(model_out_wide=model_out_wide,
obs_data_wide=obs_data_wide,
key_cols=["location", "date"],
index_cols=["location", "date"],
pred_cols=["horizon1", "horizon2", "horizon3"],
obs_cols=["value_lead1", "value_lead2", "value_lead3"],
reduce_mean=False)
Expand All @@ -77,7 +77,7 @@ def test_energy_score():
expected_mean_score = np.mean([5.8560677725938221627, 5.9574451598773787708])
actual_mean_score = energy_score(model_out_wide=model_out_wide,
obs_data_wide=obs_data_wide,
key_cols=["location", "date"],
index_cols=["location", "date"],
pred_cols=["horizon1", "horizon2", "horizon3"],
obs_cols=["value_lead1", "value_lead2", "value_lead3"],
reduce_mean=True)
Expand Down
62 changes: 62 additions & 0 deletions tests/postpredict/metrics/test_marginal_pit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Tests for postpredict.metrics.marginal_pit

from datetime import datetime

import numpy as np
import polars as pl
from polars.testing import assert_frame_equal
from postpredict.metrics import marginal_pit


def test_marginal_pit():
rng = np.random.default_rng(seed=123)
model_out_wide = pl.concat([
pl.DataFrame({
"location": "a",
"date": datetime.strptime("2024-10-01", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": np.linspace(0, 99, 100),
"horizon1": rng.permutation(np.linspace(0, 9, 100)),
"horizon2": rng.permutation(np.linspace(8, 17, 100)),
"horizon3": rng.permutation(np.linspace(5.1, 16.1, 100))
}),
pl.DataFrame({
"location": "b",
"date": datetime.strptime("2024-10-08", "%Y-%m-%d"),
"output_type": "sample",
"output_type_id": np.linspace(100, 199, 100),
"horizon1": rng.permutation(np.linspace(10.0, 19.0, 100)),
"horizon2": rng.permutation(np.linspace(-3.0, 6.0, 100)),
"horizon3": rng.permutation(np.linspace(10.99, 19.99, 100))
})
])
obs_data_wide = pl.DataFrame({
"location": ["a", "a", "b", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d"),
datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"value": [3.0, 4.0, 0.0, 7.2],
"value_lead1": [4.0, 10.0, 7.2, 9.6],
"value_lead2": [10.0, 5.0, 9.6, 10.0],
"value_lead3": [5.0, 2.0, 10.0, 14.1]
})

# expected PIT values: the number of samples less than or equal to
# corresponding observed values
expected_scores_df = pl.DataFrame({
"location": ["a", "b"],
"date": [datetime.strptime("2024-10-01", "%Y-%m-%d"),
datetime.strptime("2024-10-08", "%Y-%m-%d")],
"pit_horizon1": [0.45, 0.0],
"pit_horizon2": [0.23, 1.0],
"pit_horizon3": [0.0, 0.35]
})

actual_scores_df = marginal_pit(model_out_wide=model_out_wide,
obs_data_wide=obs_data_wide,
index_cols=["location", "date"],
pred_cols=["horizon1", "horizon2", "horizon3"],
obs_cols=["value_lead1", "value_lead2", "value_lead3"])

assert_frame_equal(actual_scores_df, expected_scores_df, check_row_order=False, atol=1e-19)