Skip to content

Commit

Permalink
Merge pull request #10 from reichlab/optimize
Browse files Browse the repository at this point in the history
basic setup for parameter optimization
  • Loading branch information
elray1 authored Oct 22, 2024
2 parents 76da491 + af7c916 commit 8492424
Show file tree
Hide file tree
Showing 10 changed files with 1,373 additions and 25 deletions.
1,120 changes: 1,120 additions & 0 deletions docs/fit_schaake.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dynamic = ["version"]

dependencies = [
"numpy",
"optuna",
"polars",
"scikit-learn",
"scipy"
Expand All @@ -20,7 +21,9 @@ dependencies = [
[project.optional-dependencies]
dev = [
"coverage",
"matplotlib",
"mypy",
"pandas",
"pre-commit",
"pytest",
"pytest-mock",
Expand Down
66 changes: 62 additions & 4 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml --extra dev -o requirements/requirements-dev.txt
alembic==1.13.3
# via optuna
cfgv==3.4.0
# via pre-commit
colorlog==6.8.2
# via optuna
contourpy==1.3.0
# via matplotlib
coverage==7.5.1
# via postpredict (pyproject.toml)
cycler==0.12.1
# via matplotlib
distlib==0.3.8
# via virtualenv
filelock==3.14.0
# via virtualenv
fonttools==4.54.1
# via matplotlib
greenlet==3.1.1
# via sqlalchemy
identify==2.5.36
# via pre-commit
iniconfig==2.0.0
# via pytest
joblib==1.4.2
# via scikit-learn
kiwisolver==1.4.7
# via matplotlib
mako==1.3.5
# via alembic
markupsafe==3.0.2
# via mako
matplotlib==3.9.2
# via postpredict (pyproject.toml)
mypy==1.10.0
# via postpredict (pyproject.toml)
mypy-extensions==1.0.0
Expand All @@ -23,10 +43,23 @@ nodeenv==1.8.0
numpy==2.1.2
# via
# postpredict (pyproject.toml)
# contourpy
# matplotlib
# optuna
# pandas
# scikit-learn
# scipy
optuna==4.0.0
# via postpredict (pyproject.toml)
packaging==24.0
# via pytest
# via
# matplotlib
# optuna
# pytest
pandas==2.2.3
# via postpredict (pyproject.toml)
pillow==11.0.0
# via matplotlib
platformdirs==4.2.1
# via virtualenv
pluggy==1.5.0
Expand All @@ -35,29 +68,54 @@ polars==1.9.0
# via postpredict (pyproject.toml)
pre-commit==3.7.0
# via postpredict (pyproject.toml)
pyparsing==3.2.0
# via matplotlib
pytest==8.2.0
# via
# postpredict (pyproject.toml)
# pytest-mock
pytest-mock==3.14.0
# via postpredict (pyproject.toml)
python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
pytz==2024.2
# via pandas
pyyaml==6.0.1
# via pre-commit
# via
# optuna
# pre-commit
ruff==0.4.3
# via postpredict (pyproject.toml)
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
# via
# postpredict (pyproject.toml)
# scikit-learn
setuptools==75.1.0
# via nodeenv
six==1.16.0
# via python-dateutil
sqlalchemy==2.0.36
# via
# alembic
# optuna
threadpoolctl==3.5.0
# via scikit-learn
toml==0.10.2
# via postpredict (pyproject.toml)
tqdm==4.66.5
# via optuna
types-toml==0.10.8.20240310
# via postpredict (pyproject.toml)
typing-extensions==4.11.0
# via mypy
# via
# alembic
# mypy
# sqlalchemy
tzdata==2024.2
# via pandas
virtualenv==20.26.1
# via pre-commit
31 changes: 30 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements/requirements.txt
alembic==1.13.3
# via optuna
colorlog==6.8.2
# via optuna
greenlet==3.1.1
# via sqlalchemy
joblib==1.4.2
# via scikit-learn
mako==1.3.5
# via alembic
markupsafe==3.0.2
# via mako
numpy==2.1.2
# via
# postpredict (pyproject.toml)
# optuna
# scikit-learn
# scipy
optuna==4.0.0
# via postpredict (pyproject.toml)
packaging==24.1
# via optuna
polars==1.9.0
# via postpredict (pyproject.toml)
pyyaml==6.0.2
# via optuna
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
# via
# postpredict (pyproject.toml)
# scikit-learn
sqlalchemy==2.0.36
# via
# alembic
# optuna
threadpoolctl==3.5.0
# via scikit-learn
tqdm==4.66.5
# via optuna
typing-extensions==4.12.2
# via
# alembic
# sqlalchemy
16 changes: 12 additions & 4 deletions src/postpredict/dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _build_templates(self, wide_model_out):
Returns
-------
templates: np.ndarray
templates: pl.DataFrame
Dependence templates of shape (wide_model_out.shape[0], self.train_Y.shape[1])
"""

Expand All @@ -45,7 +45,8 @@ def transform(self, model_out: pl.DataFrame,
reference_time_col: str = "reference_date",
horizon_col: str = "horizon", pred_col: str = "value",
idx_col: str = "output_type_id",
obs_mask: np.ndarray | None = None):
obs_mask: np.ndarray | None = None,
return_long_format: bool = True):
"""
Apply a postprocessing transformation to sample predictions to induce
dependence across time in the predictive samples.
Expand All @@ -71,6 +72,9 @@ def transform(self, model_out: pl.DataFrame,
array of shape (self.df.shape[0], ). Rows of self.df where obs_mask
is True will be used, while rows of self.df where obs_mask is False
will not be used.
return_long_format: bool
If True, return long format. If False, return wide format with
horizon pivoted into columns.
Returns
-------
Expand All @@ -93,6 +97,9 @@ def transform(self, model_out: pl.DataFrame,
.map_groups(self._transform_one_group)
)

if not return_long_format:
return transformed_wide_model_out

# unpivot back to long format
pivot_index = [c for c in model_out.columns if c not in [horizon_col, pred_col]]
transformed_model_out = (
Expand All @@ -111,12 +118,12 @@ def transform(self, model_out: pl.DataFrame,
.cast(model_out[horizon_col].dtype)
)
)

return transformed_model_out


def _transform_one_group(self, wide_model_out):
templates = self._build_templates(wide_model_out)
templates = self._build_templates(wide_model_out).to_numpy()
transformed_model_out = self._apply_shuffle(
wide_model_out = wide_model_out,
value_cols = self.wide_horizon_cols,
Expand Down Expand Up @@ -362,4 +369,5 @@ def _build_templates(self, wide_model_out):

# get the templates
templates = self.train_Y[selected_inds, :]

return templates
84 changes: 73 additions & 11 deletions src/postpredict/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
key_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
index_cols: list[str] | None, pred_cols: list[str], obs_cols: list[str],
reduce_mean: bool = True) -> float | pl.DataFrame:
"""
Compute the energy score for a collection of predictive samples.
Expand All @@ -18,9 +18,11 @@ def energy_score(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
key_cols: list[str]
index_cols: list[str]
Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify observational units.
identify the unit of a multivariate prediction (e.g., including the
location, age_group, and reference_time of a prediction). These columns
will be included in the returned dataframe.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
Expand Down Expand Up @@ -51,21 +53,81 @@ def energy_score_one_unit(df: pl.DataFrame):
Compute energy score for one observational unit based on a collection of
samples. Note, we define this function here so that key_cols, pred_cols
and obs_cols are in scope.
See
"""
score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \
- 0.5 * np.mean(pairwise_distances(df[pred_cols]))
return df[0, key_cols].with_columns(energy_score = pl.lit(score))
if df[pred_cols + obs_cols].null_count().to_numpy().sum() > 0:
# Return np.nan rather than None here to avoid a rare schema
# error when the first processed group would yield None.
score = np.nan
else:
score = np.mean(pairwise_distances(df[pred_cols], df[0, obs_cols])) \
- 0.5 * np.mean(pairwise_distances(df[pred_cols]))

return df[0, index_cols].with_columns(energy_score = pl.lit(score))

scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = key_cols)
.group_by(*key_cols)
.join(obs_data_wide, on = index_cols)
.group_by(*index_cols)
.map_groups(energy_score_one_unit)
)

if not reduce_mean:
return scores_by_unit

return scores_by_unit["energy_score"].mean()
# replace NaN with None to average only across non-missing values
return scores_by_unit["energy_score"].fill_nan(None).mean()


def marginal_pit(model_out_wide: pl.DataFrame, obs_data_wide: pl.DataFrame,
index_cols: list[str] | None, pred_cols: list[str],
obs_cols: list[str]) -> pl.DataFrame:
"""
Compute the probability integral transform (PIT) value for each of a
collection of marginal predictive distributions represented by a set of
samples.
Parameters
----------
model_out_wide: pl.DataFrame
DataFrame of model outputs where each row corresponds to one
(multivariate) sample from a multivariate distribution for one
observational unit.
obs_data_wide: pl.DataFrame
DataFrame of observed values where each row corresponds to one
(multivariate) observed outcome for one observational unit.
index_cols: list[str]
Columns that appear in both `model_out_wide` and `obs_data_wide` that
identify the unit of a multivariate prediction (e.g., including the
location, age_group, and reference_time of a prediction). These columns
will be included in the returned dataframe.
pred_cols: list[str]
Columns that appear in `model_out_wide` and identify predicted (sampled)
values. The order of these should match the order of `obs_cols`.
obs_cols: list[str]
Columns that appear in `obs_data_wide` and identify observed values. The
order of these should match the order of `pred_cols`.
Returns
-------
A pl.DataFrame with one row per observational unit and PIT values stored in
columns named according to `[f"pit_{c}" for c in pred_cols]`.
Notes
-----
Here, the PIT value is calculated as the proportion of samples that are less
than or equal to the observed value.
"""
scores_by_unit = (
model_out_wide
.join(obs_data_wide, on = index_cols)
.group_by(index_cols)
.agg(
[
(pl.col(pred_c) <= pl.col(obs_c)).mean().alias(f"pit_{pred_c}") \
for pred_c, obs_c in zip(pred_cols, obs_cols)
]
)
.select(index_cols + [f"pit_{pred_c}" for pred_c in pred_cols])
)

return scores_by_unit
10 changes: 8 additions & 2 deletions src/postpredict/weighters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections

import numpy as np
import polars as pl


class Parameter(collections.UserDict):
Expand Down Expand Up @@ -65,6 +66,11 @@ def get_weights(self, train_X, test_X):
numpy array of shape (n_test, n_train) with weights for each training set
instance, where weights sum to 1 within each row.
"""
n_train = train_X.shape[0]
prop_weights = np.exp(-0.5 / self.parameters["h"].value * (test_X - train_X.reshape(1, n_train))**2)
if isinstance(train_X, pl.DataFrame):
train_X = train_X.to_numpy()

if isinstance(test_X, pl.DataFrame):
test_X = test_X.to_numpy()

prop_weights = np.exp(-0.5 / self.parameters["h"].value * (test_X - train_X.transpose())**2)
return prop_weights / np.sum(prop_weights, axis = 1, keepdims = True)
2 changes: 1 addition & 1 deletion tests/postpredict/dependence/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fit(self, df, key_cols=None, time_col="date", obs_col="value", feat_cols=["d


def _build_templates(self, wide_model_out):
return templates
return pl.DataFrame(templates)

tdp = TestPostprocessor(rng = np.random.default_rng(42))
tdp.df = obs_data
Expand Down
Loading

0 comments on commit 8492424

Please sign in to comment.