Skip to content

Commit

Permalink
Merge pull request #9 from reichlab/observation_masking
Browse files Browse the repository at this point in the history
allow for observation masking
  • Loading branch information
elray1 authored Oct 22, 2024
2 parents 4ea7793 + 8492424 commit 3ee78d5
Show file tree
Hide file tree
Showing 11 changed files with 1,432 additions and 30 deletions.
1,120 changes: 1,120 additions & 0 deletions docs/fit_schaake.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dynamic = ["version"]

dependencies = [
"numpy",
"optuna",
"polars",
"scikit-learn",
"scipy"
Expand All @@ -20,7 +21,9 @@ dependencies = [
[project.optional-dependencies]
dev = [
"coverage",
"matplotlib",
"mypy",
"pandas",
"pre-commit",
"pytest",
"pytest-mock",
Expand Down
66 changes: 62 additions & 4 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml --extra dev -o requirements/requirements-dev.txt
alembic==1.13.3
# via optuna
cfgv==3.4.0
# via pre-commit
colorlog==6.8.2
# via optuna
contourpy==1.3.0
# via matplotlib
coverage==7.5.1
# via postpredict (pyproject.toml)
cycler==0.12.1
# via matplotlib
distlib==0.3.8
# via virtualenv
filelock==3.14.0
# via virtualenv
fonttools==4.54.1
# via matplotlib
greenlet==3.1.1
# via sqlalchemy
identify==2.5.36
# via pre-commit
iniconfig==2.0.0
# via pytest
joblib==1.4.2
# via scikit-learn
kiwisolver==1.4.7
# via matplotlib
mako==1.3.5
# via alembic
markupsafe==3.0.2
# via mako
matplotlib==3.9.2
# via postpredict (pyproject.toml)
mypy==1.10.0
# via postpredict (pyproject.toml)
mypy-extensions==1.0.0
Expand All @@ -23,10 +43,23 @@ nodeenv==1.8.0
numpy==2.1.2
# via
# postpredict (pyproject.toml)
# contourpy
# matplotlib
# optuna
# pandas
# scikit-learn
# scipy
optuna==4.0.0
# via postpredict (pyproject.toml)
packaging==24.0
# via pytest
# via
# matplotlib
# optuna
# pytest
pandas==2.2.3
# via postpredict (pyproject.toml)
pillow==11.0.0
# via matplotlib
platformdirs==4.2.1
# via virtualenv
pluggy==1.5.0
Expand All @@ -35,29 +68,54 @@ polars==1.9.0
# via postpredict (pyproject.toml)
pre-commit==3.7.0
# via postpredict (pyproject.toml)
pyparsing==3.2.0
# via matplotlib
pytest==8.2.0
# via
# postpredict (pyproject.toml)
# pytest-mock
pytest-mock==3.14.0
# via postpredict (pyproject.toml)
python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
pytz==2024.2
# via pandas
pyyaml==6.0.1
# via pre-commit
# via
# optuna
# pre-commit
ruff==0.4.3
# via postpredict (pyproject.toml)
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
# via
# postpredict (pyproject.toml)
# scikit-learn
setuptools==75.1.0
# via nodeenv
six==1.16.0
# via python-dateutil
sqlalchemy==2.0.36
# via
# alembic
# optuna
threadpoolctl==3.5.0
# via scikit-learn
toml==0.10.2
# via postpredict (pyproject.toml)
tqdm==4.66.5
# via optuna
types-toml==0.10.8.20240310
# via postpredict (pyproject.toml)
typing-extensions==4.11.0
# via mypy
# via
# alembic
# mypy
# sqlalchemy
tzdata==2024.2
# via pandas
virtualenv==20.26.1
# via pre-commit
31 changes: 30 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements/requirements.txt
alembic==1.13.3
# via optuna
colorlog==6.8.2
# via optuna
greenlet==3.1.1
# via sqlalchemy
joblib==1.4.2
# via scikit-learn
mako==1.3.5
# via alembic
markupsafe==3.0.2
# via mako
numpy==2.1.2
# via
# postpredict (pyproject.toml)
# optuna
# scikit-learn
# scipy
optuna==4.0.0
# via postpredict (pyproject.toml)
packaging==24.1
# via optuna
polars==1.9.0
# via postpredict (pyproject.toml)
pyyaml==6.0.2
# via optuna
scikit-learn==1.5.2
# via postpredict (pyproject.toml)
scipy==1.14.1
# via scikit-learn
# via
# postpredict (pyproject.toml)
# scikit-learn
sqlalchemy==2.0.36
# via
# alembic
# optuna
threadpoolctl==3.5.0
# via scikit-learn
tqdm==4.66.5
# via optuna
typing-extensions==4.12.2
# via
# alembic
# sqlalchemy
44 changes: 35 additions & 9 deletions src/postpredict/dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ def _build_templates(self, wide_model_out):
Returns
-------
templates: np.ndarray
templates: pl.DataFrame
Dependence templates of shape (wide_model_out.shape[0], self.train_Y.shape[1])
"""


def transform(self, model_out: pl.DataFrame,
reference_time_col: str = "reference_date",
horizon_col: str = "horizon", pred_col: str = "value",
idx_col: str = "output_type_id"):
idx_col: str = "output_type_id",
obs_mask: np.ndarray | None = None,
return_long_format: bool = True):
"""
Apply a postprocessing transformation to sample predictions to induce
dependence across time in the predictive samples.
Expand All @@ -63,6 +65,16 @@ def transform(self, model_out: pl.DataFrame,
name of column in model_out with predicted values (samples)
idx_col: str
name of column in model_out with sample indices
obs_mask: np.ndarray | None
mask to use for observed data. The primary use case is to support
cross-validation. If None, all observed data are used to form
dependence templates. Otherwise, `obs_mask` should be a boolean
array of shape (self.df.shape[0], ). Rows of self.df where obs_mask
is True will be used, while rows of self.df where obs_mask is False
will not be used.
return_long_format: bool
If True, return long format. If False, return wide format with
horizon pivoted into columns.
Returns
-------
Expand All @@ -76,7 +88,7 @@ def transform(self, model_out: pl.DataFrame,
max_horizon = model_out[horizon_col].max()

# extract train_X and train_Y from observed data (self.df)
self._build_train_X_Y(min_horizon, max_horizon)
self._build_train_X_Y(min_horizon, max_horizon, obs_mask)

# perform the transformation, one group at a time
transformed_wide_model_out = (
Expand All @@ -85,6 +97,9 @@ def transform(self, model_out: pl.DataFrame,
.map_groups(self._transform_one_group)
)

if not return_long_format:
return transformed_wide_model_out

# unpivot back to long format
pivot_index = [c for c in model_out.columns if c not in [horizon_col, pred_col]]
transformed_model_out = (
Expand All @@ -103,12 +118,12 @@ def transform(self, model_out: pl.DataFrame,
.cast(model_out[horizon_col].dtype)
)
)

return transformed_model_out


def _transform_one_group(self, wide_model_out):
templates = self._build_templates(wide_model_out)
templates = self._build_templates(wide_model_out).to_numpy()
transformed_model_out = self._apply_shuffle(
wide_model_out = wide_model_out,
value_cols = self.wide_horizon_cols,
Expand Down Expand Up @@ -164,7 +179,8 @@ def _apply_shuffle(self,
return shuffled_wmo


def _build_train_X_Y(self, min_horizon, max_horizon):
def _build_train_X_Y(self, min_horizon, max_horizon,
obs_mask: np.ndarray | None = None):
"""
Build training set data frames self.train_X with features and
self.train_Y with observed values in windows from min_horizon to
Expand All @@ -176,6 +192,13 @@ def _build_train_X_Y(self, min_horizon, max_horizon):
minimum prediction horizon
max_horizon: int
maximum prediction horizon
obs_mask: np.ndarray | None
mask to use for observed data. The primary use case is to support
cross-validation. If None, all observed data are used to form
dependence templates. Otherwise, `obs_mask` should be a boolean
array of shape (self.df.shape[0], ). Rows of self.df where obs_mask
is True will be used, while rows of self.df where obs_mask is False
will not be used.
Returns
-------
Expand Down Expand Up @@ -205,9 +228,11 @@ def _build_train_X_Y(self, min_horizon, max_horizon):
.alias(shift_varname)
)

df_dropnull = self.df.drop_nulls()
self.train_X = df_dropnull[self.feat_cols]
self.train_Y = df_dropnull[self.shift_varnames]
if obs_mask is None:
obs_mask = True
df_mask_and_dropnull = self.df.filter(obs_mask).drop_nulls()
self.train_X = df_mask_and_dropnull[self.feat_cols]
self.train_Y = df_mask_and_dropnull[self.shift_varnames]


def _pivot_horizon(self, model_out, reference_time_col, horizon_col,
Expand Down Expand Up @@ -344,4 +369,5 @@ def _build_templates(self, wide_model_out):

# get the templates
templates = self.train_Y[selected_inds, :]

return templates
Loading

0 comments on commit 3ee78d5

Please sign in to comment.