Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infrastructure for running on vintaged NSSP data #87

Merged
merged 21 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions nssp_demo/batch/setup_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime

import polars as pl
from azure.batch import models
Expand All @@ -25,6 +26,10 @@ def main(job_id, pool_id, container_image) -> None:
"source": "nssp-etl",
"target": "/pyrenew-hew/nssp_demo/nssp-etl",
},
{
"source": "nssp-archival-vintages",
"target": "/pyrenew-hew/nssp_demo/nssp-archival-vintages",
},
{
"source": "prod-param-estimates",
"target": "/pyrenew-hew/nssp_demo/params",
Expand All @@ -41,12 +46,16 @@ def main(job_id, pool_id, container_image) -> None:
"python nssp_demo/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 365 "
"--n-training-days 180 "
"--n-warmup 1000 "
"--n-samples 500 "
"--nssp-data-dir nssp_demo/nssp-etl/gold "
"--facility-level-nssp-data-dir nssp_demo/nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp_demo/nssp-archival-vintages/gold "
"--param-data-dir nssp_demo/params "
"--output-data-dir nssp_demo/private_data"
"--output-data-dir nssp_demo/private_data "
"--report-date {report_date:%Y-%m-%d} "
"--last-training-date {last_data_date:%Y-%m-%d}"
"'"
)

Expand All @@ -63,14 +72,27 @@ def main(job_id, pool_id, container_image) -> None:
.to_list()
)

for disease in ["COVID-19", "Influenza"]:
for state in all_states:
task = get_task_config(
f"{job_id}-{state}-{disease}",
base_call=base_call.format(state=state, disease=disease),
container_settings=container_settings,
)
client.task.add(job_id, task)
report_dates = [
datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x)
for x in range(30)
]
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved

for disease in ["Influenza"]:
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
for report_date in report_dates:
last_data_date = report_date - datetime.timedelta(days=5)
for state in all_states:
task = get_task_config(
f"{job_id}-{state}-{disease}-{report_date}",
base_call=base_call.format(
state=state,
disease=disease,
report_date=report_date,
last_data_date=last_data_date,
),
container_settings=container_settings,
)
client.task.add(job_id, task)
pass
pass
pass

Expand Down
1 change: 1 addition & 0 deletions nssp_demo/batch/setup_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(pool_name: str) -> None:
pool_config.mount_configuration = blob.get_node_mount_config(
storage_containers=[
"nssp-etl",
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
"nssp-archival-vintages",
"prod-param-estimates",
"pyrenew-test-output",
],
Expand Down
2 changes: 1 addition & 1 deletion nssp_demo/forecast_non_target_visits.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ n_forecast_days <- argv$n_forecast_days
n_samples <- argv$n_samples

disease_name_nssp_map <- c(
"covid-19" = "COVID-19/Omicron",
"covid-19" = "COVID-19",
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
"influenza" = "Influenza"
)

Expand Down
109 changes: 87 additions & 22 deletions nssp_demo/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,71 @@ def postprocess_forecast(model_run_dir: Path) -> None:
return None


def get_available_reports(
data_dir: str | Path, glob_pattern: str = "*.parquet"
):
return [
datetime.strptime(f.stem, "%Y-%m-%d").date()
for f in Path(data_dir).glob(glob_pattern)
]


def main(
disease: str,
report_date: str,
state: str,
nssp_data_dir: Path | str,
facility_level_nssp_data_dir: Path | str,
state_level_nssp_data_dir: Path | str,
param_data_dir: Path | str,
output_data_dir: Path | str,
n_training_days: int,
n_forecast_days: int,
n_chains: int,
n_warmup: int,
n_samples: int,
last_training_date: str,
exclude_last_n_days: int = 0,
):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

available_facility_level_reports = get_available_reports(
facility_level_nssp_data_dir
)

available_state_level_reports = get_available_reports(
state_level_nssp_data_dir
)
first_available_state_report = min(available_state_level_reports)
last_available_state_report = max(available_state_level_reports)

if report_date == "latest":
report_date = max(
f.stem for f in Path(nssp_data_dir).glob("*.parquet")
report_date = max(available_facility_level_reports)
else:
report_date = datetime.strptime(report_date, "%Y-%m-%d").date()

if report_date in available_state_level_reports:
state_report_date = report_date
elif report_date > last_available_state_report:
state_report_date = last_available_state_report
elif report_date > first_available_state_report:
raise ValueError(
"Dataset appear to be missing some state-level "
f"reports. First entry is {first_available_state_report}, "
f"last is {last_available_state_report}, but no entry "
f"for {report_date}"
)
else:
raise ValueError(
"Requested report date is earlier than the first "
"state-level vintage. This is not currently supported"
)
report_date = datetime.strptime(report_date, "%Y-%m-%d").date()

logger.info(f"Report date: {report_date}")
if state_report_date is not None:
logger.info(f"Using state-level data as of: {state_report_date}")

if last_training_date == "latest":
# + 1 because max date in dataset is report_date - 1
last_training_date = report_date - timedelta(days=1)
else:
last_training_date = datetime.strptime(
last_training_date, "%Y-%m-%d"
).date()
# + 1 because max date in dataset is report_date - 1
last_training_date = report_date - timedelta(days=exclude_last_n_days + 1)

if last_training_date >= report_date:
raise ValueError(
Expand All @@ -91,8 +124,29 @@ def main(
days=n_training_days - 1
)

datafile = f"{report_date}.parquet"
nssp_data = pl.scan_parquet(Path(nssp_data_dir, datafile))
logger.info(f"First training date {first_training_date}")

facility_level_nssp_data, state_level_nssp_data = None, None

if report_date in available_facility_level_reports:
logger.info(
"Facility level data available for " "the given report date"
)
facility_datafile = f"{report_date}.parquet"
facility_level_nssp_data = pl.scan_parquet(
Path(facility_level_nssp_data_dir, facility_datafile)
)
if state_report_date in available_state_level_reports:
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
logger.info("State-level data available for the given report " "date.")
state_datafile = f"{state_report_date}.parquet"
state_level_nssp_data = pl.scan_parquet(
Path(state_level_nssp_data_dir, state_datafile)
)
if facility_level_nssp_data is None and state_level_nssp_data is None:
raise ValueError(
"No data available for the requested report date " f"{report_date}"
)

param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet"))
model_batch_dir_name = (
f"{disease.lower()}_r_{report_date}_f_"
Expand All @@ -109,8 +163,10 @@ def main(
process_and_save_state(
state_abb=state,
disease=disease,
nssp_data=nssp_data,
facility_level_nssp_data=facility_level_nssp_data,
state_level_nssp_data=state_level_nssp_data,
report_date=report_date,
state_level_report_date=state_report_date,
first_training_date=first_training_date,
last_training_date=last_training_date,
param_estimates=param_estimates,
Expand Down Expand Up @@ -176,10 +232,19 @@ def main(
)

parser.add_argument(
"--nssp-data-dir",
"--facility-level-nssp-data-dir",
type=Path,
default=Path("private_data", "nssp_etl_gold"),
help="Directory in which to look for NSSP input data.",
help=(
"Directory in which to look for facility-level NSSP " "ED visit data"
),
)

parser.add_argument(
"--state-level-nssp-data-dir",
type=Path,
default=Path("private_data", "nssp_state_level_gold"),
help=("Directory in which to look for state-level NSSP " "ED visit data."),
)

parser.add_argument(
Expand Down Expand Up @@ -239,12 +304,12 @@ def main(
)

parser.add_argument(
"--last-training-date",
type=str,
default="latest",
"--exclude-last-n-days",
type=int,
default=0,
help=(
"Last date to use for model training in "
"YYYY-MM-DD format or 'latest' (default: latest)."
"Optionally exclude the final n days of available training "
"data (Default: 0, i.e. exclude no available data"
),
)

Expand Down
2 changes: 1 addition & 1 deletion nssp_demo/postprocess_state_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ theme_set(theme_minimal_grid())

disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu")
disease_name_nssp_map <- c(
"covid-19" = "COVID-19/Omicron",
"covid-19" = "COVID-19",
"influenza" = "Influenza"
)

Expand Down
Loading
Loading