Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor summarization #45

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 40 additions & 45 deletions ringvax/summary.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,66 @@
from collections import Counter
from typing import Container, Sequence
from typing import Sequence

import numpy as np
import polars as pl

from ringvax import Simulation

infection_schema = {
"infector": pl.String,
"generation": pl.Int64,
"t_exposed": pl.Float64,
"t_infectious": pl.Float64,
"t_recovered": pl.Float64,
"infection_rate": pl.Float64,
"detected": pl.Boolean,
"detect_method": pl.String,
"t_detected": pl.Float64,
"infection_times": pl.List(pl.Float64),
}
infection_schema = pl.Schema(
{
swo marked this conversation as resolved.
Show resolved Hide resolved
"infector": pl.String,
"generation": pl.Int64,
"t_exposed": pl.Float64,
"t_infectious": pl.Float64,
"t_recovered": pl.Float64,
"infection_rate": pl.Float64,
"detected": pl.Boolean,
"detect_method": pl.String,
"t_detected": pl.Float64,
"infection_times": pl.List(pl.Float64),
}
)
"""
An infection as a polars schema
"""


def prepare_for_df(infection: dict) -> dict:
"""
Handle vector-valued infection properties for downstream use in pl.DataFrame
"""
dfable = {}
for k, v in infection.items():
if isinstance(v, np.ndarray):
assert k == "infection_times"
dfable |= {k: [float(vv) for vv in v]}
else:
assert isinstance(v, str) or not isinstance(v, Container)
dfable |= {k: v}
return dfable


def get_all_person_properties(
sims: Sequence[Simulation], exclude_termination_if: list[str] = ["max_infections"]
) -> pl.DataFrame:
"""
Get a dataframe of all properties of all infections
"""
g_max = [sim.params["n_generations"] for sim in sims]
assert (
len(Counter(g_max).items()) == 1
len(set(sim.params["n_generations"] for sim in sims)) == 1
), "Aggregating simulations with different `n_generations` is nonsensical"

i_max = [sim.params["max_infections"] for sim in sims]
assert (
len(Counter(i_max).items()) == 1
len(set(sim.params["max_infections"] for sim in sims)) == 1
), "Aggregating simulations with different `max_infections` is nonsensical"

per_sim = []
for idx, sim in enumerate(sims):
if sim.termination["criterion"] not in exclude_termination_if:
sims_dict = {k: [] for k in infection_schema.keys()} | {
"simulation": [idx] * len(sim.infections)
}
for infection in sim.infections.values():
prep = prepare_for_df(infection)
for k in infection_schema.keys():
sims_dict[k].append(prep[k])
per_sim.append(pl.DataFrame(sims_dict).cast(infection_schema)) # type: ignore
return pl.concat(per_sim)
return pl.concat(
[
_get_person_properties(sim).with_columns(simulation=sim_idx)
for sim_idx, sim in enumerate(sims)
if sim.termination["criterion"] not in exclude_termination_if
]
)


def _get_person_properties(sim: Simulation) -> pl.DataFrame:
"""Get a DataFrame of all properties of all infections in a simulation"""
return pl.from_dicts(
[_prepare_for_df(x) for x in sim.infections.values()], schema=infection_schema
)


def _prepare_for_df(infection: dict) -> dict:
"""
Convert numpy arrays in a dictionary to lists, for DataFrame compatibility
"""
return {
k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in infection.items()
}

afmagee42 marked this conversation as resolved.
Show resolved Hide resolved

def summarize_detections(df: pl.DataFrame) -> pl.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_prep_for_df():
infection = {"infection_times": np.array([0, 1, 2]), "detected": False}
assert ringvax.summary.prepare_for_df(infection) == {
assert ringvax.summary._prepare_for_df(infection) == {
"infection_times": [0, 1, 2],
"detected": False,
}
Expand Down
Loading