From 7614aa3ea5561e55c7eb55190ceb6f318ce31df1 Mon Sep 17 00:00:00 2001 From: AFg6K7h4fhy2 <127630341+AFg6K7h4fhy2@users.noreply.github.com> Date: Thu, 29 Aug 2024 12:24:14 -0400 Subject: [PATCH] packaged structure from single file; begin simple eval. script for R --- notebooks/evaluate.R | 14 + pyrenew_flu_light/__init__.py | 83 +++- pyrenew_flu_light/checks.py | 115 +---- pyrenew_flu_light/comp_obs.py | 29 +- pyrenew_flu_light/comp_tran.py | 27 +- pyrenew_flu_light/model.py | 41 +- pyrenew_flu_light/pre.py | 99 ++++ pyrenew_flu_light/run.py | 799 +++++++++++++++++---------------- 8 files changed, 673 insertions(+), 534 deletions(-) create mode 100644 notebooks/evaluate.R diff --git a/notebooks/evaluate.R b/notebooks/evaluate.R new file mode 100644 index 0000000..d626f6e --- /dev/null +++ b/notebooks/evaluate.R @@ -0,0 +1,14 @@ +# Start of simple eval. script. + +if (!requireNamespace("remotes", quietly = TRUE)) { + install.packages("remotes") +} +remotes::install_local("../../cfa-forecasttools") + +library(cfaforecasttools) +library(tibble) +library(dplyr) +library(ggplot2) + + +forecast_data <- read.csv("samples.csv") diff --git a/pyrenew_flu_light/__init__.py b/pyrenew_flu_light/__init__.py index b4e92dd..52a708b 100644 --- a/pyrenew_flu_light/__init__.py +++ b/pyrenew_flu_light/__init__.py @@ -1,12 +1,73 @@ -from pyrenew_flu_light.comp_inf import CFAEPIM_Infections -from pyrenew_flu_light.comp_obs import CFAEPIM_Observation -from pyrenew_flu_light.comp_tran import CFAEPIM_Rt -from pyrenew_flu_light.model import CFAEPIM_Model -from pyrenew_flu_light.pad import ( - add_post_observation_period, - add_pre_observation_period, +from checks import ( + assert_historical_data_files_exist, + check_file_path_valid, + ensure_output_directory, + load_config, ) -from pyrenew_flu_light.plot import plot_hdi_arviz_for, plot_lm_arviz_fit +from comp_inf import CFAEPIM_Infections +from comp_obs import CFAEPIM_Observation +from comp_tran import CFAEPIM_Rt +from model import CFAEPIM_Model +from pad import add_post_observation_period, add_pre_observation_period +from plot import plot_hdi_arviz_for, plot_lm_arviz_fit +from pre import load_data + +JURISDICTIONS = [ + "AK", + "AL", + "AR", + "AZ", + "CA", + "CO", + "CT", + "DC", + "DE", + "FL", + "GA", + "HI", + "IA", + "ID", + "IL", + "IN", + "KS", + "KY", + "LA", + "MA", + "MD", + "ME", + "MI", + "MN", + "MO", + "MS", + "MT", + "NC", + "ND", + "NE", + "NH", + "NJ", + "NM", + "NV", + "NY", + "OH", + "OK", + "OR", + "PA", + "PR", + "RI", + "SC", + "SD", + "TN", + "TX", + "US", + "UT", + "VA", + "VI", + "VT", + "WA", + "WI", + "WV", + "WY", +] __all__ = [ "CFAEPIM_Infections", @@ -17,4 +78,10 @@ "add_pre_observation_period", "plot_hdi_arviz_for", "plot_lm_arviz_fit", + "JURISDICTIONS", + "assert_historical_data_files_exist", + "check_file_path_valid", + "ensure_output_directory", + "load_config", + "load_data", ] diff --git a/pyrenew_flu_light/checks.py b/pyrenew_flu_light/checks.py index 40f812d..bcbc7e1 100644 --- a/pyrenew_flu_light/checks.py +++ b/pyrenew_flu_light/checks.py @@ -5,64 +5,9 @@ import os -import polar as pl import toml -def display_data( - data: pl.DataFrame, - n_row_count: int = 15, - n_col_count: int = 5, - first_only: bool = False, - last_only: bool = False, -) -> None: - """ - Display the columns and rows of - a polars dataframe. - - Parameters - ---------- - data : pl.DataFrame - A polars dataframe. - n_row_count : int, optional - How many rows to print. - Defaults to 15. - n_col_count : int, optional - How many columns to print. - Defaults to 15. - first_only : bool, optional - If True, only display the first `n_row_count` rows. Defaults to False. - last_only : bool, optional - If True, only display the last `n_row_count` rows. Defaults to False. - - Returns - ------- - None - Displays data. - """ - rows, cols = data.shape - assert ( - 1 <= n_col_count <= cols - ), f"Must have reasonable column count; was type {n_col_count}" - assert ( - 1 <= n_row_count <= rows - ), f"Must have reasonable row count; was type {n_row_count}" - assert ( - first_only + last_only - ) != 2, "Can only do one of last or first only." - if first_only: - data_to_display = data.head(n_row_count) - elif last_only: - data_to_display = data.tail(n_row_count) - else: - data_to_display = data.head(n_row_count) - pl.Config.set_tbl_hide_dataframe_shape(True) - pl.Config.set_tbl_formatting("ASCII_MARKDOWN") - pl.Config.set_tbl_hide_column_data_types(True) - with pl.Config(tbl_rows=n_row_count, tbl_cols=n_col_count): - print(f"Dataset In Use For `cfaepim`:\n{data_to_display}\n") - - def check_file_path_valid(file_path: str) -> None: """ Checks if a file path is valid. Used to check @@ -85,47 +30,6 @@ def check_file_path_valid(file_path: str) -> None: return None -def load_data( - data_path: str, - sep: str = "\t", - schema_length: int = 10000, -) -> pl.DataFrame: - """ - Loads historical (i.e., `.tsv` data generated - `cfaepim` for a weekly run) data. - - Parameters - ---------- - data_path : str - The path to the tsv file to be read. - sep : str, optional - The separator between values in the - data file. Defaults to tab-separated. - schema_length : int, optional - An approximation of the expected - maximum number of rows. Defaults - to 10000. - - Returns - ------- - pl.DataFrame - An unvetted polars dataframe of NHSN - hospitalization data. - """ - check_file_path_valid(file_path=data_path) - assert sep in [ - "\t", - ",", - ], f"Separator must be tabs or commas; was type {sep}" - assert ( - 7500 <= schema_length <= 25000 - ), f"Schema length must be reasonable; was type {schema_length}" - data = pl.read_csv( - data_path, separator=sep, infer_schema_length=schema_length - ) - return data - - def load_config(config_path: str) -> dict[str, any]: """ Attempts to load config toml file. @@ -169,3 +73,22 @@ def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08 if not os.path.exists(output_directory): os.makedirs(output_directory) return output_directory + + +def assert_historical_data_files_exist( + reporting_date: str, +): # numpydoc ignore=GL08 + data_directory = f"../data/{reporting_date}/" + assert os.path.exists( + data_directory + ), f"Data directory {data_directory} does not exist." + required_files = [ + f"{reporting_date}_clean_data.tsv", + f"{reporting_date}_config.toml", + f"{reporting_date}-cfarenewal-cfaepimlight.csv", + ] + for file in required_files: + assert os.path.exists( + os.path.join(data_directory, file) + ), f"Required file {file} does not exist in {data_directory}." + return data_directory diff --git a/pyrenew_flu_light/comp_obs.py b/pyrenew_flu_light/comp_obs.py index 6b18210..c9836c0 100644 --- a/pyrenew_flu_light/comp_obs.py +++ b/pyrenew_flu_light/comp_obs.py @@ -7,11 +7,12 @@ import jax.numpy as jnp import numpy as np import numpyro.distributions as dist +import pyrenew.regression as r import pyrenew.transformation as t from jax.typing import ArrayLike -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import RandomVariable from pyrenew.observation import NegativeBinomialObservation -from pyrenew.regression import GLMPrediction +from pyrenew.randomvariable import DistributionalVariable class CFAEPIM_Observation(RandomVariable): @@ -44,12 +45,12 @@ def __init__( ): # numpydoc ignore=GL08 logging.info("Initializing CFAEPIM_Observation") - CFAEPIM_Observation.validate( - predictors, - alpha_prior_dist, - coefficient_priors, - nb_concentration_prior, - ) + # CFAEPIM_Observation.validate( + # predictors, + # alpha_prior_dist, + # coefficient_priors, + # nb_concentration_prior, + # ) self.predictors = predictors self.alpha_prior_dist = alpha_prior_dist @@ -67,9 +68,8 @@ def _init_alpha_t(self): transformation. """ logging.info("Initializing alpha process") - self.alpha_process = GLMPrediction( + self.alpha_process = r.GLMPrediction( name="alpha_t", - fixed_predictor_values=self.predictors, intercept_prior=self.alpha_prior_dist, coefficient_priors=self.coefficient_priors, transform=t.SigmoidTransform().inv, @@ -84,9 +84,9 @@ def _init_negative_binomial(self): logging.info("Initializing negative binomial process") self.nb_observation = NegativeBinomialObservation( name="negbinom_rv", - concentration_rv=DistributionalRV( + concentration_rv=DistributionalVariable( name="nb_concentration", - dist=self.nb_concentration_prior, + distribution=self.nb_concentration_prior, ), ) @@ -154,8 +154,9 @@ def sample( ascertainment values and the expected hospitalizations. """ - alpha_samples = self.alpha_process.sample()["prediction"] - alpha_samples = alpha_samples[: infections.shape[0]] + + alpha_samples = self.alpha_process.sample(self.predictors) + alpha_samples = alpha_samples[0].value[: infections.shape[0]] expected_hosp = ( alpha_samples * jnp.convolve(infections, inf_to_hosp_dist, mode="full")[ diff --git a/pyrenew_flu_light/comp_tran.py b/pyrenew_flu_light/comp_tran.py index d16d0f8..5f0afe1 100644 --- a/pyrenew_flu_light/comp_tran.py +++ b/pyrenew_flu_light/comp_tran.py @@ -11,12 +11,9 @@ import pyrenew.transformation as t from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam -from pyrenew.metaclass import ( - DistributionalRV, - RandomVariable, - TransformedRandomVariable, -) -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import RandomVariable +from pyrenew.process import RandomWalk +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08 @@ -108,24 +105,24 @@ def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08 "Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale) ) # Rt random walk process - wt_rv = SimpleRandomWalkProcess( + init_rv = DistributionalVariable( + name="init_Wt_rv", + distribution=self.intercept_RW_prior, + ) + wt_rv = RandomWalk( name="Wt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", - dist=dist.Normal(0, sd_wt), + distribution=dist.Normal(0, sd_wt), reparam=LocScaleReparam(0), ), - init_rv=DistributionalRV( - name="init_Wt_rv", - dist=self.intercept_RW_prior, - ), ) # transform Rt random walk w/ scaled logit - transformed_rt_samples = TransformedRandomVariable( + transformed_rt_samples = TransformedVariable( name="transformed_rt_rw", base_rv=wt_rv, transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv, - ).sample(n_steps=n_steps, **kwargs) + ).sample(n=n_steps, init_vals=init_rv()[0].value) # broadcast the Rt samples to daily values broadcasted_rt_samples = transformed_rt_samples[0].value[ self.week_indices diff --git a/pyrenew_flu_light/model.py b/pyrenew_flu_light/model.py index 91b169f..5a09ddb 100644 --- a/pyrenew_flu_light/model.py +++ b/pyrenew_flu_light/model.py @@ -16,7 +16,8 @@ InfectionInitializationProcess, InitializeInfectionsFromVec, ) -from pyrenew.metaclass import DistributionalRV, Model, SampledValue +from pyrenew.metaclass import Model, SampledValue +from pyrenew.randomvariable import DistributionalVariable from pyrenew_flu_light import ( CFAEPIM_Infections, @@ -109,11 +110,11 @@ def __init__( # infections: initial infections self.I0 = InfectionInitializationProcess( name="I0_initialization", - I_pre_init_rv=DistributionalRV( + I_pre_init_rv=DistributionalVariable( name="I0", - dist=dist.Exponential(rate=1 / self.mean_inf_val).expand( - [self.inf_model_seed_days] - ), + distribution=dist.Exponential( + rate=1 / self.mean_inf_val + ).expand([self.inf_model_seed_days]), ), infection_init_method=InitializeInfectionsFromVec( n_timepoints=self.inf_model_seed_days @@ -125,22 +126,22 @@ def __init__( # update: truncated Normal needed here, done # "under the hood" in Epidemia, use Beta for the # time being. - # self.susceptibility_prior = dist.Beta( - # 1 - # + ( - # self.susceptible_fraction_prior_mode - # / self.susceptible_fraction_prior_scale - # ), - # 1 - # + (1 - self.susceptible_fraction_prior_mode) - # / self.susceptible_fraction_prior_scale, - # ) - # now: - self.susceptibility_prior = dist.TruncatedNormal( - self.susceptible_fraction_prior_mode, - self.susceptible_fraction_prior_scale, - low=0.0, + self.susceptibility_prior = dist.Beta( + 1 + + ( + self.susceptible_fraction_prior_mode + / self.susceptible_fraction_prior_scale + ), + 1 + + (1 - self.susceptible_fraction_prior_mode) + / self.susceptible_fraction_prior_scale, ) + # now: + # self.susceptibility_prior = dist.TruncatedNormal( + # self.susceptible_fraction_prior_mode, + # self.susceptible_fraction_prior_scale, + # low=0.0, + # ) # infections component self.infections = CFAEPIM_Infections( diff --git a/pyrenew_flu_light/pre.py b/pyrenew_flu_light/pre.py index b3b0cfd..443e9b8 100644 --- a/pyrenew_flu_light/pre.py +++ b/pyrenew_flu_light/pre.py @@ -1,3 +1,102 @@ """ ETL system for pyrenew-flu-light. """ + +import polars as pl + +from pyrenew_flu_light import check_file_path_valid + + +def display_data( + data: pl.DataFrame, + n_row_count: int = 15, + n_col_count: int = 5, + first_only: bool = False, + last_only: bool = False, +) -> None: + """ + Display the columns and rows of + a polars dataframe. + + Parameters + ---------- + data : pl.DataFrame + A polars dataframe. + n_row_count : int, optional + How many rows to print. + Defaults to 15. + n_col_count : int, optional + How many columns to print. + Defaults to 15. + first_only : bool, optional + If True, only display the first `n_row_count` rows. Defaults to False. + last_only : bool, optional + If True, only display the last `n_row_count` rows. Defaults to False. + + Returns + ------- + None + Displays data. + """ + rows, cols = data.shape + assert ( + 1 <= n_col_count <= cols + ), f"Must have reasonable column count; was type {n_col_count}" + assert ( + 1 <= n_row_count <= rows + ), f"Must have reasonable row count; was type {n_row_count}" + assert ( + first_only + last_only + ) != 2, "Can only do one of last or first only." + if first_only: + data_to_display = data.head(n_row_count) + elif last_only: + data_to_display = data.tail(n_row_count) + else: + data_to_display = data.head(n_row_count) + pl.Config.set_tbl_hide_dataframe_shape(True) + pl.Config.set_tbl_formatting("ASCII_MARKDOWN") + pl.Config.set_tbl_hide_column_data_types(True) + with pl.Config(tbl_rows=n_row_count, tbl_cols=n_col_count): + print(f"Dataset In Use For `cfaepim`:\n{data_to_display}\n") + + +def load_data( + data_path: str, + sep: str = "\t", + schema_length: int = 10000, +) -> pl.DataFrame: + """ + Loads historical (i.e., `.tsv` data generated + `cfaepim` for a weekly run) data. + + Parameters + ---------- + data_path : str + The path to the tsv file to be read. + sep : str, optional + The separator between values in the + data file. Defaults to tab-separated. + schema_length : int, optional + An approximation of the expected + maximum number of rows. Defaults + to 10000. + + Returns + ------- + pl.DataFrame + An unvetted polars dataframe of NHSN + hospitalization data. + """ + check_file_path_valid(file_path=data_path) + assert sep in [ + "\t", + ",", + ], f"Separator must be tabs or commas; was type {sep}" + assert ( + 7500 <= schema_length <= 25000 + ), f"Schema length must be reasonable; was type {schema_length}" + data = pl.read_csv( + data_path, separator=sep, infer_schema_length=schema_length + ) + return data diff --git a/pyrenew_flu_light/run.py b/pyrenew_flu_light/run.py index fa01b29..72f71d5 100644 --- a/pyrenew_flu_light/run.py +++ b/pyrenew_flu_light/run.py @@ -1,381 +1,418 @@ -# def run_single_jurisdiction( -# jurisdiction: str, -# dataset: pl.DataFrame, -# config: dict[str, any], -# forecasting: bool = False, -# n_post_observation_days: int = 0, -# ): -# """ -# Runs the ported `cfaepim` model on a single -# jurisdiction. Pre- and post-observation data -# for the Rt burn in and for forecasting, -# respectively, is done before the prior predictive, -# posterior, and posterior predictive samples -# are returned. - -# Parameters -# ---------- -# jurisdiction : str -# The jurisdiction. -# dataset : pl.DataFrame -# The incidence data of interest. -# config : dict[str, any] -# A configuration file for the model. -# forecasting : bool, optional -# Whether or not forecasts are being made. -# Defaults to True. -# n_post_observation_days : int, optional -# The number of days to look ahead. Defaults -# to 0 if not forecasting. - -# Returns -# ------- -# tuple -# A tuple of prior predictive, posterior, and -# posterior predictive samples. -# """ -# # filter data to be the jurisdiction alone -# filtered_data_jurisdiction = dataset.filter( -# pl.col("location") == jurisdiction -# ) - -# # add the pre-observation period to the dataset -# filtered_data = add_pre_observation_period( -# dataset=filtered_data_jurisdiction, -# n_pre_observation_days=config["n_pre_observation_days"], -# ) - -# logging.info(f"{jurisdiction}: Dataset w/ pre-observation ready.") - -# if forecasting: -# # add the post-observation period if forecasting -# filtered_data = add_post_observation_period( -# dataset=filtered_data, -# n_post_observation_days=n_post_observation_days, -# ) -# logging.info(f"{jurisdiction}: Dataset w/ post-observation ready.") - -# # extract jurisdiction population -# population = ( -# filtered_data.select(pl.col("population")) -# .unique() -# .to_numpy() -# .flatten() -# )[0] - -# # extract indices for weeks for Rt broadcasting (weekly to daily) -# week_indices = filtered_data.select(pl.col("week")).to_numpy().flatten() - -# # extract first week hospitalizations for infections seeding -# first_week_hosp = ( -# filtered_data.select(pl.col("first_week_hosp")) -# .unique() -# .to_numpy() -# .flatten() -# )[0] - -# # extract covariates (typically weekday, holidays, nonobs period) -# day_of_week_covariate = ( -# filtered_data.select(pl.col("day_of_week")) -# .to_dummies() -# .select(pl.exclude("day_of_week_Thu")) -# ) -# remaining_covariates = filtered_data.select( -# ["is_holiday", "is_post_holiday", "nonobservation_period"] -# ) -# covariates = pl.concat( -# [day_of_week_covariate, remaining_covariates], how="horizontal" -# ) -# predictors = covariates.to_numpy() - -# # extract observation hospital admissions -# # NOTE: from filtered_data_jurisdiction, not filtered_data, which has null hosp -# observed_hosp_admissions = ( -# filtered_data.select(pl.col("hosp")).to_numpy().flatten() -# ) - -# logging.info(f"{jurisdiction}: Variables extracted from dataset.") - -# # instantiate CFAEPIM model (for fitting) -# total_steps = week_indices.size -# steps_excluding_forecast = total_steps - n_post_observation_days -# cfaepim_MSR_fit = CFAEPIM_Model( -# config=config, -# population=population, -# week_indices=week_indices[:steps_excluding_forecast], -# first_week_hosp=first_week_hosp, -# predictors=predictors[:steps_excluding_forecast], -# ) - -# logging.info(f"{jurisdiction}: CFAEPIM model instantiated (fitting)!") - -# # run the CFAEPIM model -# cfaepim_MSR_fit.run( -# rng_key=jax.random.key(config["seed"]), -# n_steps=steps_excluding_forecast, -# data_observed_hosp_admissions=observed_hosp_admissions[ -# :steps_excluding_forecast -# ], -# num_warmup=config["n_warmup"], -# num_samples=config["n_iter"], -# nuts_args={ -# "target_accept_prob": config["adapt_delta"], -# "max_tree_depth": config["max_treedepth"], -# "init_strategy": numpyro.infer.init_to_sample, -# "find_heuristic_step_size": True, -# }, -# mcmc_args={ -# "num_chains": config["n_chains"], -# "progress_bar": True, -# }, # progress_bar False if use vmap -# ) - -# logging.info(f"{jurisdiction}: CFAEPIM model (fitting) ran!") - -# cfaepim_MSR_fit.print_summary() - -# # prior predictive simulation samples -# prior_predictive_sim_samples = cfaepim_MSR_fit.prior_predictive( -# n_steps=steps_excluding_forecast, -# numpyro_predictive_args={"num_samples": config["n_iter"]}, -# rng_key=jax.random.key(config["seed"]), -# ) - -# logging.info(f"{jurisdiction}: Prior predictive simulation complete.") - -# # posterior predictive simulation samples -# posterior_predictive_sim_samples = cfaepim_MSR_fit.posterior_predictive( -# n_steps=steps_excluding_forecast, -# numpyro_predictive_args={"num_samples": config["n_iter"]}, -# rng_key=jax.random.key(config["seed"]), -# data_observed_hosp_admissions=None, -# ) - -# logging.info(f"{jurisdiction}: Posterior predictive simulation complete.") - -# # posterior predictive forecasting samples -# if forecasting: -# cfaepim_MSR_for = CFAEPIM_Model( -# config=config, -# population=population, -# week_indices=week_indices, -# first_week_hosp=first_week_hosp, -# predictors=predictors, -# ) - -# # run the CFAEPIM model (forecasting, required to do so -# # single `posterior_predictive` gets sames (need self.mcmc) -# # from passed model); -# # ISSUE: inv() -# # PR: sample() + OOP behavior & statefulness -# cfaepim_MSR_for.mcmc = cfaepim_MSR_fit.mcmc - -# posterior_predictive_for_samples = ( -# cfaepim_MSR_for.posterior_predictive( -# n_steps=total_steps, -# numpyro_predictive_args={"num_samples": config["n_iter"]}, -# rng_key=jax.random.key(config["seed"]), -# data_observed_hosp_admissions=None, -# ) -# ) - -# logging.info( -# f"{jurisdiction}: Posterior predictive forecasts complete." -# ) - -# return ( -# cfaepim_MSR_for, -# observed_hosp_admissions, -# prior_predictive_sim_samples, -# posterior_predictive_sim_samples, -# posterior_predictive_for_samples, -# ) -# else: -# posterior_predictive_for_samples = None - -# return ( -# cfaepim_MSR_fit, -# observed_hosp_admissions, -# prior_predictive_sim_samples, -# posterior_predictive_sim_samples, -# posterior_predictive_for_samples, -# ) - - -# def main(args): # numpydoc ignore=GL08 -# """ -# The `cfaepim` model required a configuration -# file and a dataset. The configuration file must -# follow some detailed specifications, as must the -# dataset. Once these are in place, the model is -# used in the following manner for each state: -# (1) extract the population, the indices of the weeks, -# the hospitalizations during the first week, & the -# covariates, (2) the configuration file and the -# previous content then will be used to produce -# an Rt, infections, and observation process by -# passing them to the `cfaepim` model, (3) the user -# can use argparse to test or compare the forecasts. -# The `cfaepim` tool is used for runs on hospitalization -# data retrieved from an API or stored historically. - -# Notes -# ----- -# Testing in `cfaepim` includes ensuring the dataset -# and configuration have the correct variables and -# values in a proper range. Testing also ensures that -# each part of the `cfaepim` model works as desired. -# python3 tut_epim_port_msr.py --reporting_date 2024-01-20 --regions NY --historical --forecast -# python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions AL --historical --forecast -# """ -# logging.info("Starting CFAEPIM") - -# # determine number of CPU cores -# numpyro.set_platform("cpu") -# num_cores = os.cpu_count() -# numpyro.set_host_device_count(num_cores - (num_cores - 3)) -# logging.info("Number of cores set.") - -# # check that output directory exists, if not create -# output_directory = ensure_output_directory(args) -# print(output_directory) -# logging.info("Output directory ensured working.") - -# if args.historical_data: -# # check that historical cfaepim data exists for given reporting date -# historical_data_directory = assert_historical_data_files_exist( -# args.reporting_date -# ) - -# # load historical configuration file (modified from cfaepim) -# if args.use_c != "": -# config = load_config(config_path=args.use_c) -# else: -# config = load_config( -# config_path=f"../config/params_{args.reporting_date}_historical.toml" -# ) -# logging.info("Configuration (historical) loaded.") - -# # load the historical hospitalization data -# data_path = os.path.join( -# historical_data_directory, f"{args.reporting_date}_clean_data.tsv" -# ) -# influenza_hosp_data = load_data(data_path=data_path) -# logging.info("Incidence data (historical) loaded.") -# _, cols = influenza_hosp_data.shape -# # display_data( -# # data=influenza_hosp_data, n_row_count=10, n_col_count=cols -# # ) - -# # modify date column from str to datetime -# influenza_hosp_data = influenza_hosp_data.with_columns( -# pl.col("date").str.strptime(pl.Date, "%Y-%m-%d") -# ) - -# # save plots of the raw hospitalization data, -# # for all jurisdictions -# if args.data_info_save: -# # save pdf of 2, 2x2 (log-scale plots) -# # total hospitalizations (full season) & last 4 weeks -# # log scale, log scale -# # growth rate, moving average -# # log-scale, log-scale -# # check if this already exist + do for all juris. -# pass - -# if args.model_info_save: -# # save model diagram -# # save plots for priors -# # check if this already exists, do for each config file -# save_numpyro_model( -# save_path=output_directory + "cfaepim_diagram.pdf", -# jurisdiction="NY", -# dataset=influenza_hosp_data, -# config=config, -# forecasting=args.forecast, -# n_post_observation_days=28, -# ) - -# # parallel run over jurisdictions -# # results = dict([(elt, {}) for elt in args.regions]) -# forecast_days = 28 -# for jurisdiction in args.regions: -# # check if a folder for the samples exists -# # check if a folder for the jurisdiction exists - -# # assumptions, fit, and forecast for each jurisdiction -# ( -# model, -# obs, -# prior_p_ss, -# post_p_ss, -# post_p_fs, -# ) = run_single_jurisdiction( -# jurisdiction=jurisdiction, -# dataset=influenza_hosp_data, -# config=config, -# forecasting=args.forecast, -# n_post_observation_days=forecast_days, -# ) - -# idata = az.from_numpyro( -# posterior=model.mcmc, -# prior=prior_p_ss, -# posterior_predictive=post_p_fs, -# constant_data={"obs": obs}, -# ) -# if not args.forecast: -# plot_lm_arviz_fit(idata) -# plot_hdi_arviz_for(idata, forecast_days) - - -# if __name__ == "__main__": -# # argparse settings -# # e.g. python3 tut_epim_port_msr.py -# # --reporting_date 2024-01-20 --regions all --historical --forecast -# # python3 tut_epim_port_msr.py -# # --reporting_date 2024-01-20 --regions NY --historical --forecast -# parser = argparse.ArgumentParser( -# description="Forecast, simulate, and analyze the CFAEPIM model." -# ) -# parser.add_argument( -# "--regions", -# type=process_jurisdictions, -# required=True, -# help="Specify jurisdictions as a comma-separated list. Use 'all' for all states, or 'not:state1,state2' to exclude specific states.", -# ) -# parser.add_argument( -# "--reporting_date", -# type=str, -# required=True, -# help="The reporting date.", -# ) -# parser.add_argument( -# "--historical_data", -# action="store_true", -# help="Load model weights before training.", -# ) -# parser.add_argument( -# "--forecast", -# action="store_true", -# help="Whether to make a forecast.", -# ) -# parser.add_argument( -# "--data_info_save", -# action="store_true", -# help="Whether to save information about the dataset.", -# ) -# parser.add_argument( -# "--model_info_save", -# action="store_true", -# help="Whether to save information about the model.", -# ) -# parser.add_argument( -# "--use_c", -# type=str, -# required=False, -# default="", -# help="Config path to external config.", -# ) -# args = parser.parse_args() -# main(args) +import argparse + +# import inspect +import logging +import os + +import arviz as az +import jax +import numpyro +import polars as pl + +import pyrenew_flu_light + + +def process_jurisdictions(value): # numpydoc ignore=GL08 + if value.lower() == "all": + return pyrenew_flu_light.JURISDICTIONS + elif value.lower().startswith("not:"): + exclude = value[4:].split(",") + return [ + state + for state in pyrenew_flu_light.JURISDICTIONS + if state not in exclude + ] + else: + return value.split(",") + + +def run_single_jurisdiction( + jurisdiction: str, + dataset: pl.DataFrame, + config: dict[str, any], + forecasting: bool = False, + n_post_observation_days: int = 0, +): + """ + Runs the ported `cfaepim` model on a single + jurisdiction. Pre- and post-observation data + for the Rt burn in and for forecasting, + respectively, is done before the prior predictive, + posterior, and posterior predictive samples + are returned. + + Parameters + ---------- + jurisdiction : str + The jurisdiction. + dataset : pl.DataFrame + The incidence data of interest. + config : dict[str, any] + A configuration file for the model. + forecasting : bool, optional + Whether or not forecasts are being made. + Defaults to True. + n_post_observation_days : int, optional + The number of days to look ahead. Defaults + to 0 if not forecasting. + + Returns + ------- + tuple + A tuple of prior predictive, posterior, and + posterior predictive samples. + """ + # filter data to be the jurisdiction alone + filtered_data_jurisdiction = dataset.filter( + pl.col("location") == jurisdiction + ) + + # add the pre-observation period to the dataset + filtered_data = pyrenew_flu_light.add_pre_observation_period( + dataset=filtered_data_jurisdiction, + n_pre_observation_days=config["n_pre_observation_days"], + ) + + logging.info(f"{jurisdiction}: Dataset w/ pre-observation ready.") + + if forecasting: + # add the post-observation period if forecasting + filtered_data = pyrenew_flu_light.add_post_observation_period( + dataset=filtered_data, + n_post_observation_days=n_post_observation_days, + ) + logging.info(f"{jurisdiction}: Dataset w/ post-observation ready.") + + # extract jurisdiction population + population = ( + filtered_data.select(pl.col("population")) + .unique() + .to_numpy() + .flatten() + )[0] + + # extract indices for weeks for Rt broadcasting (weekly to daily) + week_indices = filtered_data.select(pl.col("week")).to_numpy().flatten() + + # extract first week hospitalizations for infections seeding + first_week_hosp = ( + filtered_data.select(pl.col("first_week_hosp")) + .unique() + .to_numpy() + .flatten() + )[0] + + # extract covariates (typically weekday, holidays, nonobs period) + day_of_week_covariate = ( + filtered_data.select(pl.col("day_of_week")) + .to_dummies() + .select(pl.exclude("day_of_week_Thu")) + ) + remaining_covariates = filtered_data.select( + ["is_holiday", "is_post_holiday", "nonobservation_period"] + ) + covariates = pl.concat( + [day_of_week_covariate, remaining_covariates], how="horizontal" + ) + predictors = covariates.to_numpy() + + # extract observation hospital admissions + # NOTE: from filtered_data_jurisdiction, not filtered_data, which has null hosp + observed_hosp_admissions = ( + filtered_data.select(pl.col("hosp")).to_numpy().flatten() + ) + + logging.info(f"{jurisdiction}: Variables extracted from dataset.") + + # instantiate CFAEPIM model (for fitting) + total_steps = week_indices.size + steps_excluding_forecast = total_steps - n_post_observation_days + cfaepim_MSR_fit = pyrenew_flu_light.CFAEPIM_Model( + config=config, + population=population, + week_indices=week_indices[:steps_excluding_forecast], + first_week_hosp=first_week_hosp, + predictors=predictors[:steps_excluding_forecast], + ) + + logging.info(f"{jurisdiction}: CFAEPIM model instantiated (fitting)!") + + # run the CFAEPIM model + cfaepim_MSR_fit.run( + rng_key=jax.random.key(config["seed"]), + n_steps=steps_excluding_forecast, + data_observed_hosp_admissions=observed_hosp_admissions[ + :steps_excluding_forecast + ], + num_warmup=config["n_warmup"], + num_samples=config["n_iter"], + nuts_args={ + "target_accept_prob": config["adapt_delta"], + "max_tree_depth": config["max_treedepth"], + "init_strategy": numpyro.infer.init_to_sample, + "find_heuristic_step_size": True, + }, + mcmc_args={ + "num_chains": config["n_chains"], + "progress_bar": True, + }, # progress_bar False if use vmap + ) + + logging.info(f"{jurisdiction}: CFAEPIM model (fitting) ran!") + + cfaepim_MSR_fit.print_summary() + + # prior predictive simulation samples + prior_predictive_sim_samples = cfaepim_MSR_fit.prior_predictive( + n_steps=steps_excluding_forecast, + numpyro_predictive_args={"num_samples": config["n_iter"]}, + rng_key=jax.random.key(config["seed"]), + ) + + logging.info(f"{jurisdiction}: Prior predictive simulation complete.") + + # posterior predictive simulation samples + posterior_predictive_sim_samples = cfaepim_MSR_fit.posterior_predictive( + n_steps=steps_excluding_forecast, + numpyro_predictive_args={"num_samples": config["n_iter"]}, + rng_key=jax.random.key(config["seed"]), + data_observed_hosp_admissions=None, + ) + + logging.info(f"{jurisdiction}: Posterior predictive simulation complete.") + + # posterior predictive forecasting samples + if forecasting: + cfaepim_MSR_for = pyrenew_flu_light.CFAEPIM_Model( + config=config, + population=population, + week_indices=week_indices, + first_week_hosp=first_week_hosp, + predictors=predictors, + ) + + # run the CFAEPIM model (forecasting, required to do so + # single `posterior_predictive` gets sames (need self.mcmc) + # from passed model); + # ISSUE: inv() + # PR: sample() + OOP behavior & statefulness + cfaepim_MSR_for.mcmc = cfaepim_MSR_fit.mcmc + + posterior_predictive_for_samples = ( + cfaepim_MSR_for.posterior_predictive( + n_steps=total_steps, + numpyro_predictive_args={"num_samples": config["n_iter"]}, + rng_key=jax.random.key(config["seed"]), + data_observed_hosp_admissions=None, + ) + ) + + logging.info( + f"{jurisdiction}: Posterior predictive forecasts complete." + ) + + return ( + cfaepim_MSR_for, + observed_hosp_admissions, + prior_predictive_sim_samples, + posterior_predictive_sim_samples, + posterior_predictive_for_samples, + ) + else: + posterior_predictive_for_samples = None + + return ( + cfaepim_MSR_fit, + observed_hosp_admissions, + prior_predictive_sim_samples, + posterior_predictive_sim_samples, + posterior_predictive_for_samples, + ) + + +def main(args): # numpydoc ignore=GL08 + """ + The `cfaepim` model required a configuration + file and a dataset. The configuration file must + follow some detailed specifications, as must the + dataset. Once these are in place, the model is + used in the following manner for each state: + (1) extract the population, the indices of the weeks, + the hospitalizations during the first week, & the + covariates, (2) the configuration file and the + previous content then will be used to produce + an Rt, infections, and observation process by + passing them to the `cfaepim` model, (3) the user + can use argparse to test or compare the forecasts. + The `cfaepim` tool is used for runs on hospitalization + data retrieved from an API or stored historically. + + Notes + ----- + Testing in `cfaepim` includes ensuring the dataset + and configuration have the correct variables and + values in a proper range. Testing also ensures that + each part of the `cfaepim` model works as desired. + python3 tut_epim_port_msr.py --reporting_date 2024-01-20 --regions NY --historical --forecast + python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions AL --historical --forecast + """ + logging.info("Starting CFAEPIM") + + # determine number of CPU cores + numpyro.set_platform("cpu") + num_cores = os.cpu_count() + numpyro.set_host_device_count(num_cores - (num_cores - 3)) + logging.info("Number of cores set.") + + # check that output directory exists, if not create + output_directory = pyrenew_flu_light.ensure_output_directory(args) + print(output_directory) + logging.info("Output directory ensured working.") + + if args.historical_data: + # check that historical cfaepim data exists for given reporting date + historical_data_directory = ( + pyrenew_flu_light.assert_historical_data_files_exist( + args.reporting_date + ) + ) + + # load historical configuration file (modified from cfaepim) + if args.use_c != "": + config = pyrenew_flu_light.load_config(config_path=args.use_c) + else: + config = pyrenew_flu_light.load_config( + config_path=f"../config/params_{args.reporting_date}_historical.toml" + ) + logging.info("Configuration (historical) loaded.") + + # load the historical hospitalization data + data_path = os.path.join( + historical_data_directory, f"{args.reporting_date}_clean_data.tsv" + ) + influenza_hosp_data = pyrenew_flu_light.load_data(data_path=data_path) + logging.info("Incidence data (historical) loaded.") + _, cols = influenza_hosp_data.shape + # display_data( + # data=influenza_hosp_data, n_row_count=10, n_col_count=cols + # ) + + # modify date column from str to datetime + influenza_hosp_data = influenza_hosp_data.with_columns( + pl.col("date").str.strptime(pl.Date, "%Y-%m-%d") + ) + + # save plots of the raw hospitalization data, + # for all jurisdictions + if args.data_info_save: + # save pdf of 2, 2x2 (log-scale plots) + # total hospitalizations (full season) & last 4 weeks + # log scale, log scale + # growth rate, moving average + # log-scale, log-scale + # check if this already exist + do for all juris. + pass + + if args.model_info_save: + # save model diagram + # save plots for priors + # check if this already exists, do for each config file + # save_numpyro_model( + # save_path=output_directory + "cfaepim_diagram.pdf", + # jurisdiction="NY", + # dataset=influenza_hosp_data, + # config=config, + # forecasting=args.forecast, + # n_post_observation_days=28, + # ) + pass + + # parallel run over jurisdictions + # results = dict([(elt, {}) for elt in args.regions]) + forecast_days = 28 + for jurisdiction in args.regions: + # check if a folder for the samples exists + # check if a folder for the jurisdiction exists + + # assumptions, fit, and forecast for each jurisdiction + ( + model, + obs, + prior_p_ss, + post_p_ss, + post_p_fs, + ) = run_single_jurisdiction( + jurisdiction=jurisdiction, + dataset=influenza_hosp_data, + config=config, + forecasting=args.forecast, + n_post_observation_days=forecast_days, + ) + + idata = az.from_numpyro( + posterior=model.mcmc, + prior=prior_p_ss, + posterior_predictive=post_p_fs, + constant_data={"obs": obs}, + ) + save_path = f"{jurisdiction}_{args.reporting_date}_{forecast_days}_Ahead.csv" + if not os.path.exists(save_path): + df = pl.DataFrame( + {k: v.__array__() for k, v in post_p_fs.items()} + ) + df.write_csv(save_path) + + if not args.forecast: + pyrenew_flu_light.plot_lm_arviz_fit(idata) + pyrenew_flu_light.plot_hdi_arviz_for(idata, forecast_days) + + +if __name__ == "__main__": + # argparse settings + # e.g. python3 tut_epim_port_msr.py + # --reporting_date 2024-01-20 --regions all --historical --forecast + # python3 run.py --reporting_date 2024-01-20 --regions NY --historical --forecast + parser = argparse.ArgumentParser( + description="Forecast, simulate, and analyze the CFAEPIM model." + ) + parser.add_argument( + "--regions", + type=process_jurisdictions, + required=True, + help="Specify jurisdictions as a comma-separated list. Use 'all' for all states, or 'not:state1,state2' to exclude specific states.", + ) + parser.add_argument( + "--reporting_date", + type=str, + required=True, + help="The reporting date.", + ) + parser.add_argument( + "--historical_data", + action="store_true", + help="Load model weights before training.", + ) + parser.add_argument( + "--forecast", + action="store_true", + help="Whether to make a forecast.", + ) + parser.add_argument( + "--data_info_save", + action="store_true", + help="Whether to save information about the dataset.", + ) + parser.add_argument( + "--model_info_save", + action="store_true", + help="Whether to save information about the model.", + ) + parser.add_argument( + "--use_c", + type=str, + required=False, + default="", + help="Config path to external config.", + ) + args = parser.parse_args() + main(args)