Skip to content

Commit

Permalink
begin port of notebook into codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Jan 22, 2025
1 parent beba8f3 commit b12ffc2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 248 deletions.
41 changes: 41 additions & 0 deletions forecasttools/idata_to_tidy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import re

import arviz as az
import polars as pl


def convert_idata_forecast_to_tidydraws(
idata: az.InferenceData,
groups: list[str]
) -> dict[str, pl.DataFrame]:
tidy_dfs = {}
idata_df = idata.to_dataframe()
for group in groups:
group_columns = [
col for col in idata_df.columns
if isinstance(col, tuple) and col[0] == group
]
meta_columns = ["chain", "draw"]
group_df = idata_df[meta_columns + group_columns]
group_df.columns = [
col[1] if isinstance(col, tuple) else col
for col in group_df.columns
]
group_pols_df = pl.from_pandas(group_df)
value_columns = [col for col in group_pols_df.columns if col not in meta_columns]
group_pols_df = group_pols_df.melt(
id_vars=meta_columns,
value_vars=value_columns,
variable_name="variable",
value_name="value"
)
group_pols_df = group_pols_df.with_columns(
pl.col("variable").map_elements(lambda x: re.sub(r"\[.*\]", "", x)).alias("variable")
)
group_pols_df = group_pols_df.with_columns(
((pl.col("draw") - 1) % group_pols_df["draw"].n_unique() + 1).alias(".iteration")
)
group_pols_df = group_pols_df.rename({"chain": ".chain", "draw": ".draw"})
tidy_dfs[group] = group_pols_df.select([".chain", ".draw", ".iteration", "variable", "value"])

return tidy_dfs
248 changes: 0 additions & 248 deletions notebooks/idata_to_tidy_draws.py

This file was deleted.

0 comments on commit b12ffc2

Please sign in to comment.