diff --git a/README.md b/README.md index 5254ac0..74dfa47 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,7 @@ ⚠️ This is a work in progress. -_`pyrenew-flu-light` is an instantiation of an [Epidemia](https://imperialcollegelondon.github.io/epidemia/) influenza forecasting model in [PyRenew](https://github.com/CDCgov/PyRenew)_ - - +_`pyrenew-flu-light` is an instantiation of an [Epidemia](https://imperialcollegelondon.github.io/epidemia/) influenza forecasting model in [PyRenew](https://github.com/CDCgov/PyRenew)._ NOTE: Presently, this `pyrenew-flu-light` cannot be installed and used with current NHSN, as its author is validating it on historical influenza data, which is . diff --git a/assets/paste_bin.txt b/assets/paste_bin.txt index 1fbd2a1..50a9c8c 100644 --- a/assets/paste_bin.txt +++ b/assets/paste_bin.txt @@ -2,6 +2,45 @@ NOTES + +REMOVE plot and comparison functions for now + +# ax.set_title("Posterior Predictive Plot") + # ax.set_ylabel("Hospital Admissions") + # ax.set_xlabel("Days") + # plt.show() + + # prior_p_ss_figures_and_descriptions = plot_sample_variables( + # samples=prior_p_ss, + # variables=["Rts", "latent_infections", "negbinom_rv"], + # observations=obs, + # ylabels=[ + # "Basic Reproduction Number", + # "Latent Infections", + # "Hospital Admissions", + # ], + # plot_types=["TRACE", "PPC", "HDI"], + # plot_kwargs={ + # "HDI": {"hdi_prob": 0.95, "plot_kwargs": {"ls": "-."}}, + # "TRACE": {"var_names": ["Rts", "latent_infections"]}, + # "PPC": {"alpha": 0.05, "textsize": 12}, + # }, + # ) + + # print(prior_p_ss_figures_and_descriptions) + + # if args.forecasting: + + # prior_p_ss & post_p_ss get their own pdf (markdown first then subprocess) + # each variable is plotted out, if possible + # arviz diagnostics + + + + + + + seeding ("initialization" in MSR lingo): no renewal process, no need for a defined R(t) diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/inf.py b/src/model/inf.py new file mode 100644 index 0000000..c4a178d --- /dev/null +++ b/src/model/inf.py @@ -0,0 +1,165 @@ +import logging + +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike +from pyrenew.latent import logistic_susceptibility_adjustment +from pyrenew.metaclass import RandomVariable + + +class CFAEPIM_Infections(RandomVariable): + """ + Class representing the infection process in + the CFAEPIM model. This class handles the sampling of + infection counts over time, considering the + reproduction number, generation interval, and population size, + while accounting for susceptibility depletion. + + Parameters + ---------- + I0 : ArrayLike + Initial infection counts. + susceptibility_prior : numpyro.distributions + Prior distribution for the susceptibility proportion + (S_{v-1} / P). + """ + + def __init__( + self, + I0: ArrayLike, + susceptibility_prior: numpyro.distributions, + ): # numpydoc ignore=GL08 + logging.info("Initializing CFAEPIM_Infections") + + self.I0 = I0 + self.susceptibility_prior = susceptibility_prior + + @staticmethod + def validate(I0: any, susceptibility_prior: any) -> None: + """ + Validate the parameters of the + infection process. Checks that the initial infections + (I0) and susceptibility_prior are + correctly specified. If any parameter is invalid, + an appropriate error is raised. + + Raises + ------ + TypeError + If I0 is not array-like or + susceptibility_prior is not + a numpyro distribution. + """ + logging.info("Validating CFAEPIM_Infections parameters") + if not isinstance(I0, (np.ndarray, jnp.ndarray)): + raise TypeError( + f"Initial infections (I0) must be an array-like structure; was type {type(I0)}" + ) + + if not isinstance(susceptibility_prior, dist.Distribution): + raise TypeError( + f"susceptibility_prior must be a numpyro distribution; was type {type(susceptibility_prior)}" + ) + + def sample( + self, Rt: ArrayLike, gen_int: ArrayLike, P: float, **kwargs + ) -> tuple: + """ + Given an array of reproduction numbers, + a generation interval, and the size of a + jurisdiction's population, + calculate infections under the scheme + of susceptible depletion. + + Parameters + ---------- + Rt : ArrayLike + Reproduction numbers over time; this is an array of + Rt values for each time step. + gen_int : ArrayLike + Generation interval probability mass function. This is + an array of probabilities representing the + distribution of times between successive infections + in a chain of transmission. + P : float + Population size. This is the total population + size used for susceptibility adjustment. + **kwargs : dict, optional + Additional keyword arguments passed through to internal + sample calls, should there be any. + + Returns + ------- + tuple + A tuple containing two arrays: all_I_t, an array of + latent infections at each time step and all_S_t, an + array of susceptible individuals at each time step. + + Raises + ------ + ValueError + If the length of the initial infections + vector (I0) is less than the length of + the generation interval. + """ + + # get initial infections + I0_samples = self.I0.sample() + I0 = I0_samples[0].value + + logging.debug(f"I0 samples: {I0}") + + # reverse generation interval (recency) + gen_int_rev = jnp.flip(gen_int) + + if I0.size < gen_int.size: + raise ValueError( + "Initial infections vector must be at least as long as " + "the generation interval. " + f"Initial infections vector length: {I0.size}, " + f"generation interval length: {gen_int.size}." + ) + recent_I0 = I0[-gen_int_rev.size :] + + # sample the initial susceptible population proportion S_{v-1} / P from prior + init_S_proportion = numpyro.sample( + "S_v_minus_1_over_P", self.susceptibility_prior + ) + logging.debug(f"Initial susceptible proportion: {init_S_proportion}") + + # calculate initial susceptible population S_{v-1} + init_S = init_S_proportion * P + + def update_infections(carry, Rt): # numpydoc ignore=GL08 + S_t, I_recent = carry + + # compute raw infections + i_raw_t = Rt * jnp.dot(I_recent, gen_int_rev) + + # apply the logistic susceptibility adjustment to a potential new incidence + i_t = logistic_susceptibility_adjustment( + I_raw_t=i_raw_t, frac_susceptible=S_t / P, n_population=P + ) + + # update susceptible population + S_t -= i_t + + # update infections + I_recent = jnp.concatenate([I_recent[:-1], jnp.array([i_t])]) + + return (S_t, I_recent), i_t + + # initial carry state + init_carry = (init_S, recent_I0) + + # scan to iterate over time steps and update infections + (all_S_t, _), all_I_t = numpyro.contrib.control_flow.scan( + update_infections, init_carry, Rt + ) + + logging.debug(f"All infections: {all_I_t}") + logging.debug(f"All susceptibles: {all_S_t}") + + return all_I_t, all_S_t diff --git a/src/model/model.py b/src/model/model.py new file mode 100644 index 0000000..13cda24 --- /dev/null +++ b/src/model/model.py @@ -0,0 +1,266 @@ +# import jax +# import jax.numpy as jnp +# import numpyro +# import numpyro.distributions as dist +# from jax.typing import ArrayLike +# from pyrenew.deterministic import DeterministicPMF +# from pyrenew.latent import ( +# InfectionInitializationProcess, +# InitializeInfectionsFromVec, +# ) +# from pyrenew.metaclass import ( +# DistributionalRV, +# Model, +# SampledValue, +# ) + + +# class CFAEPIM_Model_Sample(NamedTuple): # numpydoc ignore=GL08 +# Rts: SampledValue | None = None +# latent_infections: SampledValue | None = None +# susceptibles: SampledValue | None = None +# ascertainment_rates: SampledValue | None = None +# expected_hospitalizations: SampledValue | None = None + +# def __repr__(self): +# return ( +# f"CFAEPIM_Model_Sample(Rts={self.Rts}, " +# f"latent_infections={self.latent_infections}, " +# f"susceptibles={self.susceptibles}, " +# f"ascertainment_rates={self.ascertainment_rates}, " +# f"expected_hospitalizations={self.expected_hospitalizations}" +# ) + + +# class CFAEPIM_Model(Model): +# """ +# CFAEPIM Model class for epidemic inference, +# ported over from `cfaepim`. This class handles the +# initialization and sampling of the CFAEPIM model, +# including the transmission process, infection process, +# and observation process. + +# Parameters +# ---------- +# config : dict[str, any] +# Configuration dictionary containing model parameters. +# population : int +# Total population size. +# week_indices : ArrayLike +# Array of week indices corresponding to the time steps. +# first_week_hosp : int +# Number of hospitalizations in the first week. +# predictors : list[int] +# List of predictors (covariates) for the model. +# data_observed_hosp_admissions : pl.DataFrame +# DataFrame containing observed hospital admissions data. +# """ + +# def __init__( +# self, +# config: dict[str, any], +# population: int, +# week_indices: ArrayLike, +# first_week_hosp: int, +# predictors: list[int], +# ): # numpydoc ignore=GL08 +# self.population = population +# self.week_indices = week_indices +# self.first_week_hosp = first_week_hosp +# self.predictors = predictors + +# self.config = config +# for key, value in config.items(): +# setattr(self, key, value) + +# # transmission: generation time distribution +# self.pmf_array = jnp.array(self.generation_time_dist) +# self.gen_int = DeterministicPMF(name="gen_int", value=self.pmf_array) +# # update: record in sample ought to be False by default + +# # transmission: prior for RW intercept +# self.intercept_RW_prior = dist.Normal( +# self.rt_intercept_prior_mode, self.rt_intercept_prior_scale +# ) + +# # transmission: Rt process +# self.Rt_process = CFAEPIM_Rt( +# intercept_RW_prior=self.intercept_RW_prior, +# max_rt=self.max_rt, +# gamma_RW_prior_scale=self.weekly_rw_prior_scale, +# week_indices=self.week_indices, +# ) + +# # infections: get value rate for infection seeding (initialization) +# self.mean_inf_val = ( +# self.inf_model_prior_infections_per_capita * self.population +# ) + (self.first_week_hosp / (self.ihr_intercept_prior_mode * 7)) + +# # infections: initial infections +# self.I0 = InfectionInitializationProcess( +# name="I0_initialization", +# I_pre_init_rv=DistributionalRV( +# name="I0", +# dist=dist.Exponential(rate=1 / self.mean_inf_val).expand( +# [self.inf_model_seed_days] +# ), +# ), +# infection_init_method=InitializeInfectionsFromVec( +# n_timepoints=self.inf_model_seed_days +# ), +# t_unit=1, +# ) + +# # infections: susceptibility depletion prior +# # update: truncated Normal needed here, done +# # "under the hood" in Epidemia, use Beta for the +# # time being. +# # self.susceptibility_prior = dist.Beta( +# # 1 +# # + ( +# # self.susceptible_fraction_prior_mode +# # / self.susceptible_fraction_prior_scale +# # ), +# # 1 +# # + (1 - self.susceptible_fraction_prior_mode) +# # / self.susceptible_fraction_prior_scale, +# # ) +# # now: +# self.susceptibility_prior = dist.TruncatedNormal( +# self.susceptible_fraction_prior_mode, +# self.susceptible_fraction_prior_scale, +# low=0.0, +# ) + +# # infections component +# self.infections = CFAEPIM_Infections( +# I0=self.I0, susceptibility_prior=self.susceptibility_prior +# ) + +# # observations: negative binomial concentration prior +# self.nb_concentration_prior = dist.Normal( +# self.reciprocal_dispersion_prior_mode, +# self.reciprocal_dispersion_prior_scale, +# ) + +# # observations: instantaneous ascertainment rate prior +# self.alpha_prior_dist = dist.Normal( +# self.ihr_intercept_prior_mode, self.ihr_intercept_prior_scale +# ) + +# # observations: prior on covariate coefficients +# self.coefficient_priors = dist.Normal( +# loc=jnp.array( +# self.day_of_week_effect_prior_modes +# + [ +# self.holiday_eff_prior_mode, +# self.post_holiday_eff_prior_mode, +# self.non_obs_effect_prior_mode, +# ] +# ), +# scale=jnp.array( +# self.day_of_week_effect_prior_scales +# + [ +# self.holiday_eff_prior_scale, +# self.post_holiday_eff_prior_scale, +# self.non_obs_effect_prior_scale, +# ] +# ), +# ) + +# # observations component +# self.obs_process = CFAEPIM_Observation( +# predictors=self.predictors, +# alpha_prior_dist=self.alpha_prior_dist, +# coefficient_priors=self.coefficient_priors, +# nb_concentration_prior=self.nb_concentration_prior, +# ) + +# @staticmethod +# def validate( +# population: any, +# week_indices: any, +# first_week_hosp: any, +# predictors: any, +# ) -> None: +# """ +# Validate the parameters of the CFAEPIM model. + +# This method checks that all necessary parameters and priors are correctly specified. +# If any parameter is invalid, an appropriate error is raised. + +# Raises +# ------ +# ValueError +# If any parameter is missing or invalid. +# """ +# if not isinstance(population, int) or population <= 0: +# raise ValueError("Population must be a positive integer.") +# if not isinstance(week_indices, jax.ndarray): +# raise ValueError("Week indices must be an array-like structure.") +# if not isinstance(first_week_hosp, int) or first_week_hosp < 0: +# raise ValueError( +# "First week hospitalizations must be a non-negative integer." +# ) +# if not isinstance(predictors, jnp.ndarray): +# raise ValueError("Predictors must be a list of integers.") + +# def sample( +# self, +# n_steps: int, +# data_observed_hosp_admissions: ArrayLike = None, +# **kwargs, +# ) -> tuple: +# # shift towards "reduced statefulness", include here week indices & +# # predictors which might change; for the same model and different +# # models. +# """ +# Samples the reproduction numbers, generation interval, +# infections, and hospitalizations from the CFAEPIM model. + +# Parameters +# ---------- +# n_steps : int +# Number of time steps to sample. +# data_observed_hosp_admissions : ArrayLike, optional +# Observation hospital admissions. +# Defaults to None. +# **kwargs : dict, optional +# Additional keyword arguments passed through to +# internal sample calls, should there be any. + +# Returns +# ------- +# CFAEPIM_Model_Sample +# A named tuple containing sampled values for reproduction numbers, +# latent infections, susceptibles, ascertainment rates, expected +# hospitalizations, and observed hospital admissions. +# """ +# sampled_Rts = self.Rt_process.sample(n_steps=n_steps) +# sampled_gen_int = self.gen_int.sample(record=False) +# all_I_t, all_S_t = self.infections.sample( +# Rt=sampled_Rts, +# gen_int=sampled_gen_int[0].value, +# P=self.population, +# ) +# sampled_alphas, expected_hosps = self.obs_process.sample( +# infections=all_I_t, +# inf_to_hosp_dist=jnp.array(self.inf_to_hosp_dist), +# ) +# # observed_hosp_admissions = self.obs_process.nb_observation.sample( +# # mu=expected_hosps, +# # obs=data_observed_hosp_admissions, +# # **kwargs, +# # ) +# numpyro.deterministic("Rts", sampled_Rts) +# numpyro.deterministic("latent_infections", all_I_t) +# numpyro.deterministic("susceptibles", all_S_t) +# numpyro.deterministic("alphas", sampled_alphas) +# numpyro.deterministic("expected_hospitalizations", expected_hosps) +# return CFAEPIM_Model_Sample( +# Rts=sampled_Rts, +# latent_infections=all_I_t, +# susceptibles=all_S_t, +# ascertainment_rates=sampled_alphas, +# expected_hospitalizations=expected_hosps, +# ) diff --git a/src/model/obs.py b/src/model/obs.py new file mode 100644 index 0000000..257cc39 --- /dev/null +++ b/src/model/obs.py @@ -0,0 +1,163 @@ +import logging + +import jax.numpy as jnp +import numpy as np +import numpyro.distributions as dist +import pyrenew.transformation as t +from jax.typing import ArrayLike +from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.observation import NegativeBinomialObservation +from pyrenew.regression import GLMPrediction + + +class CFAEPIM_Observation(RandomVariable): + """ + Class representing the observation process + in the CFAEPIM model. This class handles the generation + of the alpha (instantaneous ascertaintment rate) process + and the negative binomial observation process for + modeling hospitalizations from latent infections. + + Parameters + ---------- + predictors : ArrayLike + Array of predictor (covariates) values for the alpha process. + alpha_prior_dist : numpyro.distributions + Prior distribution for the intercept in the alpha process. + coefficient_priors : numpyro.distributions + Prior distributions for the coefficients in the alpha process. + nb_concentration_prior : numpyro.distributions + Prior distribution for the concentration parameter of + the negative binomial distribution. + """ + + def __init__( + self, + predictors, + alpha_prior_dist, + coefficient_priors, + nb_concentration_prior, + ): # numpydoc ignore=GL08 + logging.info("Initializing CFAEPIM_Observation") + + CFAEPIM_Observation.validate( + predictors, + alpha_prior_dist, + coefficient_priors, + nb_concentration_prior, + ) + + self.predictors = predictors + self.alpha_prior_dist = alpha_prior_dist + self.coefficient_priors = coefficient_priors + self.nb_concentration_prior = nb_concentration_prior + + self._init_alpha_t() + self._init_negative_binomial() + + def _init_alpha_t(self): + """ + Initialize the alpha process using a generalized + linear model (GLM) (transformed linear predictor). + The transform is set to the inverse of the sigmoid + transformation. + """ + logging.info("Initializing alpha process") + self.alpha_process = GLMPrediction( + name="alpha_t", + fixed_predictor_values=self.predictors, + intercept_prior=self.alpha_prior_dist, + coefficient_priors=self.coefficient_priors, + transform=t.SigmoidTransform().inv, + ) + + def _init_negative_binomial(self): + """ + Sets up the negative binomial + distribution for modeling hospitalizations + with a prior on the concentration parameter. + """ + logging.info("Initializing negative binomial process") + self.nb_observation = NegativeBinomialObservation( + name="negbinom_rv", + concentration_rv=DistributionalRV( + name="nb_concentration", + dist=self.nb_concentration_prior, + ), + ) + + @staticmethod + def validate( + predictors: any, + alpha_prior_dist: any, + coefficient_priors: any, + nb_concentration_prior: any, + ) -> None: + """ + Validate the parameters of the CFAEPIM observation process. Checks that + the predictors, alpha prior distribution, coefficient priors, and negative + binomial concentration prior are correctly specified. If any parameter + is invalid, an appropriate error is raised. + """ + logging.info("Validating CFAEPIM_Observation parameters") + if not isinstance(predictors, (np.ndarray, jnp.ndarray)): + raise TypeError( + f"Predictors must be an array-like structure; was type {type(predictors)}" + ) + if not isinstance(alpha_prior_dist, dist.Distribution): + raise TypeError( + f"alpha_prior_dist must be a numpyro distribution; was type {type(alpha_prior_dist)}" + ) + if not isinstance(coefficient_priors, dist.Distribution): + raise TypeError( + f"coefficient_priors must be a numpyro distribution; was type {type(coefficient_priors)}" + ) + if not isinstance(nb_concentration_prior, dist.Distribution): + raise TypeError( + f"nb_concentration_prior must be a numpyro distribution; was type {type(nb_concentration_prior)}" + ) + + def sample( + self, + infections: ArrayLike, + inf_to_hosp_dist: ArrayLike, + **kwargs, + ) -> tuple: + """ + Sample from the observation process. Generates samples + from the alpha process and calculates the expected number + of hospitalizations by convolving the infections with + the infection-to-hospitalization (delay distribution) + distribution. It then samples from the negative binomial + distribution to model the observed + hospitalizations. + + Parameters + ---------- + infections : ArrayLike + Array of infection counts over time. + inf_to_hosp_dist : ArrayLike + Array representing the distribution of times + from infection to hospitalization. + **kwargs : dict, optional + Additional keyword arguments passed through + to internal sample calls, should there be any. + + Returns + ------- + tuple + A tuple containing the sampled instantaneous + ascertainment values and the expected + hospitalizations. + """ + alpha_samples = self.alpha_process.sample()["prediction"] + alpha_samples = alpha_samples[: infections.shape[0]] + expected_hosp = ( + alpha_samples + * jnp.convolve(infections, inf_to_hosp_dist, mode="full")[ + : infections.shape[0] + ] + ) + logging.debug(f"Alpha samples: {alpha_samples}") + logging.debug(f"Expected hospitalizations: {expected_hosp}") + return alpha_samples, expected_hosp diff --git a/src/model/rt.py b/src/model/rt.py new file mode 100644 index 0000000..90d5fec --- /dev/null +++ b/src/model/rt.py @@ -0,0 +1,130 @@ +import logging + +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as dist +import pyrenew.transformation as t +from jax.typing import ArrayLike +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + TransformedRandomVariable, +) +from pyrenew.process import SimpleRandomWalkProcess + + +class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08 + def __init__( + self, + intercept_RW_prior: numpyro.distributions, + max_rt: float, + gamma_RW_prior_scale: float, + week_indices: ArrayLike, + ): # numpydoc ignore=GL08 + """ + Initialize the CFAEPIM_Rt class. + + Parameters + ---------- + intercept_RW_prior : numpyro.distributions.Distribution + Prior distribution for the random walk intercept. + max_rt : float + Maximum value of the reproduction number. Used as + the scale in the `ScaledLogitTransform()`. + gamma_RW_prior_scale : float + Scale parameter for the HalfNormal distribution + used for random walk standard deviation. + week_indices : ArrayLike + Array of week indices used for broadcasting + the Rt values. + """ + logging.info("Initializing CFAEPIM_Rt") + self.intercept_RW_prior = intercept_RW_prior + self.max_rt = max_rt + self.gamma_RW_prior_scale = gamma_RW_prior_scale + self.week_indices = week_indices + + @staticmethod + def validate( + intercept_RW_prior: any, + max_rt: any, + gamma_RW_prior_scale: any, + week_indices: any, + ) -> None: # numpydoc ignore=GL08 + """ + Validate the parameters of the CFAEPIM_Rt class. + + Raises + ------ + ValueError + If any of the parameters are not valid. + """ + logging.info("Validating CFAEPIM_Rt parameters") + if not isinstance(intercept_RW_prior, dist.Distribution): + raise ValueError( + f"intercept_RW_prior must be a numpyro distribution; was type {type(intercept_RW_prior)}" + ) + if not isinstance(max_rt, (float, int)) or max_rt <= 0: + raise ValueError( + f"max_rt must be a positive number; was type {type(max_rt)}" + ) + if ( + not isinstance(gamma_RW_prior_scale, (float, int)) + or gamma_RW_prior_scale <= 0 + ): + raise ValueError( + f"gamma_RW_prior_scale must be a positive number; was type {type(gamma_RW_prior_scale)}" + ) + if not isinstance(week_indices, (np.ndarray, jnp.ndarray)): + raise ValueError( + f"week_indices must be an array-like structure; was type {type(week_indices)}" + ) + + def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08 + """ + Sample the Rt values using a random walk process + and broadcast them to daily values. + + Parameters + ---------- + n_steps : int + Number of time steps to sample. + **kwargs : dict, optional + Additional keyword arguments passed through to internal sample calls. + + Returns + ------- + ArrayLike + An array containing the broadcasted Rt values. + """ + # sample the standard deviation for the random walk process + sd_wt = numpyro.sample( + "Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale) + ) + # Rt random walk process + wt_rv = SimpleRandomWalkProcess( + name="Wt", + step_rv=DistributionalRV( + name="rw_step_rv", + dist=dist.Normal(0, sd_wt), + reparam=LocScaleReparam(0), + ), + init_rv=DistributionalRV( + name="init_Wt_rv", + dist=self.intercept_RW_prior, + ), + ) + # transform Rt random walk w/ scaled logit + transformed_rt_samples = TransformedRandomVariable( + name="transformed_rt_rw", + base_rv=wt_rv, + transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv, + ).sample(n_steps=n_steps, **kwargs) + # broadcast the Rt samples to daily values + broadcasted_rt_samples = transformed_rt_samples[0].value[ + self.week_indices + ] + logging.debug(f"Broadcasted Rt samples: {broadcasted_rt_samples}") + return broadcasted_rt_samples diff --git a/src/model/tut_epim_port_msr.py b/src/model/tut_epim_port_msr.py index 0794690..0747eee 100644 --- a/src/model/tut_epim_port_msr.py +++ b/src/model/tut_epim_port_msr.py @@ -298,7 +298,7 @@ def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08 def assert_historical_data_files_exist( reporting_date: str, ): # numpydoc ignore=GL08 - data_directory = f"./data/{reporting_date}/" + data_directory = f"../../data/{reporting_date}/" assert os.path.exists( data_directory ), f"Data directory {data_directory} does not exist." @@ -1048,22 +1048,22 @@ def __init__( # update: truncated Normal needed here, done # "under the hood" in Epidemia, use Beta for the # time being. - # self.susceptibility_prior = dist.Beta( - # 1 - # + ( - # self.susceptible_fraction_prior_mode - # / self.susceptible_fraction_prior_scale - # ), - # 1 - # + (1 - self.susceptible_fraction_prior_mode) - # / self.susceptible_fraction_prior_scale, - # ) - # now: - self.susceptibility_prior = dist.TruncatedNormal( - self.susceptible_fraction_prior_mode, - self.susceptible_fraction_prior_scale, - low=0.0, + self.susceptibility_prior = dist.Beta( + 1 + + ( + self.susceptible_fraction_prior_mode + / self.susceptible_fraction_prior_scale + ), + 1 + + (1 - self.susceptible_fraction_prior_mode) + / self.susceptible_fraction_prior_scale, ) + # now: + # self.susceptibility_prior = dist.TruncatedNormal( + # self.susceptible_fraction_prior_mode, + # self.susceptible_fraction_prior_scale, + # low=0.0, + # ) # infections component self.infections = CFAEPIM_Infections( @@ -1698,46 +1698,6 @@ def plot_hdi_arviz_for(idata, forecast_days): # numpydoc ignore=GL08 plt.show() -# def quantilize_forecasts( -# samples_dict, -# state_abbr, -# start_date, -# end_date, -# fitting_data, -# output_path, -# reference_date, -# ): # numpydoc ignore=GL08 -# pandas2ri.activate() -# forecasttools = importr("forecasttools") -# # dplyr = importr("dplyr") -# # tidyr = importr("tidyr") -# # cli = importr("cli") - -# posterior_samples = pl.DataFrame(samples_dict) -# posterior_samples_pd = posterior_samples.to_pandas() -# r_posterior_samples = pandas2ri.py2rpy(posterior_samples_pd) - -# fitting_data_pd = fitting_data.to_pandas() -# r_fitting_data = pandas2ri.py2rpy(fitting_data_pd) - -# results_list = ro.ListVector({state_abbr: r_posterior_samples}) - -# horizons = ro.IntVector([-1, 0, 1, 2, 3]) - -# forecast_output = forecasttools.forecast_and_output_flusight( -# data=r_fitting_data, -# results=results_list, -# output_path=output_path, -# reference_date=reference_date, -# horizons=horizons, -# seed=62352, -# ) - -# forecast_output_pd = pandas2ri.rpy2py(forecast_output) -# forecast_output_pl = pl.from_pandas(forecast_output_pd) -# print(forecast_output_pl) - - def main(args): # numpydoc ignore=GL08 """ The `cfaepim` model required a configuration @@ -1784,7 +1744,7 @@ def main(args): # numpydoc ignore=GL08 # load historical configuration file (modified from cfaepim) config = load_config( - config_path=f"./config/params_{args.reporting_date}_historical.toml" + config_path=f"../../config/params_{args.reporting_date}_historical.toml" ) logging.info("Configuration (historical) loaded.") @@ -1857,7 +1817,7 @@ def main(args): # numpydoc ignore=GL08 constant_data={"obs": obs}, ) print(dir(idata)) - # plot_lm_arviz_fit(idata) + plot_lm_arviz_fit(idata) plot_hdi_arviz_for(idata, forecast_days) # save to folder for jurisdiction, @@ -1869,36 +1829,6 @@ def main(args): # numpydoc ignore=GL08 # ) # print(diagnostic_stats_summary.loc["negbinom_rv"]) - # ax.set_title("Posterior Predictive Plot") - # ax.set_ylabel("Hospital Admissions") - # ax.set_xlabel("Days") - # plt.show() - - # prior_p_ss_figures_and_descriptions = plot_sample_variables( - # samples=prior_p_ss, - # variables=["Rts", "latent_infections", "negbinom_rv"], - # observations=obs, - # ylabels=[ - # "Basic Reproduction Number", - # "Latent Infections", - # "Hospital Admissions", - # ], - # plot_types=["TRACE", "PPC", "HDI"], - # plot_kwargs={ - # "HDI": {"hdi_prob": 0.95, "plot_kwargs": {"ls": "-."}}, - # "TRACE": {"var_names": ["Rts", "latent_infections"]}, - # "PPC": {"alpha": 0.05, "textsize": 12}, - # }, - # ) - - # print(prior_p_ss_figures_and_descriptions) - - # if args.forecasting: - - # prior_p_ss & post_p_ss get their own pdf (markdown first then subprocess) - # each variable is plotted out, if possible - # arviz diagnostics - if __name__ == "__main__": # argparse settings @@ -1941,28 +1871,3 @@ def main(args): # numpydoc ignore=GL08 ) args = parser.parse_args() main(args) - -# TODO -# argparse -# turn off reports -# report(s) generation -# plotting generation -# generalized plotting -# forecasttools formatting -# MCMC utils (numpyro) -# issues x 3 (plotting, inv(), infections docs.) -# tests -# Save MCMC + Samples -# Save to Image as Metadata -# forecast scoring & interpretation -# in reports -# probabilistic statements on what to expect -# relative to historical (several weeks prior + last year) -# evaluate configuration file -# tutorial on usage -# writing again -# notes about what each function must know -# plot objects include: latent infections, observed hospital -# admissions, Rt; plot types includes: dist., density, posterior, -# density comparison, pair plot, posterior predictive check plot, -# HDI plot, diff --git a/src/run.py b/src/run.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/checks.py b/src/utils/checks.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/compare.py b/src/utils/compare.py new file mode 100644 index 0000000..2de8cd3 --- /dev/null +++ b/src/utils/compare.py @@ -0,0 +1,44 @@ +import polars as pl +import rpy2.robjects as ro +from rpy2.robjects import pandas2ri +from rpy2.robjects.packages import importr + + +def quantilize_forecasts( + samples_dict, + state_abbr, + start_date, + end_date, + fitting_data, + output_path, + reference_date, +): # numpydoc ignore=GL08 + pandas2ri.activate() + forecasttools = importr("forecasttools") + # dplyr = importr("dplyr") + # tidyr = importr("tidyr") + # cli = importr("cli") + + posterior_samples = pl.DataFrame(samples_dict) + posterior_samples_pd = posterior_samples.to_pandas() + r_posterior_samples = pandas2ri.py2rpy(posterior_samples_pd) + + fitting_data_pd = fitting_data.to_pandas() + r_fitting_data = pandas2ri.py2rpy(fitting_data_pd) + + results_list = ro.ListVector({state_abbr: r_posterior_samples}) + + horizons = ro.IntVector([-1, 0, 1, 2, 3]) + + forecast_output = forecasttools.forecast_and_output_flusight( + data=r_fitting_data, + results=results_list, + output_path=output_path, + reference_date=reference_date, + horizons=horizons, + seed=62352, + ) + + forecast_output_pd = pandas2ri.rpy2py(forecast_output) + forecast_output_pl = pl.from_pandas(forecast_output_pd) + print(forecast_output_pl) diff --git a/src/utils/pad.py b/src/utils/pad.py new file mode 100644 index 0000000..158c502 --- /dev/null +++ b/src/utils/pad.py @@ -0,0 +1,165 @@ +from datetime import datetime, timedelta + +import polars as pl + +HOLIDAYS = ["2023-11-23", "2023-12-25", "2023-12-31", "2024-01-01"] + + +def add_post_observation_period( + dataset: pl.DataFrame, n_post_observation_days: int +) -> pl.DataFrame: # numpydoc ignore=RT01 + """ + Receives a dataframe that is filtered down to a + particular jurisdiction, that has pre-observation + data, and adds new rows to the end of the dataframe + for the post-observation (forecasting) period. + """ + + # calculate the dates from the latest date in the dataframe + max_date = dataset["date"].max() + post_observation_dates = [ + (max_date + timedelta(days=i)) + for i in range(1, n_post_observation_days + 1) + ] + + # get the days of the week (e.g. Fri) from the calculated dates + day_of_weeks = ( + pl.Series(post_observation_dates) + .dt.strftime("%a") + .alias("day_of_week") + ) + weekends = day_of_weeks.is_in(["Sat", "Sun"]) + + # calculate the epiweeks and epiyears, which might not evenly mod 7 + last_epiweek = dataset["epiweek"][-1] + epiweek_counts = dataset.filter(pl.col("epiweek") == last_epiweek).shape[0] + epiweeks = [last_epiweek] * (7 - epiweek_counts) + [ + (last_epiweek + 1 + (i // 7)) + for i in range(n_post_observation_days - (7 - epiweek_counts)) + ] + last_epiyear = dataset["epiyear"][-1] + epiyears = [ + last_epiyear if epiweek <= 52 else last_epiyear + 1 + for epiweek in epiweeks + ] + epiweeks = [ + epiweek if epiweek <= 52 else epiweek - 52 for epiweek in epiweeks + ] + + # calculate week values + last_week = dataset["week"][-1] + week_counts = dataset.filter(pl.col("week") == last_week).shape[0] + weeks = [last_week] * (7 - week_counts) + [ + (last_week + 1 + (i // 7)) + for i in range(n_post_observation_days - (7 - week_counts)) + ] + weeks = [week if week <= 52 else week - 52 for week in weeks] + + # calculate holiday series + holidays = [datetime.strptime(elt, "%Y-%m-%d") for elt in HOLIDAYS] + holidays_values = [date in holidays for date in post_observation_dates] + post_holidays = [holiday + timedelta(days=1) for holiday in holidays] + post_holiday_values = [ + date in post_holidays for date in post_observation_dates + ] + + # fill in post-observation data entries, zero hospitalizations + post_observation_data = pl.DataFrame( + { + "location": [dataset["location"][0]] * n_post_observation_days, + "date": post_observation_dates, + "hosp": [-9999] * n_post_observation_days, # possible + "epiweek": epiweeks, + "epiyear": epiyears, + "day_of_week": day_of_weeks, + "is_weekend": weekends, + "is_holiday": holidays_values, + "is_post_holiday": post_holiday_values, + "recency": [0] * n_post_observation_days, + "week": weeks, + "location_code": [dataset["location_code"][0]] + * n_post_observation_days, + "population": [dataset["population"][0]] * n_post_observation_days, + "first_week_hosp": [dataset["first_week_hosp"][0]] + * n_post_observation_days, + "nonobservation_period": [False] * n_post_observation_days, + } + ) + + # stack post_observation_data ONTO dataset + merged_data = dataset.vstack(post_observation_data) + return merged_data + + +def add_pre_observation_period( + dataset: pl.DataFrame, n_pre_observation_days: int +) -> pl.DataFrame: # numpydoc ignore=RT01 + """ + Receives a dataframe that is filtered down to a + particular jurisdiction and adds new rows to the + beginning of the dataframe for the non-observation + period. + """ + + # create new nonobs column, set to False by default + dataset = dataset.with_columns( + pl.lit(False).alias("nonobservation_period") + ) + + # backcalculate the dates from the earliest date in the dataframe + min_date = dataset["date"].min() + pre_observation_dates = [ + (min_date - timedelta(days=i)) + for i in range(1, n_pre_observation_days + 1) + ] + pre_observation_dates.reverse() + + # get the days of the week (e.g. Fri) from the backcalculated dates + day_of_weeks = ( + pl.Series(pre_observation_dates).dt.strftime("%a").alias("day_of_week") + ) + weekends = day_of_weeks.is_in(["Sat", "Sun"]) + + # backculate the epiweeks, which might not evenly mod 7 + first_epiweek = dataset["epiweek"][0] + counts = dataset.filter(pl.col("epiweek") == first_epiweek).shape[0] + epiweeks = [first_epiweek] * (7 - counts) + [ + (first_epiweek - 1 - (i // 7)) + for i in range(n_pre_observation_days - (7 - counts)) + ] + epiweeks.reverse() + + # calculate holiday series + holidays = [datetime.strptime(elt, "%Y-%m-%d") for elt in HOLIDAYS] + holidays_values = [date in holidays for date in pre_observation_dates] + post_holidays = [holiday + timedelta(days=1) for holiday in holidays] + post_holiday_values = [ + date in post_holidays for date in pre_observation_dates + ] + + # fill in pre-observation data entries, zero hospitalizations + pre_observation_data = pl.DataFrame( + { + "location": [dataset["location"][0]] * n_pre_observation_days, + "date": pre_observation_dates, + "hosp": [0] * n_pre_observation_days, + "epiweek": epiweeks, + "epiyear": [dataset["epiyear"][0]] * n_pre_observation_days, + "day_of_week": day_of_weeks, + "is_weekend": weekends, + "is_holiday": holidays_values, + "is_post_holiday": post_holiday_values, + "recency": [0] * n_pre_observation_days, + "week": [dataset["week"][0]] * n_pre_observation_days, + "location_code": [dataset["location_code"][0]] + * n_pre_observation_days, + "population": [dataset["population"][0]] * n_pre_observation_days, + "first_week_hosp": [dataset["first_week_hosp"][0]] + * n_pre_observation_days, + "nonobservation_period": [True] * n_pre_observation_days, + } + ) + + # stack dataset ONTO pre_observation_data + merged_data = pre_observation_data.vstack(dataset) + return merged_data diff --git a/src/utils/plot_data.py b/src/utils/plot_data.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/plot_for.py b/src/utils/plot_for.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/plot_mcmc.py b/src/utils/plot_mcmc.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/reports.py b/src/utils/reports.py new file mode 100644 index 0000000..e69de29