Skip to content

Commit

Permalink
rename hhs to nhsn in public-facing api
Browse files Browse the repository at this point in the history
  • Loading branch information
elray1 committed Nov 5, 2024
1 parent 6b4228d commit 24e736a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
28 changes: 14 additions & 14 deletions src/iddata/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def load_ilinet(self,
return dat


def load_hhs(self, rates=True, drop_pandemic_seasons=True, as_of=None):
def load_nhsn(self, rates=True, drop_pandemic_seasons=True, as_of=None):
if drop_pandemic_seasons:
if as_of is None:
file_path = "influenza-hhs/hhs.csv"
Expand All @@ -316,7 +316,7 @@ def load_hhs(self, rates=True, drop_pandemic_seasons=True, as_of=None):
file_path = all_file_paths[-1]
else:
if as_of is not None:
raise NotImplementedError("Functionality for loading all seasons of HHS data with specified as_of date is not implemented.")
raise NotImplementedError("Functionality for loading all seasons of NHSN data with specified as_of date is not implemented.")
file_path = "influenza-hhs/hhs_complete.csv"

dat = pd.read_csv(self._construct_data_raw_url(file_path))
Expand All @@ -336,7 +336,7 @@ def load_hhs(self, rates=True, drop_pandemic_seasons=True, as_of=None):

dat["agg_level"] = np.where(dat["location"] == "US", "national", "state")
dat = dat[["agg_level", "location", "season", "season_week", "wk_end_date", "inc"]]
dat["source"] = "hhs"
dat["source"] = "nhsn"
return dat


Expand Down Expand Up @@ -413,18 +413,18 @@ def load_agg_transform_flusurv(self, fips_mappings, **flusurvnet_kwargs):
return df_flusurv


def load_data(self, sources=None, flusurvnet_kwargs=None, hhs_kwargs=None, ilinet_kwargs=None,
def load_data(self, sources=None, flusurvnet_kwargs=None, nhsn_kwargs=None, ilinet_kwargs=None,
power_transform="4rt"):
"""
Load influenza data and transform to a scale suitable for input to models.
Parameters
----------
sources: None or list of sources
data sources to collect. Defaults to ['flusurvnet', 'hhs', 'ilinet'].
data sources to collect. Defaults to ['flusurvnet', 'nhsn', 'ilinet'].
If provided as a list, must be a subset of the defaults.
flusurvnet_kwargs: dictionary of keyword arguments to pass on to `load_flusurv_rates`
hhs_kwargs: dictionary of keyword arguments to pass on to `load_hhs`
nhsn_kwargs: dictionary of keyword arguments to pass on to `load_nhsn`
ilinet_kwargs: dictionary of keyword arguments to pass on to `load_ilinet`
power_transform: string specifying power transform to use: '4rt' or `None`
Expand All @@ -433,13 +433,13 @@ def load_data(self, sources=None, flusurvnet_kwargs=None, hhs_kwargs=None, iline
Pandas DataFrame
"""
if sources is None:
sources = ["flusurvnet", "hhs", "ilinet"]
sources = ["flusurvnet", "nhsn", "ilinet"]

if flusurvnet_kwargs is None:
flusurvnet_kwargs = {}

if hhs_kwargs is None:
hhs_kwargs = {}
if nhsn_kwargs is None:
nhsn_kwargs = {}

if ilinet_kwargs is None:
ilinet_kwargs = {}
Expand All @@ -450,11 +450,11 @@ def load_data(self, sources=None, flusurvnet_kwargs=None, hhs_kwargs=None, iline
us_census = self.load_us_census()
fips_mappings = pd.read_csv(self._construct_data_raw_url("fips-mappings/fips_mappings.csv"))

if "hhs" in sources:
df_hhs = self.load_hhs(**hhs_kwargs)
df_hhs["inc"] = df_hhs["inc"] + 0.75**4
if "nhsn" in sources:
df_nhsn = self.load_nhsn(**nhsn_kwargs)
df_nhsn["inc"] = df_nhsn["inc"] + 0.75**4
else:
df_hhs = None
df_nhsn = None

if "ilinet" in sources:
df_ilinet = self.load_agg_transform_ilinet(fips_mappings=fips_mappings, **ilinet_kwargs)
Expand All @@ -467,7 +467,7 @@ def load_data(self, sources=None, flusurvnet_kwargs=None, hhs_kwargs=None, iline
df_flusurv = None

df = pd.concat(
[df_hhs, df_ilinet, df_flusurv],
[df_nhsn, df_ilinet, df_flusurv],
axis=0).sort_values(["source", "location", "wk_end_date"])

# log population
Expand Down
12 changes: 6 additions & 6 deletions tests/iddata/unit/test_load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ def test_load_data_sources():
fdl = FluDataLoader()

sources_options = [
["hhs"],
["hhs", "ilinet"],
["nhsn"],
["nhsn", "ilinet"],
["flusurvnet"],
["flusurvnet", "hhs", "ilinet"]
["flusurvnet", "nhsn", "ilinet"]
]
for sources in sources_options:
df = fdl.load_data(sources=sources)
assert set(df["source"].unique()) == set(sources)

df = fdl.load_data()
assert set(df["source"].unique()) == {"flusurvnet", "hhs", "ilinet"}
assert set(df["source"].unique()) == {"flusurvnet", "nhsn", "ilinet"}


@pytest.mark.parametrize("test_kwargs, season_expected, wk_end_date_expected", [
Expand All @@ -28,10 +28,10 @@ def test_load_data_sources():
({"drop_pandemic_seasons": True, "as_of": datetime.date.fromisoformat("2023-12-30")},
"2022/23", "2023-12-23")
])
def test_load_data_hhs_kwargs(test_kwargs, season_expected, wk_end_date_expected):
def test_load_data_nhsn_kwargs(test_kwargs, season_expected, wk_end_date_expected):
fdl = FluDataLoader()

df = fdl.load_data(sources=["hhs"], hhs_kwargs=test_kwargs)
df = fdl.load_data(sources=["nhsn"], nhsn_kwargs=test_kwargs)

assert df["season"].min() == season_expected
wk_end_date_actual = str(df["wk_end_date"].max())[:10]
Expand Down

0 comments on commit 24e736a

Please sign in to comment.