Skip to content

Commit

Permalink
Refactor summarization (#45)
Browse files Browse the repository at this point in the history
* refactoring

* use polars schema

* refactor lists
  • Loading branch information
swo authored Dec 19, 2024
1 parent 5725420 commit 2c594b7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 46 deletions.
85 changes: 40 additions & 45 deletions ringvax/summary.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,66 @@
from collections import Counter
from typing import Container, Sequence
from typing import Sequence

import numpy as np
import polars as pl

from ringvax import Simulation

infection_schema = {
"infector": pl.String,
"generation": pl.Int64,
"t_exposed": pl.Float64,
"t_infectious": pl.Float64,
"t_recovered": pl.Float64,
"infection_rate": pl.Float64,
"detected": pl.Boolean,
"detect_method": pl.String,
"t_detected": pl.Float64,
"infection_times": pl.List(pl.Float64),
}
infection_schema = pl.Schema(
{
"infector": pl.String,
"generation": pl.Int64,
"t_exposed": pl.Float64,
"t_infectious": pl.Float64,
"t_recovered": pl.Float64,
"infection_rate": pl.Float64,
"detected": pl.Boolean,
"detect_method": pl.String,
"t_detected": pl.Float64,
"infection_times": pl.List(pl.Float64),
}
)
"""
An infection as a polars schema
"""


def prepare_for_df(infection: dict) -> dict:
"""
Handle vector-valued infection properties for downstream use in pl.DataFrame
"""
dfable = {}
for k, v in infection.items():
if isinstance(v, np.ndarray):
assert k == "infection_times"
dfable |= {k: [float(vv) for vv in v]}
else:
assert isinstance(v, str) or not isinstance(v, Container)
dfable |= {k: v}
return dfable


def get_all_person_properties(
sims: Sequence[Simulation], exclude_termination_if: list[str] = ["max_infections"]
) -> pl.DataFrame:
"""
Get a dataframe of all properties of all infections
"""
g_max = [sim.params["n_generations"] for sim in sims]
assert (
len(Counter(g_max).items()) == 1
len(set(sim.params["n_generations"] for sim in sims)) == 1
), "Aggregating simulations with different `n_generations` is nonsensical"

i_max = [sim.params["max_infections"] for sim in sims]
assert (
len(Counter(i_max).items()) == 1
len(set(sim.params["max_infections"] for sim in sims)) == 1
), "Aggregating simulations with different `max_infections` is nonsensical"

per_sim = []
for idx, sim in enumerate(sims):
if sim.termination["criterion"] not in exclude_termination_if:
sims_dict = {k: [] for k in infection_schema.keys()} | {
"simulation": [idx] * len(sim.infections)
}
for infection in sim.infections.values():
prep = prepare_for_df(infection)
for k in infection_schema.keys():
sims_dict[k].append(prep[k])
per_sim.append(pl.DataFrame(sims_dict).cast(infection_schema)) # type: ignore
return pl.concat(per_sim)
return pl.concat(
[
_get_person_properties(sim).with_columns(simulation=sim_idx)
for sim_idx, sim in enumerate(sims)
if sim.termination["criterion"] not in exclude_termination_if
]
)


def _get_person_properties(sim: Simulation) -> pl.DataFrame:
"""Get a DataFrame of all properties of all infections in a simulation"""
return pl.from_dicts(
[_prepare_for_df(x) for x in sim.infections.values()], schema=infection_schema
)


def _prepare_for_df(infection: dict) -> dict:
"""
Convert numpy arrays in a dictionary to lists, for DataFrame compatibility
"""
return {
k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in infection.items()
}


def summarize_detections(df: pl.DataFrame) -> pl.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_prep_for_df():
infection = {"infection_times": np.array([0, 1, 2]), "detected": False}
assert ringvax.summary.prepare_for_df(infection) == {
assert ringvax.summary._prepare_for_df(infection) == {
"infection_times": [0, 1, 2],
"detected": False,
}
Expand Down

0 comments on commit 2c594b7

Please sign in to comment.