Skip to content

Commit

Permalink
Demonstration Primarily For NNH At PyRenew Meeting (#40)
Browse files Browse the repository at this point in the history
* start demo notebook

* upload community demo for 2024-11-19

* the code works with time updates
  • Loading branch information
AFg6K7h4fhy2 authored Nov 25, 2024
1 parent 1b4f413 commit 4a89c9f
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 1 deletion.
12 changes: 11 additions & 1 deletion forecasttools/idata_w_dates_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
scoringutils ready dataframes
"""

import itertools
from datetime import date, datetime, timedelta

import arviz as az
Expand Down Expand Up @@ -242,8 +243,17 @@ def add_time_coords_to_idata_dimensions(
forecasttools.validate_iter_has_expected_types(
dimensions, str, "dimensions"
)
# create tuples, the groups should have
# every combination of variables and
# dimensions
var_dim_combinations = list(itertools.product(variables, dimensions))
gvd_tuples = [
(group, var, dim)
for group in groups
for var, dim in var_dim_combinations
]
# iterate over (group, variable, dimension) triples
for group, variable, dimension in zip(groups, variables, dimensions):
for group, variable, dimension in gvd_tuples:
try:
idata = add_time_coords_to_idata_dimension(
idata=idata,
Expand Down
273 changes: 273 additions & 0 deletions notebooks/forecasttools_community_demo_2024-11-19.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
---
title: "Community Meeting Utilities Demonstration"
date: "2024-11-19"
format: "html"
engine: "jupyter"
execute:
freeze: true
warnings: false
---

This discussion covers:

* Brief background information on use of `numpyro` and `arviz`.
* Brief introduction to `forecasttools-py` (see [here](https://github.com/CDCgov/forecasttools-py))
* Walkthrough for making ready influenza forecasts for FluSight hub.
* Walkthrough for connecting `idata` objects to time.

# Tools In The STF Workflow

[NumPyro](https://num.pyro.ai/en/stable/getting_started.html)

* Lightweight PPL built on [JAX](https://jax.readthedocs.io/en/latest/quickstart.html)
* JAX has [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html)
* Other benefit: GPU/TPU acceleration
* Other benefit: [just-in-time-compilation](https://jax.readthedocs.io/en/latest/control-flow.html)
* `sample()` calls used for
* (latent variables, observations, or intermediates)
* Efficient sampling (packaged [MCMC](https://num.pyro.ai/en/stable/mcmc.html) & NUTS)
* Forecast output from `numpyro.infer.Predictive()`

[Arviz](https://www.arviz.org/en/latest/)

* Compatible with NumPyro models and output.
* [InferenceData objects](https://python.arviz.org/en/stable/api/generated/arviz.InferenceData.html)
* Diagnostics: ESS, R-hat, etc...
* Visualization: trace plots, PPC, pair plots
* Storage: standardized for inference results

At present, the following are built using NumPyro

* [PyRenew](https://github.com/CDCgov/PyRenew)
* [pyrenew-hew](https://github.com/CDCgov/pyrenew-hew) (makes use of Arviz)

Expectedly, more STF projects will use these tools.


# Briefly, On `forecasttools-py`

* 2023-24 influenza forecasting produced need for pre- and post-processing tools
* [`forecasttools`](https://github.com/CDCgov/forecasttools) was on `cdcent`, served
* `cfa-forecast-renewal-epidemia` (see [here](https://github.com/cdcent/cfa-forecast-renewal-epidemia))
* `cfa-flu-eval` (see [here](https://github.com/cdcent/cfa-flu-eval))
* `CFA-Mechanistic` (see [here](https://github.com/cdcent/CFA-Mechanistic))
* Tools needed for newer, Pythonic STF workflows
* Some forecasting processes by `forecasttools`
* Others supported by `forecasttools-py`
* [ScoringUtils](https://github.com/epiforecasts/scoringutils) necessitates R usage
* Right now `forecasttools` handles this better

# Making A FluSight (Hubverse) Submission

_A vignette covering this functionality was featured in the original `forecasttools`._

Columns of a hubverse submission:

* `reference_date`
* `target`
* `horizon`
* `target_end_date`
* `location`
* `output_type`
* `output_type_id`
* `value`

__Task__: Go from the output of a NumPyro model to this format.

[Example submission](https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output/cfa-flumech/2023-10-14-cfa-flumech.csv):

```{python}
import polars as pl
import forecasttools
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from datetime import timedelta, datetime, date
# load example FluSight submission
submission = forecasttools.example_flusight_submission
# display structure of submission
submission
```

Example forecast in `forecasttools-py`:

* Description in the [README](https://github.com/CDCgov/forecasttools-py)
* Made using [spline regression model](https://github.com/cdcent/upx3-sandbox/blob/main/guides/unsorted/create_idata_flu_forecast_w_dates.py)
* Made for TX influenza hospitalizations 2022-23 season

![](../assets/example_forecast_w_dates.png)


__What Does An InferenceData Object (Idata) Look Like?__

```{python}
xr.set_options(display_expand_data=False, display_expand_attrs=False)
# load example forecast(s)
idata = forecasttools.nhsn_flu_forecast_w_dates
idata
```

__Examination Of Idata Observed Data Group__

```{python}
# examine the observed data (dimensions, coordinates, data variables)
idata.observed_data
```

__Take Idata With Dates To Polars Dataframe__

```{python}
forecast_df = forecasttools.idata_forecast_w_dates_to_df(
idata_w_dates=idata,
location="TX",
postp_val_name="obs",
postp_dim_name="obs_dim_0",
timepoint_col_name="date",
value_col_name="hosp",
)
forecast_df
```

__Aggregate Count By Which Epiweek They Are Members__

```{python}
forecast_df = forecasttools.df_aggregate_to_epiweekly(
forecast_df=forecast_df,
value_col = "hosp",
date_col = "date",
id_cols = ["draw", "location"],
weekly_value_name = "weekly_hosp"
)
forecast_df
```

__Quantilize The Existing Forecast Dataframe__

```{python}
forecast_df = forecasttools.trajectories_to_quantiles(
forecast_df,
timepoint_cols = ["epiweek", "epiyear"],
id_cols = ["location"],
value_col_name = "weekly_hosp"
)
forecast_df
```

__Modify The Location Column From Abbreviation To Codes__

```{python}
forecast_df_recoded = forecasttools.loc_abbr_to_hubverse_code(
df=forecast_df, location_col="location")
forecast_df_recoded
```

__Finalize The Dataframe With The Remaining Columns__

```{python}
flusight_output = forecasttools.get_hubverse_table(
quantile_forecasts=forecast_df_recoded,
quantile_value_col="quantile_value",
quantile_level_col="quantile_level",
reference_date="2022-12-15",
location_col="location",
epiweek_col="epiweek",
epiyear_col="epiyear",
)
flusight_output
```

# Representations Of Time In Idata

Each `az.InferenceData` has

* Groups (e.g. `posterior_predictive`, `observed_data`)
* Variables (e.g. `y_hat`, `obs`)
* Dimensions (e.g. `chain`, `draw`, `obs_dim_0`)

The non-chain, non-draw dimensions are usually indices.

We might want to `idata`s contain dates or times.

This is useful for

* Hubverse submissions
* Plotting from `idata` groups
* Different intervaled quantities (weekly, daily, etc...)


__Default Idata View Without Dates Or Times As Coordinates__

```{python}
idata_wo_dates = forecasttools.nhsn_flu_forecast_wo_dates
idata_wo_dates
idata_wo_dates["posterior_predictive"]["obs"]["obs_dim_0"]
```

__Intermediary Functions For Adding Dates Or Time__

```{python}
start_time_as_dt = date(2022, 8, 1)
variable_data = idata_wo_dates["posterior_predictive"]["obs"]
as_dates_idata = forecasttools.generate_time_range_for_dim(
start_time_as_dt=start_time_as_dt,
variable_data=variable_data,
dimension="obs_dim_0",
time_step=timedelta(days=1),
)
print(as_dates_idata[:5], type(as_dates_idata[0]))
as_time_idata = forecasttools.generate_time_range_for_dim(
start_time_as_dt=start_time_as_dt,
variable_data=variable_data,
dimension="obs_dim_0",
time_step=timedelta(days=1.5),
)
print(as_time_idata[:5], type(as_time_idata[0]))
```

__Adding Dates To Single Versus Multiple Groups__

```{python}
idata_w_dates_single = forecasttools.add_time_coords_to_idata_dimension(
idata=idata_wo_dates,
group="posterior_predictive",
variable="obs",
dimension="obs_dim_0",
start_date_iso=start_time_as_dt,
time_step=timedelta(weeks=1), # notice weeks
)
idata_w_dates_single.posterior_predictive
```

```{python}
idata_w_dates_single["posterior_predictive"]["obs"]["obs_dim_0"].values[-15:]
```

```{python}
idata_w_dates_multiple = forecasttools.add_time_coords_to_idata_dimensions(
idata=idata_wo_dates,
groups=["posterior_predictive", "observed_data"],
variables="obs",
dimensions="obs_dim_0",
start_date_iso=start_time_as_dt,
time_step=timedelta(days=1),
)
idata_w_dates_multiple.observed_data
```

```{python}
idata_w_dates_multiple["observed_data"]["obs"]["obs_dim_0"].values[-15:]
```

```{python}
idata_w_dates_multiple["posterior_predictive"]["obs"]["obs_dim_0"].values[-15:]
```

0 comments on commit 4a89c9f

Please sign in to comment.