From c87a1014bec61d67fc84da1c916ccadb2e746b77 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 20 Nov 2024 16:22:25 -0500 Subject: [PATCH 1/6] support covid --- src/idmodels/gbqr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index b838d33..10c80cf 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -32,7 +32,7 @@ def run(self, run_config): flusurvnet_kwargs = {"burden_adj": False} fdl = FluDataLoader() - df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date}, + df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, disease=run_config.disease}, ilinet_kwargs=ilinet_kwargs, flusurvnet_kwargs=flusurvnet_kwargs, sources=self.model_config.sources, From 872f956cf30449cadf2493ff5dc58a35c702fba1 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 20 Nov 2024 16:27:40 -0500 Subject: [PATCH 2/6] update iddata version in requirements --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index cc26fa6..71cb9cf 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -30,7 +30,7 @@ frozenlist==1.5.0 # aiosignal fsspec==2024.10.0 # via s3fs -iddata @ git+https://github.com/reichlab/iddata@b94caa9735d010059ea4117631c7b1908adff70d +iddata @ git+https://github.com/reichlab/iddata@dddfeddefff101f7fafda11245333772aabef7b0 # via idmodels (pyproject.toml) idna==3.10 # via yarl From 8323af1586f3523d34fba97850724aee9dd3b694 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 20 Nov 2024 18:14:14 -0500 Subject: [PATCH 3/6] fix covid support --- src/idmodels/gbqr.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index 10c80cf..d30a2b9 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -3,7 +3,7 @@ import lightgbm as lgb import numpy as np import pandas as pd -from iddata.loader import FluDataLoader +from iddata.loader import DiseaseDataLoader from tqdm.autonotebook import tqdm from idmodels.preprocess import create_features_and_targets @@ -31,7 +31,7 @@ def run(self, run_config): ilinet_kwargs = {"scale_to_positive": False} flusurvnet_kwargs = {"burden_adj": False} - fdl = FluDataLoader() + fdl = DiseaseDataLoader() df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, disease=run_config.disease}, ilinet_kwargs=ilinet_kwargs, flusurvnet_kwargs=flusurvnet_kwargs, @@ -41,14 +41,20 @@ def run(self, run_config): df = df.loc[df["location"].isin(run_config.locations)] # augment data with features and target values + if run_config.disease == "flu": + init_feats = ["inc_trans_cs", "season_week", "log_pop"] + elif run_config.disease == "covid": + init_feats = ["inc_trans_cs", "log_pop"] + df, feat_names = create_features_and_targets( df = df, incl_level_feats=self.model_config.incl_level_feats, max_horizon=run_config.max_horizon, - curr_feat_names=["inc_trans_cs", "season_week", "log_pop"]) + curr_feat_names=init_feats) # keep only rows that are in-season - df = df.query("season_week >= 5 and season_week <= 45") + if run_config.disease == "flu": + df = df.query("season_week >= 5 and season_week <= 45") # "test set" df used to generate look-ahead predictions df_test = df.loc[df.wk_end_date == df.wk_end_date.max()] \ From c7982f9f37940790ea26f093aefcd57403a5aace Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 20 Nov 2024 18:24:29 -0500 Subject: [PATCH 4/6] fix covid support --- src/idmodels/gbqr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index d30a2b9..3d74a0d 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -32,7 +32,7 @@ def run(self, run_config): flusurvnet_kwargs = {"burden_adj": False} fdl = DiseaseDataLoader() - df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, disease=run_config.disease}, + df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, ilinet_kwargs=ilinet_kwargs, flusurvnet_kwargs=flusurvnet_kwargs, sources=self.model_config.sources, From a09605fedb9cb41729053b785688a8d1e6a03544 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 20 Nov 2024 19:49:51 -0500 Subject: [PATCH 5/6] support covid for sarix --- src/idmodels/sarix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index 8775791..3fd904d 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd -from iddata.loader import FluDataLoader +from iddata.loader import DiseaseDataLoader from iddata.utils import get_holidays from sarix import sarix @@ -13,8 +13,8 @@ def __init__(self, model_config): self.model_config = model_config def run(self, run_config): - fdl = FluDataLoader() - df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date}, + fdl = DiseaseDataLoader() + df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, sources=self.model_config.sources, power_transform=self.model_config.power_transform) if run_config.locations is not None: From fc860f4e52e41fdfff005082e13ca5725043e424 Mon Sep 17 00:00:00 2001 From: "Evan L. Ray" Date: Tue, 26 Nov 2024 16:38:59 -0500 Subject: [PATCH 6/6] update to latest iddata api --- requirements/requirements-dev.txt | 2 +- requirements/requirements.txt | 2 +- tests/integration/test_gbqr.py | 7 +++++++ tests/integration/test_sarix.py | 7 +++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 76048a1..4a245a3 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -38,7 +38,7 @@ frozenlist==1.5.0 # aiosignal fsspec==2024.10.0 # via s3fs -iddata @ git+https://github.com/reichlab/iddata@b94caa9735d010059ea4117631c7b1908adff70d +iddata @ git+https://github.com/reichlab/iddata@7d85f006345c8e17139588d24650509a1eef6af0 # via idmodels (pyproject.toml) identify==2.6.1 # via pre-commit diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 71cb9cf..3558e4d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -30,7 +30,7 @@ frozenlist==1.5.0 # aiosignal fsspec==2024.10.0 # via s3fs -iddata @ git+https://github.com/reichlab/iddata@dddfeddefff101f7fafda11245333772aabef7b0 +iddata @ git+https://github.com/reichlab/iddata@7d85f006345c8e17139588d24650509a1eef6af0 # via idmodels (pyproject.toml) idna==3.10 # via yarl diff --git a/tests/integration/test_gbqr.py b/tests/integration/test_gbqr.py index 981c95d..db615f1 100644 --- a/tests/integration/test_gbqr.py +++ b/tests/integration/test_gbqr.py @@ -34,10 +34,17 @@ def test_gbqr(tmp_path): run_config = SimpleNamespace( + disease="flu", ref_date=datetime.date.fromisoformat("2024-01-06"), output_root=tmp_path / "model-output", artifact_store_root=tmp_path / "artifact-store", save_feat_importance=False, + locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", + "12", "13", "15", "16", "17", "18", "19", "20", "21", "22", + "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", + "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", + "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", + "55", "56", "72"], max_horizon=3, q_levels = [0.025, 0.50, 0.975], q_labels = ["0.025", "0.5", "0.975"], diff --git a/tests/integration/test_sarix.py b/tests/integration/test_sarix.py index be551f2..f445a48 100644 --- a/tests/integration/test_sarix.py +++ b/tests/integration/test_sarix.py @@ -38,10 +38,17 @@ def test_sarix(tmp_path): ) run_config = SimpleNamespace( + disease="flu", ref_date=datetime.date.fromisoformat("2024-01-06"), output_root=tmp_path / "model-output", artifact_store_root=tmp_path / "artifact-store", save_feat_importance=False, + locations=["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", + "12", "13", "15", "16", "17", "18", "19", "20", "21", "22", + "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", + "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", + "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", + "55", "56", "72"], max_horizon=3, q_levels = [0.025, 0.50, 0.975], q_labels = ["0.025", "0.5", "0.975"],