Skip to content

Commit

Permalink
Merge pull request #10 from CDCgov/9-remove-extraneous-code-numpydoc-…
Browse files Browse the repository at this point in the history
…ignore=gl08-utf-8-encoding

Ignore Extraneous Comments In Code
  • Loading branch information
AFg6K7h4fhy2 authored Sep 5, 2024
2 parents 4b3480d + 44f4280 commit 230e0f4
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 64 deletions.
175 changes: 175 additions & 0 deletions model_comparison/compare.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
library(dplyr)
library(tidyr)
library(readr)
library(lubridate)
library(forecasttools)

read_posterior_samples <- function(file_path) {
posterior_samples <- read_csv(file_path)
return(posterior_samples)
}

pivot_forecast_to_long <- function(
posterior_samples,
signal_name = "hosp",
time_name = "date"
) {
draws_wide <- tibble::tibble(as.data.frame(posterior_samples))
names(draws_wide) <- posterior_samples$time
draws_long <- draws_wide %>%
dplyr::mutate(.draw = dplyr::row_number()) %>%
tidyr::pivot_longer(cols = -.draw, names_to = time_name, values_to = signal_name)
return(draws_long)
}
aggregate_to_epiweekly <- function(tidy_daily_trajectories) {
epiweekly_forecasts <- forecasttools::daily_to_epiweekly(
tidy_daily_trajectories,
value_col = "hosp",
id_cols = c(".draw"),
weekly_value_name = "weekly_hosp"
)
return(epiweekly_forecasts)
}

output_flusight_table <- function(
weekly_forecasts,
reference_date,
horizons,
output_path
) {
formatted_output <- forecasttools::trajectories_to_quantiles(
weekly_forecasts,
timepoint_cols = c("epiweek", "epiyear"),
value_col = "weekly_hosp"
) %>%
dplyr::mutate(
location = forecasttools::loc_abbr_to_flusight_code(state)
) %>%
forecasttools::get_flusight_table(
reference_date,
horizons = horizons
)

readr::write_csv(formatted_output, output_path)
return(formatted_output)
}

generate_flusight_output <- function(file_paths, reference_date, horizons, output_path) {
for (file_path in file_paths) {
posterior_samples <- read_posterior_samples(file_path)
daily_forecasts <- pivot_forecast_to_long(posterior_samples)
weekly_forecasts <- aggregate_to_epiweekly(daily_forecasts)
formatted_output <- output_flusight_table(weekly_forecasts, reference_date, horizons, output_path)
}
}


file_paths <- c(
"posterior_predictive_forecasts_test_NY_2024-01-20.csv")
reference_date <- "2024-01-20"
horizons <- -1:3
output_path <- "flusight_forecast_output_PFL_test_NY_2024-01-20.csv"

generate_flusight_output(
file_paths,
reference_date,
horizons,
output_path)







# library(forecasttools)
# library(tibble)
# library(dplyr)
# library(readr)
# library(lubridate)

# pyrenew_flusight_forecast_from_csv <- function(
# spread_draws_csv,
# fitting_data_csv,
# output_path,
# reference_date,
# horizons = -1:3,
# seed = NULL
# ) {
# set.seed(seed)
# # read data from pyrenew-flu-light run
# spread_draws <- read_csv(spread_draws_csv)
# fitting_data <- read_csv(fitting_data_csv)
# # retrieve locations from fitting data
# locations <- unique(fitting_data$location)
# # parse over and collect forecasts and fitting data
# state_daily_forecast_list <- list()
# for (loc in locations) {
# state_fitting_data <- fitting_data %>% filter(location == loc)
# state_forecast <- spread_draws %>% filter(negbinom_rv_dim_0_index == loc)
# # forecasttools convert spread draws tiddy
# state_forecast_long <- state_forecast %>%
# dplyr::mutate(.draw = draw) %>%
# dplyr::select(.draw, date = negbinom_rv_dim_0_index, hosp = negbinom_rv)
# # go to epiweekly from daily
# state_weekly_forecasts <- forecasttools::daily_to_epiweekly(
# tidy_daily_trajectories = state_forecast_long,
# value_col = "hosp",
# date_col = "date",
# id_cols = ".draw"
# )
# state_daily_forecast_list[[loc]] <- state_weekly_forecasts
# }
# cli::cli_inform("Formatting output for FluSight...")
# # flusight formatting
# state_flusight_tables <- list()
# full_table <- tibble::tibble()
# for (state in names(state_daily_forecast_list)) {
# state_flusight_table <- forecasttools::trajectories_to_quantiles(
# state_daily_forecast_list[[state]],
# timepoint_cols = c("epiweek", "epiyear"),
# value_col = "weekly_hosp"
# ) %>%
# dplyr::mutate(
# location = forecasttools::loc_abbr_to_flusight_code(state)
# ) %>%
# forecasttools:::get_flusight_table(
# reference_date,
# horizons = horizons
# )

# full_table <- dplyr::bind_rows(
# full_table,
# state_flusight_table
# )
# }
# full_table <- full_table %>%
# dplyr::arrange(
# location,
# reference_date,
# horizon,
# output_type,
# output_type_id
# )
# # save
# readr::write_csv(
# full_table,
# output_path
# )
# return(full_table)
# }

# spread_draws_csv <- "AL_2024-03-30_28_NegBinRv.csv"
# fitting_data_csv <- "filtered_data_AL.csv"
# output_path <- "flusight_output_AL.csv"
# reference_date <- "2024-03-30"
# horizons <- -1:3

# run_forecast_with_csvs(
# spread_draws_csv,
# fitting_data_csv,
# output_path,
# reference_date,
# horizons,
# seed = 62352
# )
2 changes: 0 additions & 2 deletions pyrenew_flu_light/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

from checks import (
assert_historical_data_files_exist,
check_file_path_valid,
Expand Down
6 changes: 2 additions & 4 deletions pyrenew_flu_light/checks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""
Methods to verify pathing and existence
of certain files for use of pyrenew-flu-light.
Expand Down Expand Up @@ -62,7 +60,7 @@ def load_config(config_path: str) -> dict[str, any]:
return config


def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08
def ensure_output_directory(args: dict[str, any]):
output_directory = "./output/"
if not os.path.exists(output_directory):
os.makedirs(output_directory)
Expand All @@ -79,7 +77,7 @@ def ensure_output_directory(args: dict[str, any]): # numpydoc ignore=GL08

def assert_historical_data_files_exist(
reporting_date: str,
): # numpydoc ignore=GL08
):
data_directory = f"../model_comparison/data/{reporting_date}/"
assert os.path.exists(
data_directory
Expand Down
6 changes: 2 additions & 4 deletions pyrenew_flu_light/comp_inf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""
The infections process component in pyrenew-flu-light.
"""
Expand Down Expand Up @@ -36,7 +34,7 @@ def __init__(
self,
I0: ArrayLike,
susceptibility_prior: numpyro.distributions,
): # numpydoc ignore=GL08
):
logging.info("Initializing CFAEPIM_Infections")

self.I0 = I0
Expand Down Expand Up @@ -138,7 +136,7 @@ def sample(
# calculate initial susceptible population S_{v-1}
init_S = init_S_proportion * P

def update_infections(carry, Rt): # numpydoc ignore=GL08
def update_infections(carry, Rt):
S_t, I_recent = carry

# compute raw infections
Expand Down
4 changes: 1 addition & 3 deletions pyrenew_flu_light/comp_obs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""
The observation process component in pyrenew-flu-light.
"""
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(
alpha_prior_dist,
coefficient_priors,
nb_concentration_prior,
): # numpydoc ignore=GL08
):
logging.info("Initializing CFAEPIM_Observation")

# CFAEPIM_Observation.validate(
Expand Down
10 changes: 4 additions & 6 deletions pyrenew_flu_light/comp_tran.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""
The transmission process (Rt) component in pyrenew-flu-light.
"""
Expand All @@ -18,14 +16,14 @@
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable


class CFAEPIM_Rt(RandomVariable): # numpydoc ignore=GL08
class CFAEPIM_Rt(RandomVariable):
def __init__(
self,
intercept_RW_prior: numpyro.distributions,
max_rt: float,
gamma_RW_prior_scale: float,
week_indices: ArrayLike,
): # numpydoc ignore=GL08
):
"""
Initialize the CFAEPIM_Rt class.
Expand Down Expand Up @@ -55,7 +53,7 @@ def validate(
max_rt: any,
gamma_RW_prior_scale: any,
week_indices: any,
) -> None: # numpydoc ignore=GL08
) -> None:
"""
Validate the parameters of the CFAEPIM_Rt class.
Expand Down Expand Up @@ -85,7 +83,7 @@ def validate(
f"week_indices must be an array-like structure; was type {type(week_indices)}"
)

def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
def sample(self, n_steps: int, **kwargs) -> tuple:
"""
Sample the Rt values using a random walk process
and broadcast them to daily values.
Expand Down
4 changes: 1 addition & 3 deletions pyrenew_flu_light/compare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""
Comparisons made between posterior samples.
"""
Expand All @@ -18,7 +16,7 @@ def quantilize_forecasts(
fitting_data,
output_path,
reference_date,
): # numpydoc ignore=GL08
):
pandas2ri.activate()
forecasttools = importr("forecasttools")
# dplyr = importr("dplyr")
Expand Down
Loading

0 comments on commit 230e0f4

Please sign in to comment.