Skip to content

Commit

Permalink
unpolished model outout formatting using forecasttools; template for …
Browse files Browse the repository at this point in the history
…comparision code
  • Loading branch information
AFg6K7h4fhy2 committed Sep 4, 2024
1 parent 1361a5f commit 4b3480d
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 63 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ repos:
hooks:
- id: actionlint
################################################################################
# - repo: https://github.com/lorenzwalthert/precommit
# rev: v0.4.3
# hooks:
# - id: styler-r
# - id: lintr-r
################################################################################
- repo: https://github.com/crate-ci/typos
rev: v1.21.0
hooks:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
92 changes: 92 additions & 0 deletions model_comparison/output/compare.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
library(forecasttools)
library(tibble)
library(dplyr)
library(readr)
library(lubridate)

# function for taking
run_forecast_with_csvs <- function(
spread_draws_csv,
fitting_data_csv,
output_path,
reference_date,
horizons = -1:3,
seed = NULL
) {
set.seed(seed)
# read data from pyrenew-flu-light run
spread_draws <- read_csv(spread_draws_csv)
fitting_data <- read_csv(fitting_data_csv)
# retrieve locations from fitting data
locations <- unique(fitting_data$location)
# parse over and collect forecasts and fitting data
state_daily_forecast_list <- list()
for (loc in locations) {
state_fitting_data <- fitting_data %>% filter(location == loc)
state_forecast <- spread_draws %>% filter(negbinom_rv_dim_0_index == loc)
# forecasttools convert spread draws tiddy
state_forecast_long <- state_forecast %>%
dplyr::mutate(.draw = draw) %>%
dplyr::select(.draw, date = negbinom_rv_dim_0_index, hosp = negbinom_rv)
# go to epiweekly from daily
state_weekly_forecasts <- forecasttools::daily_to_epiweekly(
tidy_daily_trajectories = state_forecast_long,
value_col = "hosp",
date_col = "date",
id_cols = ".draw"
)
state_daily_forecast_list[[loc]] <- state_weekly_forecasts
}
cli::cli_inform("Formatting output for FluSight...")
# flusight formatting
state_flusight_tables <- list()
full_table <- tibble::tibble()
for (state in names(state_daily_forecast_list)) {
state_flusight_table <- forecasttools::trajectories_to_quantiles(
state_daily_forecast_list[[state]],
timepoint_cols = c("epiweek", "epiyear"),
value_col = "weekly_hosp"
) %>%
dplyr::mutate(
location = forecasttools::loc_abbr_to_flusight_code(state)
) %>%
forecasttools:::get_flusight_table(
reference_date,
horizons = horizons
)

full_table <- dplyr::bind_rows(
full_table,
state_flusight_table
)
}
full_table <- full_table %>%
dplyr::arrange(
location,
reference_date,
horizon,
output_type,
output_type_id
)
# save
readr::write_csv(
full_table,
output_path
)
return(full_table)
}

spread_draws_csv <- "AL_2024-03-30_28_NegBinRv.csv"
fitting_data_csv <- "filtered_data_AL.csv"
output_path <- "flusight_output_AL.csv"
reference_date <- "2024-03-30"
horizons <- -1:3

run_forecast_with_csvs(
spread_draws_csv,
fitting_data_csv,
output_path,
reference_date,
horizons,
seed = 62352
)
2 changes: 1 addition & 1 deletion pyrenew_flu_light/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08
def assert_historical_data_files_exist(
reporting_date: str,
): # numpydoc ignore=GL08
data_directory = f"../data/{reporting_date}/"
data_directory = f"../model_comparison/data/{reporting_date}/"
assert os.path.exists(
data_directory
), f"Data directory {data_directory} does not exist."
Expand Down
40 changes: 40 additions & 0 deletions pyrenew_flu_light/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,43 @@
Save information regarding model performance
and nature of model input and output.
"""


def generate_flusight_formatted_output_forecasttools(input_csv):
pass


def read_to_df_flusight_formatted_output(input_csv_path: str):
pass


def compare_flusight_formatted_forecast_kstest():
pass


# def convert_quantiles_to_draws(
# input_csv_path: str, states: list[str], output_csv_path: str
# ):
# df = pl.read_csv(input_csv_path, infer_schema_length=55000)
# df_filtered = df.filter(pl.col("location").is_in(states))
# num_draws = 2000
# result_rows = []
# for _, group in df_filtered.group_by(
# ["reference_date", "target", "target_end_date", "location"]
# ):
# values = group["value"].to_numpy()
# quantiles = group["output_type_id"].to_numpy()
# # assumption normal = placeholder
# samples = norm.ppf(quantiles, loc=values.mean(), scale=values.std())
# for draw_index in range(num_draws):
# for i, sample in enumerate(samples):
# result_rows.append([draw_index, i, sample])
# df_result = pl.DataFrame(result_rows, schema=["draw", "index", "value"])
# df_result.write_csv(output_csv_path)


# convert_quantiles_to_draws(
# input_csv_path="../model_comparison/data/2024-03-30/2024-03-30-cfarenewal-cfaepimlight.csv",
# states=["02"],
# output_csv_path="test.csv",
# )
161 changes: 99 additions & 62 deletions pyrenew_flu_light/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import logging
import os

import arviz as az
import jax
import numpy as np
import numpyro
import polars as pl

Expand Down Expand Up @@ -83,6 +83,10 @@ def run_single_jurisdiction(
)
logging.info(f"{jurisdiction}: Dataset w/ post-observation ready.")

data_save = f"filtered_data_{jurisdiction}"
if not os.path.exists(data_save):
filtered_data.write_csv(data_save)

# extract jurisdiction population
population = (
filtered_data.select(pl.col("population"))
Expand Down Expand Up @@ -230,30 +234,82 @@ def run_single_jurisdiction(
)


# added in from pyrenew
def spread_draws(
posteriors: dict,
variables_names: list[str] | list[tuple],
) -> pl.DataFrame:
"""
Get nicely shaped draws from the posterior
Given a dictionary of posteriors, return a long-form polars dataframe
indexed by draw, with variable values (equivalent of tidybayes
spread_draws() function).
Parameters
----------
posteriors: dict
A dictionary of posteriors with variable names as keys and numpy
ndarrays as values (with the first axis corresponding to the posterior
draw number.
variables_names: list[str] | list[tuple]
list of strings or of tuples identifying which variables to retrieve.
Returns
-------
pl.DataFrame
A dataframe of draw-indexed
"""

for i_var, v in enumerate(variables_names):
if isinstance(v, str):
v_dims = None
else:
v_dims = v[1:]
v = v[0]

post = posteriors.get(v)
long_post = post.flatten()[..., np.newaxis]

indices = np.array(list(np.ndindex(post.shape)))
n_dims = indices.shape[1] - 1
if v_dims is None:
dim_names = [
("{}_dim_{}_index".format(v, k), pl.Int64)
for k in range(n_dims)
]
elif len(v_dims) != n_dims:
raise ValueError(
"incorrect number of "
"dimension names "
"provided for variable "
"{}".format(v)
)
else:
dim_names = [(v_dim, pl.Int64) for v_dim in v_dims]

p_df = pl.DataFrame(
np.concatenate([indices, long_post], axis=1),
schema=([("draw", pl.Int64)] + dim_names + [(v, pl.Float64)]),
)

if i_var == 0:
df = p_df
else:
df = df.join(
p_df, on=[col for col in df.columns if col in p_df.columns]
)
pass

return df


def main(args): # numpydoc ignore=GL08
"""
The `cfaepim` model required a configuration
file and a dataset. The configuration file must
follow some detailed specifications, as must the
dataset. Once these are in place, the model is
used in the following manner for each state:
(1) extract the population, the indices of the weeks,
the hospitalizations during the first week, & the
covariates, (2) the configuration file and the
previous content then will be used to produce
an Rt, infections, and observation process by
passing them to the `cfaepim` model, (3) the user
can use argparse to test or compare the forecasts.
The `cfaepim` tool is used for runs on hospitalization
data retrieved from an API or stored historically.
Notes
-----
Testing in `cfaepim` includes ensuring the dataset
and configuration have the correct variables and
values in a proper range. Testing also ensures that
each part of the `cfaepim` model works as desired.
python3 tut_epim_port_msr.py --reporting_date 2024-01-20 --regions NY --historical --forecast
pyrenew-flu-light; to run:
python3 tut_epim_port_msr.py --reporting_date 2024-01-20
--regions NY --historical --forecast
python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions AL --historical --forecast
"""
logging.info("Starting CFAEPIM")
Expand All @@ -264,10 +320,9 @@ def main(args): # numpydoc ignore=GL08
numpyro.set_host_device_count(num_cores - (num_cores - 3))
logging.info("Number of cores set.")

# check that output directory exists, if not create
output_directory = pyrenew_flu_light.ensure_output_directory(args)
print(output_directory)
logging.info("Output directory ensured working.")
# # check that output directory exists, if not create
# output_directory = pyrenew_flu_light.ensure_output_directory(args)
# logging.info("Output directory ensured working.")

if args.historical_data:
# check that historical cfaepim data exists for given reporting date
Expand All @@ -282,7 +337,7 @@ def main(args): # numpydoc ignore=GL08
config = pyrenew_flu_light.load_config(config_path=args.use_c)
else:
config = pyrenew_flu_light.load_config(
config_path=f"../config/params_{args.reporting_date}_historical.toml"
config_path=f"../model_comparison/config/params_{args.reporting_date}.toml"
)
logging.info("Configuration (historical) loaded.")

Expand All @@ -302,31 +357,6 @@ def main(args): # numpydoc ignore=GL08
pl.col("date").str.strptime(pl.Date, "%Y-%m-%d")
)

# save plots of the raw hospitalization data,
# for all jurisdictions
if args.data_info_save:
# save pdf of 2, 2x2 (log-scale plots)
# total hospitalizations (full season) & last 4 weeks
# log scale, log scale
# growth rate, moving average
# log-scale, log-scale
# check if this already exist + do for all juris.
pass

if args.model_info_save:
# save model diagram
# save plots for priors
# check if this already exists, do for each config file
# save_numpyro_model(
# save_path=output_directory + "cfaepim_diagram.pdf",
# jurisdiction="NY",
# dataset=influenza_hosp_data,
# config=config,
# forecasting=args.forecast,
# n_post_observation_days=28,
# )
pass

# parallel run over jurisdictions
# results = dict([(elt, {}) for elt in args.regions])
forecast_days = 28
Expand All @@ -348,23 +378,30 @@ def main(args): # numpydoc ignore=GL08
forecasting=args.forecast,
n_post_observation_days=forecast_days,
)

idata = az.from_numpyro(
posterior=model.mcmc,
prior=prior_p_ss,
posterior_predictive=post_p_fs,
constant_data={"obs": obs},
print(prior_p_ss["negbinom_rv"])
save_path_samples = f"{jurisdiction}_{args.reporting_date}_{forecast_days}_NegBinRv.csv"
df = spread_draws(
posteriors=post_p_fs, variables_names=["negbinom_rv"]
)
if not os.path.exists(save_path_samples):
df.write_csv(save_path_samples)

# idata = az.from_numpyro(
# posterior=model.mcmc,
# prior=prior_p_ss,
# posterior_predictive=post_p_fs,
# constant_data={"obs": obs},
# )
save_path = f"{jurisdiction}_{args.reporting_date}_{forecast_days}_Ahead.csv"
if not os.path.exists(save_path):
df = pl.DataFrame(
{k: v.__array__() for k, v in post_p_fs.items()}
)
df.write_csv(save_path)

if not args.forecast:
pyrenew_flu_light.plot_lm_arviz_fit(idata)
pyrenew_flu_light.plot_hdi_arviz_for(idata, forecast_days)
# if not args.forecast:
# pyrenew_flu_light.plot_lm_arviz_fit(idata)
# pyrenew_flu_light.plot_hdi_arviz_for(idata, forecast_days)


if __name__ == "__main__":
Expand Down

0 comments on commit 4b3480d

Please sign in to comment.