Skip to content

Commit

Permalink
Merge pull request #6 from reichlab/support_covid
Browse files Browse the repository at this point in the history
support covid
  • Loading branch information
elray1 authored Nov 26, 2024
2 parents 05a2f98 + fc860f4 commit 8f1aacf
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@b94caa9735d010059ea4117631c7b1908adff70d
iddata @ git+https://github.com/reichlab/iddata@7d85f006345c8e17139588d24650509a1eef6af0
# via idmodels (pyproject.toml)
identify==2.6.1
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@b94caa9735d010059ea4117631c7b1908adff70d
iddata @ git+https://github.com/reichlab/iddata@7d85f006345c8e17139588d24650509a1eef6af0
# via idmodels (pyproject.toml)
idna==3.10
# via yarl
Expand Down
16 changes: 11 additions & 5 deletions src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import lightgbm as lgb
import numpy as np
import pandas as pd
from iddata.loader import FluDataLoader
from iddata.loader import DiseaseDataLoader
from tqdm.autonotebook import tqdm

from idmodels.preprocess import create_features_and_targets
Expand Down Expand Up @@ -31,8 +31,8 @@ def run(self, run_config):
ilinet_kwargs = {"scale_to_positive": False}
flusurvnet_kwargs = {"burden_adj": False}

fdl = FluDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date},
fdl = DiseaseDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
sources=self.model_config.sources,
Expand All @@ -41,14 +41,20 @@ def run(self, run_config):
df = df.loc[df["location"].isin(run_config.locations)]

# augment data with features and target values
if run_config.disease == "flu":
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
elif run_config.disease == "covid":
init_feats = ["inc_trans_cs", "log_pop"]

df, feat_names = create_features_and_targets(
df = df,
incl_level_feats=self.model_config.incl_level_feats,
max_horizon=run_config.max_horizon,
curr_feat_names=["inc_trans_cs", "season_week", "log_pop"])
curr_feat_names=init_feats)

# keep only rows that are in-season
df = df.query("season_week >= 5 and season_week <= 45")
if run_config.disease == "flu":
df = df.query("season_week >= 5 and season_week <= 45")

# "test set" df used to generate look-ahead predictions
df_test = df.loc[df.wk_end_date == df.wk_end_date.max()] \
Expand Down
6 changes: 3 additions & 3 deletions src/idmodels/sarix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import numpy as np
import pandas as pd
from iddata.loader import FluDataLoader
from iddata.loader import DiseaseDataLoader
from iddata.utils import get_holidays
from sarix import sarix

Expand All @@ -13,8 +13,8 @@ def __init__(self, model_config):
self.model_config = model_config

def run(self, run_config):
fdl = FluDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date},
fdl = DiseaseDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
if run_config.locations is not None:
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@ def test_gbqr(tmp_path):


run_config = SimpleNamespace(
disease="flu",
ref_date=datetime.date.fromisoformat("2024-01-06"),
output_root=tmp_path / "model-output",
artifact_store_root=tmp_path / "artifact-store",
save_feat_importance=False,
locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
"23", "24", "25", "26", "27", "28", "29", "30", "31", "32",
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
"55", "56", "72"],
max_horizon=3,
q_levels = [0.025, 0.50, 0.975],
q_labels = ["0.025", "0.5", "0.975"],
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_sarix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@ def test_sarix(tmp_path):
)

run_config = SimpleNamespace(
disease="flu",
ref_date=datetime.date.fromisoformat("2024-01-06"),
output_root=tmp_path / "model-output",
artifact_store_root=tmp_path / "artifact-store",
save_feat_importance=False,
locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11",
"12", "13", "15", "16", "17", "18", "19", "20", "21", "22",
"23", "24", "25", "26", "27", "28", "29", "30", "31", "32",
"33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
"44", "45", "46", "47", "48", "49", "50", "51", "53", "54",
"55", "56", "72"],
max_horizon=3,
q_levels = [0.025, 0.50, 0.975],
q_labels = ["0.025", "0.5", "0.975"],
Expand Down

0 comments on commit 8f1aacf

Please sign in to comment.