generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Demonstration Primarily For NNH At PyRenew Meeting (#40)
* start demo notebook * upload community demo for 2024-11-19 * the code works with time updates
- Loading branch information
1 parent
1b4f413
commit 4a89c9f
Showing
2 changed files
with
284 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
--- | ||
title: "Community Meeting Utilities Demonstration" | ||
date: "2024-11-19" | ||
format: "html" | ||
engine: "jupyter" | ||
execute: | ||
freeze: true | ||
warnings: false | ||
--- | ||
|
||
This discussion covers: | ||
|
||
* Brief background information on use of `numpyro` and `arviz`. | ||
* Brief introduction to `forecasttools-py` (see [here](https://github.com/CDCgov/forecasttools-py)) | ||
* Walkthrough for making ready influenza forecasts for FluSight hub. | ||
* Walkthrough for connecting `idata` objects to time. | ||
|
||
# Tools In The STF Workflow | ||
|
||
[NumPyro](https://num.pyro.ai/en/stable/getting_started.html) | ||
|
||
* Lightweight PPL built on [JAX](https://jax.readthedocs.io/en/latest/quickstart.html) | ||
* JAX has [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) | ||
* Other benefit: GPU/TPU acceleration | ||
* Other benefit: [just-in-time-compilation](https://jax.readthedocs.io/en/latest/control-flow.html) | ||
* `sample()` calls used for | ||
* (latent variables, observations, or intermediates) | ||
* Efficient sampling (packaged [MCMC](https://num.pyro.ai/en/stable/mcmc.html) & NUTS) | ||
* Forecast output from `numpyro.infer.Predictive()` | ||
|
||
[Arviz](https://www.arviz.org/en/latest/) | ||
|
||
* Compatible with NumPyro models and output. | ||
* [InferenceData objects](https://python.arviz.org/en/stable/api/generated/arviz.InferenceData.html) | ||
* Diagnostics: ESS, R-hat, etc... | ||
* Visualization: trace plots, PPC, pair plots | ||
* Storage: standardized for inference results | ||
|
||
At present, the following are built using NumPyro | ||
|
||
* [PyRenew](https://github.com/CDCgov/PyRenew) | ||
* [pyrenew-hew](https://github.com/CDCgov/pyrenew-hew) (makes use of Arviz) | ||
|
||
Expectedly, more STF projects will use these tools. | ||
|
||
|
||
# Briefly, On `forecasttools-py` | ||
|
||
* 2023-24 influenza forecasting produced need for pre- and post-processing tools | ||
* [`forecasttools`](https://github.com/CDCgov/forecasttools) was on `cdcent`, served | ||
* `cfa-forecast-renewal-epidemia` (see [here](https://github.com/cdcent/cfa-forecast-renewal-epidemia)) | ||
* `cfa-flu-eval` (see [here](https://github.com/cdcent/cfa-flu-eval)) | ||
* `CFA-Mechanistic` (see [here](https://github.com/cdcent/CFA-Mechanistic)) | ||
* Tools needed for newer, Pythonic STF workflows | ||
* Some forecasting processes by `forecasttools` | ||
* Others supported by `forecasttools-py` | ||
* [ScoringUtils](https://github.com/epiforecasts/scoringutils) necessitates R usage | ||
* Right now `forecasttools` handles this better | ||
|
||
# Making A FluSight (Hubverse) Submission | ||
|
||
_A vignette covering this functionality was featured in the original `forecasttools`._ | ||
|
||
Columns of a hubverse submission: | ||
|
||
* `reference_date` | ||
* `target` | ||
* `horizon` | ||
* `target_end_date` | ||
* `location` | ||
* `output_type` | ||
* `output_type_id` | ||
* `value` | ||
|
||
__Task__: Go from the output of a NumPyro model to this format. | ||
|
||
[Example submission](https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output/cfa-flumech/2023-10-14-cfa-flumech.csv): | ||
|
||
```{python} | ||
import polars as pl | ||
import forecasttools | ||
import matplotlib.pyplot as plt | ||
import xarray as xr | ||
import numpy as np | ||
from datetime import timedelta, datetime, date | ||
# load example FluSight submission | ||
submission = forecasttools.example_flusight_submission | ||
# display structure of submission | ||
submission | ||
``` | ||
|
||
Example forecast in `forecasttools-py`: | ||
|
||
* Description in the [README](https://github.com/CDCgov/forecasttools-py) | ||
* Made using [spline regression model](https://github.com/cdcent/upx3-sandbox/blob/main/guides/unsorted/create_idata_flu_forecast_w_dates.py) | ||
* Made for TX influenza hospitalizations 2022-23 season | ||
|
||
 | ||
|
||
|
||
__What Does An InferenceData Object (Idata) Look Like?__ | ||
|
||
```{python} | ||
xr.set_options(display_expand_data=False, display_expand_attrs=False) | ||
# load example forecast(s) | ||
idata = forecasttools.nhsn_flu_forecast_w_dates | ||
idata | ||
``` | ||
|
||
__Examination Of Idata Observed Data Group__ | ||
|
||
```{python} | ||
# examine the observed data (dimensions, coordinates, data variables) | ||
idata.observed_data | ||
``` | ||
|
||
__Take Idata With Dates To Polars Dataframe__ | ||
|
||
```{python} | ||
forecast_df = forecasttools.idata_forecast_w_dates_to_df( | ||
idata_w_dates=idata, | ||
location="TX", | ||
postp_val_name="obs", | ||
postp_dim_name="obs_dim_0", | ||
timepoint_col_name="date", | ||
value_col_name="hosp", | ||
) | ||
forecast_df | ||
``` | ||
|
||
__Aggregate Count By Which Epiweek They Are Members__ | ||
|
||
```{python} | ||
forecast_df = forecasttools.df_aggregate_to_epiweekly( | ||
forecast_df=forecast_df, | ||
value_col = "hosp", | ||
date_col = "date", | ||
id_cols = ["draw", "location"], | ||
weekly_value_name = "weekly_hosp" | ||
) | ||
forecast_df | ||
``` | ||
|
||
__Quantilize The Existing Forecast Dataframe__ | ||
|
||
```{python} | ||
forecast_df = forecasttools.trajectories_to_quantiles( | ||
forecast_df, | ||
timepoint_cols = ["epiweek", "epiyear"], | ||
id_cols = ["location"], | ||
value_col_name = "weekly_hosp" | ||
) | ||
forecast_df | ||
``` | ||
|
||
__Modify The Location Column From Abbreviation To Codes__ | ||
|
||
```{python} | ||
forecast_df_recoded = forecasttools.loc_abbr_to_hubverse_code( | ||
df=forecast_df, location_col="location") | ||
forecast_df_recoded | ||
``` | ||
|
||
__Finalize The Dataframe With The Remaining Columns__ | ||
|
||
```{python} | ||
flusight_output = forecasttools.get_hubverse_table( | ||
quantile_forecasts=forecast_df_recoded, | ||
quantile_value_col="quantile_value", | ||
quantile_level_col="quantile_level", | ||
reference_date="2022-12-15", | ||
location_col="location", | ||
epiweek_col="epiweek", | ||
epiyear_col="epiyear", | ||
) | ||
flusight_output | ||
``` | ||
|
||
# Representations Of Time In Idata | ||
|
||
Each `az.InferenceData` has | ||
|
||
* Groups (e.g. `posterior_predictive`, `observed_data`) | ||
* Variables (e.g. `y_hat`, `obs`) | ||
* Dimensions (e.g. `chain`, `draw`, `obs_dim_0`) | ||
|
||
The non-chain, non-draw dimensions are usually indices. | ||
|
||
We might want to `idata`s contain dates or times. | ||
|
||
This is useful for | ||
|
||
* Hubverse submissions | ||
* Plotting from `idata` groups | ||
* Different intervaled quantities (weekly, daily, etc...) | ||
|
||
|
||
__Default Idata View Without Dates Or Times As Coordinates__ | ||
|
||
```{python} | ||
idata_wo_dates = forecasttools.nhsn_flu_forecast_wo_dates | ||
idata_wo_dates | ||
idata_wo_dates["posterior_predictive"]["obs"]["obs_dim_0"] | ||
``` | ||
|
||
__Intermediary Functions For Adding Dates Or Time__ | ||
|
||
```{python} | ||
start_time_as_dt = date(2022, 8, 1) | ||
variable_data = idata_wo_dates["posterior_predictive"]["obs"] | ||
as_dates_idata = forecasttools.generate_time_range_for_dim( | ||
start_time_as_dt=start_time_as_dt, | ||
variable_data=variable_data, | ||
dimension="obs_dim_0", | ||
time_step=timedelta(days=1), | ||
) | ||
print(as_dates_idata[:5], type(as_dates_idata[0])) | ||
as_time_idata = forecasttools.generate_time_range_for_dim( | ||
start_time_as_dt=start_time_as_dt, | ||
variable_data=variable_data, | ||
dimension="obs_dim_0", | ||
time_step=timedelta(days=1.5), | ||
) | ||
print(as_time_idata[:5], type(as_time_idata[0])) | ||
``` | ||
|
||
__Adding Dates To Single Versus Multiple Groups__ | ||
|
||
```{python} | ||
idata_w_dates_single = forecasttools.add_time_coords_to_idata_dimension( | ||
idata=idata_wo_dates, | ||
group="posterior_predictive", | ||
variable="obs", | ||
dimension="obs_dim_0", | ||
start_date_iso=start_time_as_dt, | ||
time_step=timedelta(weeks=1), # notice weeks | ||
) | ||
idata_w_dates_single.posterior_predictive | ||
``` | ||
|
||
```{python} | ||
idata_w_dates_single["posterior_predictive"]["obs"]["obs_dim_0"].values[-15:] | ||
``` | ||
|
||
```{python} | ||
idata_w_dates_multiple = forecasttools.add_time_coords_to_idata_dimensions( | ||
idata=idata_wo_dates, | ||
groups=["posterior_predictive", "observed_data"], | ||
variables="obs", | ||
dimensions="obs_dim_0", | ||
start_date_iso=start_time_as_dt, | ||
time_step=timedelta(days=1), | ||
) | ||
idata_w_dates_multiple.observed_data | ||
``` | ||
|
||
```{python} | ||
idata_w_dates_multiple["observed_data"]["obs"]["obs_dim_0"].values[-15:] | ||
``` | ||
|
||
```{python} | ||
idata_w_dates_multiple["posterior_predictive"]["obs"]["obs_dim_0"].values[-15:] | ||
``` |