Skip to content

Commit

Permalink
packaged structure from single file; begin simple eval. script for R
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Aug 29, 2024
1 parent 0e883d6 commit 7614aa3
Show file tree
Hide file tree
Showing 8 changed files with 673 additions and 534 deletions.
14 changes: 14 additions & 0 deletions notebooks/evaluate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Start of simple eval. script.

if (!requireNamespace("remotes", quietly = TRUE)) {
install.packages("remotes")
}
remotes::install_local("../../cfa-forecasttools")

library(cfaforecasttools)
library(tibble)
library(dplyr)
library(ggplot2)


forecast_data <- read.csv("samples.csv")
83 changes: 75 additions & 8 deletions pyrenew_flu_light/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,73 @@
from pyrenew_flu_light.comp_inf import CFAEPIM_Infections
from pyrenew_flu_light.comp_obs import CFAEPIM_Observation
from pyrenew_flu_light.comp_tran import CFAEPIM_Rt
from pyrenew_flu_light.model import CFAEPIM_Model
from pyrenew_flu_light.pad import (
add_post_observation_period,
add_pre_observation_period,
from checks import (
assert_historical_data_files_exist,
check_file_path_valid,
ensure_output_directory,
load_config,
)
from pyrenew_flu_light.plot import plot_hdi_arviz_for, plot_lm_arviz_fit
from comp_inf import CFAEPIM_Infections
from comp_obs import CFAEPIM_Observation
from comp_tran import CFAEPIM_Rt
from model import CFAEPIM_Model
from pad import add_post_observation_period, add_pre_observation_period
from plot import plot_hdi_arviz_for, plot_lm_arviz_fit
from pre import load_data

JURISDICTIONS = [
"AK",
"AL",
"AR",
"AZ",
"CA",
"CO",
"CT",
"DC",
"DE",
"FL",
"GA",
"HI",
"IA",
"ID",
"IL",
"IN",
"KS",
"KY",
"LA",
"MA",
"MD",
"ME",
"MI",
"MN",
"MO",
"MS",
"MT",
"NC",
"ND",
"NE",
"NH",
"NJ",
"NM",
"NV",
"NY",
"OH",
"OK",
"OR",
"PA",
"PR",
"RI",
"SC",
"SD",
"TN",
"TX",
"US",
"UT",
"VA",
"VI",
"VT",
"WA",
"WI",
"WV",
"WY",
]

__all__ = [
"CFAEPIM_Infections",
Expand All @@ -17,4 +78,10 @@
"add_pre_observation_period",
"plot_hdi_arviz_for",
"plot_lm_arviz_fit",
"JURISDICTIONS",
"assert_historical_data_files_exist",
"check_file_path_valid",
"ensure_output_directory",
"load_config",
"load_data",
]
115 changes: 19 additions & 96 deletions pyrenew_flu_light/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,9 @@

import os

import polar as pl
import toml


def display_data(
data: pl.DataFrame,
n_row_count: int = 15,
n_col_count: int = 5,
first_only: bool = False,
last_only: bool = False,
) -> None:
"""
Display the columns and rows of
a polars dataframe.
Parameters
----------
data : pl.DataFrame
A polars dataframe.
n_row_count : int, optional
How many rows to print.
Defaults to 15.
n_col_count : int, optional
How many columns to print.
Defaults to 15.
first_only : bool, optional
If True, only display the first `n_row_count` rows. Defaults to False.
last_only : bool, optional
If True, only display the last `n_row_count` rows. Defaults to False.
Returns
-------
None
Displays data.
"""
rows, cols = data.shape
assert (
1 <= n_col_count <= cols
), f"Must have reasonable column count; was type {n_col_count}"
assert (
1 <= n_row_count <= rows
), f"Must have reasonable row count; was type {n_row_count}"
assert (
first_only + last_only
) != 2, "Can only do one of last or first only."
if first_only:
data_to_display = data.head(n_row_count)
elif last_only:
data_to_display = data.tail(n_row_count)
else:
data_to_display = data.head(n_row_count)
pl.Config.set_tbl_hide_dataframe_shape(True)
pl.Config.set_tbl_formatting("ASCII_MARKDOWN")
pl.Config.set_tbl_hide_column_data_types(True)
with pl.Config(tbl_rows=n_row_count, tbl_cols=n_col_count):
print(f"Dataset In Use For `cfaepim`:\n{data_to_display}\n")


def check_file_path_valid(file_path: str) -> None:
"""
Checks if a file path is valid. Used to check
Expand All @@ -85,47 +30,6 @@ def check_file_path_valid(file_path: str) -> None:
return None


def load_data(
data_path: str,
sep: str = "\t",
schema_length: int = 10000,
) -> pl.DataFrame:
"""
Loads historical (i.e., `.tsv` data generated
`cfaepim` for a weekly run) data.
Parameters
----------
data_path : str
The path to the tsv file to be read.
sep : str, optional
The separator between values in the
data file. Defaults to tab-separated.
schema_length : int, optional
An approximation of the expected
maximum number of rows. Defaults
to 10000.
Returns
-------
pl.DataFrame
An unvetted polars dataframe of NHSN
hospitalization data.
"""
check_file_path_valid(file_path=data_path)
assert sep in [
"\t",
",",
], f"Separator must be tabs or commas; was type {sep}"
assert (
7500 <= schema_length <= 25000
), f"Schema length must be reasonable; was type {schema_length}"
data = pl.read_csv(
data_path, separator=sep, infer_schema_length=schema_length
)
return data


def load_config(config_path: str) -> dict[str, any]:
"""
Attempts to load config toml file.
Expand Down Expand Up @@ -169,3 +73,22 @@ def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08
if not os.path.exists(output_directory):
os.makedirs(output_directory)
return output_directory


def assert_historical_data_files_exist(
reporting_date: str,
): # numpydoc ignore=GL08
data_directory = f"../data/{reporting_date}/"
assert os.path.exists(
data_directory
), f"Data directory {data_directory} does not exist."
required_files = [
f"{reporting_date}_clean_data.tsv",
f"{reporting_date}_config.toml",
f"{reporting_date}-cfarenewal-cfaepimlight.csv",
]
for file in required_files:
assert os.path.exists(
os.path.join(data_directory, file)
), f"Required file {file} does not exist in {data_directory}."
return data_directory
29 changes: 15 additions & 14 deletions pyrenew_flu_light/comp_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
import pyrenew.regression as r
import pyrenew.transformation as t
from jax.typing import ArrayLike
from pyrenew.metaclass import DistributionalRV, RandomVariable
from pyrenew.metaclass import RandomVariable
from pyrenew.observation import NegativeBinomialObservation
from pyrenew.regression import GLMPrediction
from pyrenew.randomvariable import DistributionalVariable


class CFAEPIM_Observation(RandomVariable):
Expand Down Expand Up @@ -44,12 +45,12 @@ def __init__(
): # numpydoc ignore=GL08
logging.info("Initializing CFAEPIM_Observation")

CFAEPIM_Observation.validate(
predictors,
alpha_prior_dist,
coefficient_priors,
nb_concentration_prior,
)
# CFAEPIM_Observation.validate(
# predictors,
# alpha_prior_dist,
# coefficient_priors,
# nb_concentration_prior,
# )

self.predictors = predictors
self.alpha_prior_dist = alpha_prior_dist
Expand All @@ -67,9 +68,8 @@ def _init_alpha_t(self):
transformation.
"""
logging.info("Initializing alpha process")
self.alpha_process = GLMPrediction(
self.alpha_process = r.GLMPrediction(
name="alpha_t",
fixed_predictor_values=self.predictors,
intercept_prior=self.alpha_prior_dist,
coefficient_priors=self.coefficient_priors,
transform=t.SigmoidTransform().inv,
Expand All @@ -84,9 +84,9 @@ def _init_negative_binomial(self):
logging.info("Initializing negative binomial process")
self.nb_observation = NegativeBinomialObservation(
name="negbinom_rv",
concentration_rv=DistributionalRV(
concentration_rv=DistributionalVariable(
name="nb_concentration",
dist=self.nb_concentration_prior,
distribution=self.nb_concentration_prior,
),
)

Expand Down Expand Up @@ -154,8 +154,9 @@ def sample(
ascertainment values and the expected
hospitalizations.
"""
alpha_samples = self.alpha_process.sample()["prediction"]
alpha_samples = alpha_samples[: infections.shape[0]]

alpha_samples = self.alpha_process.sample(self.predictors)
alpha_samples = alpha_samples[0].value[: infections.shape[0]]
expected_hosp = (
alpha_samples
* jnp.convolve(infections, inf_to_hosp_dist, mode="full")[
Expand Down
27 changes: 12 additions & 15 deletions pyrenew_flu_light/comp_tran.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
import pyrenew.transformation as t
from jax.typing import ArrayLike
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.metaclass import (
DistributionalRV,
RandomVariable,
TransformedRandomVariable,
)
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.metaclass import RandomVariable
from pyrenew.process import RandomWalk
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable


class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08
Expand Down Expand Up @@ -108,24 +105,24 @@ def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
"Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale)
)
# Rt random walk process
wt_rv = SimpleRandomWalkProcess(
init_rv = DistributionalVariable(
name="init_Wt_rv",
distribution=self.intercept_RW_prior,
)
wt_rv = RandomWalk(
name="Wt",
step_rv=DistributionalRV(
step_rv=DistributionalVariable(
name="rw_step_rv",
dist=dist.Normal(0, sd_wt),
distribution=dist.Normal(0, sd_wt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_Wt_rv",
dist=self.intercept_RW_prior,
),
)
# transform Rt random walk w/ scaled logit
transformed_rt_samples = TransformedRandomVariable(
transformed_rt_samples = TransformedVariable(
name="transformed_rt_rw",
base_rv=wt_rv,
transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv,
).sample(n_steps=n_steps, **kwargs)
).sample(n=n_steps, init_vals=init_rv()[0].value)
# broadcast the Rt samples to daily values
broadcasted_rt_samples = transformed_rt_samples[0].value[
self.week_indices
Expand Down
Loading

0 comments on commit 7614aa3

Please sign in to comment.