Skip to content

Commit

Permalink
Attempt Use Of Dates With idata Objects (#30)
Browse files Browse the repository at this point in the history
* update pyproject installs; add helper files for experimentation; change name of idata object to include wo dates

* actually change nc flu forecast file name to match

* update idata dates experimentation file

* move away in experimentation from manual date array creation

* add DHM suggestion for deep polars use in idata wo dates to df w dates; add another option w/ pl duration to experimentation file

* add some minor comments from DHM and DB discussion

* create dates idata object; update init; update README; split experimentation files by dates addition method

* bundle additional idata; minor README edit

* minor, initial docstring edits; first-pass multi-group support add dates to idata

* test new line length pre-commit; first pass add dates to idata groups; rename idata modifications and idata to tidy formatting

* further modifications to idata_forecast_w_dates_to_df

* fix pre-commit

* change of line length was to examine if ruff could get lines looking more similar to DHMs compact code, now reverting change

* fix pre-commit

* recreate dated flu forecast with iso str not np datetime; update flu forecast creation process; verify test out idata works

* some changes to ensure flusight vignette works properly

* very minor flusight vignette correction

* change width of image in README

* update arviz

* passed converted --> converted

* tidy like output

* what the output of the added tidy experiment file looks like

* now removing output parquet

* minor edit to tidy draws experimental file

* add eager parameter to date range

* add eager parameter to date range

* rename experimentation file for clarity

* not actually tidy; more specific naming

* remove the notebooks that are less directly relevant to this PRs scope

* remove reference to tidy dataframes
  • Loading branch information
AFg6K7h4fhy2 authored Oct 25, 2024
1 parent 013f4d4 commit 558c640
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 150 deletions.
33 changes: 20 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ make_nshn_fitting_dataset(
)
```

## Influenza Hospitalizations Forecast
## Influenza Hospitalizations Forecast(s)

An example forecast stored in an Arviz `InferenceData` object is included for vignettes and user experimentation. This 28 day forecast for Texas was made using a spline regression model fitted to NHSN influenza data between 2022-08-08 and 2022-12-08. The `idata` object which includes the observed data and posterior predictive samples is given below:
Two example forecasts stored in Arviz `InferenceData` objects are included for vignettes and user experimentation. Both are 28 day influenza hospital admissions forecasts for Texas made using a spline regression model fitted to NHSN data between 2022-08-08 and 2022-12-08. The only difference between the forecasts is that `example_flu_forecast_w_dates.nc` has dates as its coordinates. The `idata` objects which includes the observed data and posterior predictive samples is given below:

```
Inference data with groups:
Expand All @@ -184,31 +184,38 @@ Inference data with groups:
> observed_data
```

The forecast `idata` is accessed via:
The forecast `idata`s are accessed via:

```python
import forecasttools


idata = forecasttools.nhsn_flu_forecast
# idata with dates as coordinates
idata_w_dates = forecasttools.nhsn_flu_forecast_w_dates

# idata without dates as coordinates
idata_wo_dates = forecasttools.nhsn_flu_forecast_wo_dates
```

The forecast was generated following the creation of `nhsn_hosp_flu.csv` (see previous section) by running `data.py` with the following added:

```python
make_nhsn_fitted_forecast_idata(
nhsn_dataset_path="nhsn_hosp_flu.csv",
file_save_path=os.path.join(os.getcwd(), "example_flu_forecast.nc"),
start_date"2022/08/08",
end_date="2023/12/08",
forecast_days=28,
make_forecast(
nhsn_data=forecasttools.nhsn_hosp_flu,
start_date="2022-08-08",
end_date="2022-12-08",
juris_subset=["TX"],
create_save_directory=False,
show_plot=True,
save_idata=True
forecast_days=28,
save_path="../forecasttools/example_flu_forecast_w_dates.nc",
save_idata=True,
use_log=False,
)
```

The forecast looks like:

![Example NHSN-based Influenza forecast](./assets/example_forecast_w_dates.png){ width=75% }

---

# CDC Open Source Considerations
Expand Down
Binary file added assets/example_forecast_w_dates.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 20 additions & 7 deletions forecasttools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import polars as pl

from .daily_to_epiweekly import df_aggregate_to_epiweekly
from .idata_to_df_w_dates import forecast_as_df_with_dates
from .idata_w_dates_to_df import (
add_dates_as_coords_to_idata,
idata_forecast_w_dates_to_df,
)
from .recode_locations import loc_abbr_to_flusight_code
from .to_flusight import get_flusight_table
from .trajectories_to_quantiles import trajectories_to_quantiles
Expand All @@ -17,7 +20,8 @@

# load example flusight submission
with importlib.resources.path(
__package__, "example_flusight_submission.parquet"
__package__,
"example_flusight_submission.parquet",
) as data_path:
dtypes_d = {"location": pl.Utf8}
example_flusight_submission = pl.read_parquet(data_path)
Expand All @@ -34,20 +38,29 @@
) as data_path:
nhsn_hosp_flu = pl.read_parquet(data_path)

# load light idata NHSN influenza forecast (NHSN, as of 2024-09-26)
# load light idata NHSN influenza forecast wo dates (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__, "example_flu_forecast.nc"
__package__,
"example_flu_forecast_wo_dates.nc",
) as data_path:
nhsn_flu_forecast = az.from_netcdf(data_path)
nhsn_flu_forecast_wo_dates = az.from_netcdf(data_path)

# load light idata NHSN influenza forecast w dates (NHSN, as of 2024-09-26)
with importlib.resources.path(
__package__, "example_flu_forecast_w_dates.nc"
) as data_path:
nhsn_flu_forecast_w_dates = az.from_netcdf(data_path)


__all__ = [
"location_table",
"example_flusight_submission",
"nhsn_hosp_COVID",
"nhsn_hosp_flu",
"nhsn_flu_forecast",
"forecast_as_df_with_dates",
"nhsn_flu_forecast_wo_dates",
"nhsn_flu_forecast_w_dates",
"idata_forecast_w_dates_to_df",
"add_dates_as_coords_to_idata",
"trajectories_to_quantiles",
"df_aggregate_to_epiweekly",
"loc_abbr_to_flusight_code",
Expand Down
21 changes: 10 additions & 11 deletions forecasttools/daily_to_epiweekly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ def calculate_epi_week_and_year(date: str):
An ISO8601 date.
"""
epiweek = epiweeks.Week.fromdate(datetime.strptime(date, "%Y-%m-%d"))
epiweek_df_struct = {"epiweek": epiweek.week, "epiyear": epiweek.year}
epiweek_df_struct = {
"epiweek": epiweek.week,
"epiyear": epiweek.year,
}
return epiweek_df_struct


def df_aggregate_to_epiweekly(
forecast_df: pl.DataFrame,
value_col: str = "value",
date_col: str = "date",
id_cols: list[str] = [".draw"],
id_cols: list[str] = ["draw"],
weekly_value_name: str = "weekly_value",
strict: bool = False,
) -> pl.DataFrame:
Expand All @@ -41,7 +44,7 @@ def df_aggregate_to_epiweekly(
A polars dataframe with draws and dates
as columns. This dataframe will likely
have come from an InferenceData object
that was passed converted using `forecast_as_df_with_dates`.
that was converted using `idata_w_dates_to_df`.
value_col
The name of the column with the fitted
and or forecasted quantity. Defaults
Expand Down Expand Up @@ -70,12 +73,6 @@ def df_aggregate_to_epiweekly(
A dataframe with value_col aggregated
across epiweek and epiyear.
"""
# check intended df columns are in received df
forecast_df_cols = forecast_df.columns
required_cols = [value_col, date_col] + id_cols
assert set(required_cols).issubset(
set(forecast_df_cols)
), f"Column mismatch between require columns {required_cols} and forecast dateframe columns {forecast_df_cols}."
# add epiweek and epiyear columns
forecast_df = forecast_df.with_columns(
pl.col(["date"])
Expand Down Expand Up @@ -105,12 +102,14 @@ def df_aggregate_to_epiweekly(
if strict:
valid_groups = n_elements.filter(pl.col("n_elements") == 7)
forecast_df = forecast_df.join(
valid_groups.select(group_cols), on=group_cols, how="inner"
valid_groups.select(group_cols),
on=group_cols,
how="inner",
)
# aggregate; sum values in the specified value_col
df = (
forecast_df.group_by(group_cols)
.agg(pl.col(value_col).sum().alias(weekly_value_name))
.sort(["epiyear", "epiweek", ".draw"])
.sort(["epiyear", "epiweek", "draw"])
)
return df
4 changes: 3 additions & 1 deletion forecasttools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def check_url(url: str) -> bool:
return False


def check_file_save_path(file_save_path: str) -> None:
def check_file_save_path(
file_save_path: str,
) -> None:
"""
Checks whether a file path is valid.
Expand Down
Binary file added forecasttools/example_flu_forecast_w_dates.nc
Binary file not shown.
File renamed without changes.
94 changes: 0 additions & 94 deletions forecasttools/idata_to_df_w_dates.py

This file was deleted.

Loading

0 comments on commit 558c640

Please sign in to comment.