diff --git a/src/iddata/loader.py b/src/iddata/loader.py index 21ac167..9796149 100644 --- a/src/iddata/loader.py +++ b/src/iddata/loader.py @@ -10,7 +10,7 @@ from iddata import utils -class FluDataLoader(): +class DiseaseDataLoader(): def __init__(self) -> None: self.data_raw = "https://infectious-disease-data.s3.amazonaws.com/data-raw/" @@ -303,10 +303,7 @@ def load_ilinet(self, return dat - def load_nhsn(self, rates=True, drop_pandemic_seasons=True, as_of=None): - if not drop_pandemic_seasons: - raise NotImplementedError("Functionality for loading all seasons of NHSN data with specified as_of date is not implemented.") - + def load_nhsn(self, disease="flu", rates=True, drop_pandemic_seasons=True, as_of=None): if as_of is None: as_of = datetime.date.today() @@ -314,9 +311,15 @@ def load_nhsn(self, rates=True, drop_pandemic_seasons=True, as_of=None): as_of = datetime.date.fromisoformat(as_of) if as_of < datetime.date.fromisoformat("2024-11-15"): + if not drop_pandemic_seasons: + raise NotImplementedError("Functionality for loading all seasons of NHSN data with specified as_of date is not implemented.") + + if disease != "flu": + raise NotImplementedError(f"When loading NHSN data with an as_of date prior to 2024-11-15, only disease='flu' is supported; got {str(disease)}.") return self.load_nhsn_from_hhs(rates=rates, as_of=as_of) else: return self.load_nhsn_from_nhsn( + disease=disease, rates=rates, as_of=as_of, drop_pandemic_seasons=drop_pandemic_seasons @@ -353,7 +356,11 @@ def load_nhsn_from_hhs(self, rates=True, as_of=None): return dat - def load_nhsn_from_nhsn(self, rates=True, as_of=None, drop_pandemic_seasons=True): + def load_nhsn_from_nhsn(self, disease="flu", rates=True, as_of=None, drop_pandemic_seasons=True): + valid_diseases = ["flu", "covid"] + if disease not in valid_diseases: + raise ValueError("For NHSN data, the only supported diseases are 'flu' and 'covid'.") + # find the largest stored file dated on or before the as_of date as_of_file_path = f"influenza-nhsn/nhsn-{str(as_of)}.csv" glob_results = s3fs.S3FileSystem(anon=True) \ @@ -363,7 +370,11 @@ def load_nhsn_from_nhsn(self, rates=True, as_of=None, drop_pandemic_seasons=True file_path = all_file_paths[-1] dat = pd.read_csv(self._construct_data_raw_url(file_path)) - dat = dat[["Geographic aggregation", "Week Ending Date", "Total Influenza Admissions"]] + if disease == "flu": + inc_colname = "Total Influenza Admissions" + elif disease == "covid": + inc_colname = "Total COVID-19 Admissions" + dat = dat[["Geographic aggregation", "Week Ending Date"] + [inc_colname]] dat.columns = ["abbreviation", "wk_end_date", "inc"] # rename USA to US diff --git a/tests/iddata/unit/test_load_data.py b/tests/iddata/unit/test_load_data.py index d3c0efe..e7d1758 100644 --- a/tests/iddata/unit/test_load_data.py +++ b/tests/iddata/unit/test_load_data.py @@ -2,11 +2,11 @@ import numpy as np import pytest -from iddata.loader import FluDataLoader +from iddata.loader import DiseaseDataLoader def test_load_data_sources(): - fdl = FluDataLoader() + fdl = DiseaseDataLoader() sources_options = [ ["nhsn"], @@ -29,7 +29,7 @@ def test_load_data_sources(): "2022/23", "2023-12-23") ]) def test_load_data_nhsn_kwargs(test_kwargs, season_expected, wk_end_date_expected): - fdl = FluDataLoader() + fdl = DiseaseDataLoader() df = fdl.load_data(sources=["nhsn"], nhsn_kwargs=test_kwargs) @@ -47,7 +47,7 @@ def test_load_data_nhsn_kwargs(test_kwargs, season_expected, wk_end_date_expecte ({"drop_pandemic_seasons": True}, True) ]) def test_load_data_ilinet_kwargs(test_kwargs, expect_all_na): - fdl = FluDataLoader() + fdl = DiseaseDataLoader() df = fdl.load_data(sources=["ilinet"], ilinet_kwargs=test_kwargs) @@ -63,7 +63,7 @@ def test_load_data_ilinet_kwargs(test_kwargs, expect_all_na): ({"locations": ["California", "Colorado", "Connecticut"]}) ]) def test_load_data_flusurvnet_kwargs(test_kwargs): - fdl = FluDataLoader() + fdl = DiseaseDataLoader() #flusurv_kwargs df = fdl.load_data(sources=["flusurvnet"], flusurvnet_kwargs=test_kwargs)