Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More control over size histogram #72

Merged
merged 13 commits into from
Jan 15, 2025
44 changes: 36 additions & 8 deletions ringvax/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ringvax import Simulation
from ringvax.summary import (
get_all_person_properties,
get_total_infection_count_df,
get_infection_counts_by_generation,
prob_control_by_gen,
summarize_detections,
summarize_infections,
Expand Down Expand Up @@ -62,7 +62,7 @@ def render_percents(df: pl.DataFrame) -> pl.DataFrame:
.then(pl.lit("Not a number"))
.otherwise(
pl.col(col).map_elements(
lambda x: f"{round(x):.0f}%", return_dtype=pl.String
lambda x: f"{round(x * 100):.0f}%", return_dtype=pl.String
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
)
)
.alias(col)
Expand Down Expand Up @@ -297,6 +297,20 @@ def infectiousness_callback():
seed = st.number_input("Random seed", value=1234, step=1)
nsim = st.number_input("Number of simulations", value=250, step=1)

plot_gen = st.segmented_control(
"Generation to plot",
options=range(1, n_generations + 1),
default=n_generations,
)
cumulative = (
st.segmented_control(
"Show infections cumulatively or in specific generation?",
options=["Cumulative", "In generation"],
default="Cumulative",
)
== "Cumulative"
)

params = {
"n_generations": n_generations,
"latent_duration": latent_duration,
Expand Down Expand Up @@ -362,16 +376,30 @@ def infectiousness_callback():
help=f"The probability that there are no infections in the {format_control_gens(control_generations)}, or equivalently that the {format_control_gens(control_generations - 1)} do not produce any further infections.",
)

st.header("Number of infections")
st.write(
f"Distribution of the total number of infections seen in {n_generations} generations."
st.header(
"Number of infections",
help="This is a histogram describing the distribution of the number of infections. You can change what is plotted here in the Advanced Settings.",
)
generational_counts = get_infection_counts_by_generation(sim_df)

if cumulative:
counts = (
generational_counts.filter(pl.col("generation") <= plot_gen)
.group_by("simulation")
.agg(pl.col("num_infections").sum())
)
else:
counts = generational_counts.filter(pl.col("generation") == plot_gen)

x_lab = f"Number of infections in generation {plot_gen}"
if cumulative:
x_lab = f"Cumulative infections through generation {plot_gen}"
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved
st.altair_chart(
alt.Chart(get_total_infection_count_df(sim_df))
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("size:Q", bin=True, title="Number of infections"),
y=alt.Y("count()", title="Count"),
x=alt.X("num_infections:Q", bin=True, title=x_lab),
y=alt.Y("count()", title="Number of simulations"),
)
)

Expand Down
25 changes: 19 additions & 6 deletions ringvax/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,25 @@ def prob_control_by_gen(df: pl.DataFrame, gen: int) -> float:
return 1.0 - (size_at_gen.shape[0] / n_sim)


def get_total_infection_count_df(df: pl.DataFrame) -> pl.DataFrame:
def get_infection_counts_by_generation(df: pl.DataFrame) -> pl.DataFrame:
"""
Get DataFrame of all total outbreak sizes from simulations
Get DataFrame of number of infections in each generation from simulations.
"""
return (
df.group_by("simulation")
# length of anything in the grouped dataframe is number of infections
.agg(pl.col("t_exposed").len().alias("size"))
non_extinct = df.group_by("simulation", "generation").agg(num_infections=pl.len())

gmax = int(max(df["generation"]))
nsims = int(max(df["simulation"])) + 1

all_extinct = [
{"simulation": i, "generation": g, "num_infections": 0}
for i in range(nsims)
for g in range(gmax + 1)
]

all_extinct = pl.DataFrame(all_extinct).cast(
{"num_infections": pl.UInt32, "simulation": pl.Int32}
)

extinct = all_extinct.join(non_extinct, on=["simulation", "generation"], how="anti")
afmagee42 marked this conversation as resolved.
Show resolved Hide resolved

return pl.concat([non_extinct, extinct])
45 changes: 45 additions & 0 deletions tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,48 @@ def test_get_all_person_properties():
# result should be a data frame of length 4
assert isinstance(x, pl.DataFrame)
assert x.shape[0] == 4


def test_generational_counts():
params = {
"n_generations": 6,
"latent_duration": 1.0,
"infectious_duration": 3.0,
"infection_rate": 0.5,
"p_passive_detect": 0.5,
"passive_detection_delay": 2.0,
"p_active_detect": 0.15,
"active_detection_delay": 2.0,
"max_infections": 1000000,
}

n_sims = 3
sims = []
for i in range(n_sims):
sims.append(ringvax.Simulation(params=params, rng=np.random.default_rng(i)))
sims[-1].run()

all_sims = ringvax.summary.get_all_person_properties(sims)
max_obs_gen = [
max(sim.get_person_property(id, "generation") for id in sim.infections)
for sim in sims
]
obs_g_max = max(max_obs_gen)

gen_counts = ringvax.summary.get_infection_counts_by_generation(all_sims)

assert gen_counts.shape[0] == (obs_g_max + 1) * n_sims

for i, sim in enumerate(sims):
sim_counts = gen_counts.filter(pl.col("simulation") == i)
assert sim_counts.shape[0] == obs_g_max + 1

for g in range(max_obs_gen[i] + 1):
assert sim_counts.filter(pl.col("generation") == g)["num_infections"][
0
] == len(sim.query_people({"generation": g}))

for g in range(max_obs_gen[i] + 1, obs_g_max):
assert (
sim_counts.filter(pl.col("generation") == g)["num_infections"][0] == 0
)
Loading