Skip to content

Commit

Permalink
use unnest
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Oct 11, 2024
1 parent 3e64e6d commit afbd528
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 31 deletions.
37 changes: 14 additions & 23 deletions forecasttools/daily_to_epiweekly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import polars as pl


def calculate_epidate(date):
epiweek = epiweeks.Week.fromdate(date)
return epiweek.week, epiweek.year
def calculate_epidate(date: str):
week_obj = epiweeks.Week.fromdate(datetime.strptime(date, "%Y-%m-%d"))
return {"epiweek": week_obj.week, "epiyear": week_obj.year}


def daily_to_epiweekly(
Expand All @@ -21,28 +21,19 @@ def daily_to_epiweekly(
Aggregate daily forecast draws to epiweekly.
"""
# check intended df columns are in received df
# forecast_df_cols = forecast_df.columns
# = [value_col, date_col] + id_cols
# assert set(required_cols).issubset(set(forecast_df_cols)),f"Column mismatch between require columns {required_cols} and forecast dateframe columns {forecast_df_cols}."
forecast_df_cols = forecast_df.columns
required_cols = [value_col, date_col] + id_cols
assert set(required_cols).issubset(
set(forecast_df_cols)
), f"Column mismatch between require columns {required_cols} and forecast dateframe columns {forecast_df_cols}."
# add epiweek and epiyear columns
forecast_df = forecast_df.with_columns(
[
pl.col(date_col)
.map_elements(
lambda x: calculate_epidate(datetime.strptime(x, "%Y-%m-%d"))[
0
]
)
.alias("epiweek"),
pl.col(date_col)
.map_elements(
lambda x: calculate_epidate(datetime.strptime(x, "%Y-%m-%d"))[
1
]
)
.alias("epiyear"),
]
)
pl.struct(["date"])
.map_elements(
lambda x: calculate_epidate(x["date"]), return_dtype=pl.Struct
)
.alias("epi_struct")
).unnest("epi_struct")
# group by epiweek, epiyear, and the id_cols
group_cols = ["epiweek", "epiyear"] + id_cols
grouped_df = forecast_df.group_by(group_cols)
Expand Down
8 changes: 2 additions & 6 deletions forecasttools/to_flusight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
import polars as pl


def calculate_epidate(date):
epiweek = epiweeks.Week.fromdate(date)
return epiweek.week, epiweek.year


def get_flusight_target_end_dates(
reference_date: str, horizons=None
) -> pl.DataFrame:
Expand All @@ -20,7 +15,8 @@ def get_flusight_target_end_dates(
data = []
for horizon in horizons:
target_end_date = reference_date_dt + timedelta(weeks=horizon)
epiweek, epiyear = calculate_epidate(target_end_date)
epiweek = epiweeks.Week.fromdate(target_end_date)
epiweek, epiyear = epiweek.week, epiweek.year
data.append(
{
"reference_date": reference_date,
Expand Down
55 changes: 53 additions & 2 deletions notebooks/example_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,68 @@
# %% IMPORTS


import datetime as dt

import arviz as az
import epiweeks
import polars as pl
import xarray as xr

import forecasttools

# %% SIMPLE EPI COLUMNS


# create artificial date dataframe
df_dict = {
"date": ["2024-01-01", "2024-01-06", "2024-01-08"],
"value": [1, 2, 2],
}
df = pl.DataFrame(df_dict)
print(df)

# method 01: use map elements
new_df = df.with_columns(
pl.col("date").str.strptime(pl.Date, "%Y-%m-%d")
).with_columns(
pl.col("date")
.map_elements(
lambda d: epiweeks.Week.fromdate(d).week, return_dtype=pl.Int64
)
.alias("epiweek"),
pl.col("date")
.map_elements(
lambda d: epiweeks.Week.fromdate(d).year, return_dtype=pl.Int64
)
.alias("epiyear"),
)


# method 02: use unnest w/ a function
def calculate_epidate(date: str):
week_obj = epiweeks.Week.fromdate(dt.datetime.strptime(date, "%Y-%m-%d"))
return {"epiweek": week_obj.week, "epiyear": week_obj.year}


newer_df = df.with_columns(
pl.struct(["date"])
.map_elements(
lambda x: calculate_epidate(x["date"]), return_dtype=pl.Struct
)
.alias("epi_struct")
).unnest("epi_struct")
print(newer_df)


# %% USE OF PREDICTIONS IN ARVIZ

# load idata object from forecasttools
idata = forecasttools.nhsn_flu_forecast
print(idata.posterior_predictive)

# has predictions by default
print(idata.predictions["obs_dim_0"])
print(idata.observed_data["obs_dim_0"])
print(idata.posterior_predictive["obs_dim_0"])

# get posterior samples
postp_samps = idata.posterior_predictive["obs"]
Expand All @@ -34,7 +85,7 @@
predictions_idata = az.InferenceData(predictions=xr.Dataset(predictions_dict))


# edit original idata object
# (attempt) edit original idata object
idata.posterior_predictive["obs"] = fitted
idata = idata.extend(predictions_idata)

Expand Down

0 comments on commit afbd528

Please sign in to comment.