diff --git a/poetry.lock b/poetry.lock index 8496cc4..202c7c5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "altair" @@ -747,6 +747,47 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "1.17.1" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.9" +files = [ + {file = "polars-1.17.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d3a2172f7cf332010f0b034345111e9c86d59b5a5b0fc5aa0509121f40d9e43c"}, + {file = "polars-1.17.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:82e98c69197df0d8ddc341a6175008508ceaea88f723f32044027810bcdb43fa"}, + {file = "polars-1.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59abdab015ed2ecfa0c63862b960816c35096e1f4df057dde3c44cd973af5029"}, + {file = "polars-1.17.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:6d2f922c403b8900b3ae3c23a27b2cae3a2db40ad790cc4fc368402b92629b11"}, + {file = "polars-1.17.1-cp39-abi3-win_amd64.whl", hash = "sha256:d38156c8259554cbcb17874d91e6dfa9c404335f08a3307496aadfdee46baa31"}, + {file = "polars-1.17.1.tar.gz", hash = "sha256:5a3dac3cb7cbe174d1fa898cba9afbede0c08e8728feeeab515554d762127019"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.19.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +gpu = ["cudf-polars-cu12"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["altair (>=5.4.0)"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "protobuf" version = "5.29.1" @@ -1296,4 +1337,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "12a0d963bbff6ff6a2ade750627a66df232fd6c06690e09167b1a51529e01607" +content-hash = "56438ac7f627fc645ee6ad0395bce3dccbda15dcef600e6321724fe9c10e4935" diff --git a/pyproject.toml b/pyproject.toml index b79e09d..7488d6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ python = "^3.10" numpy = "^2.2.0" streamlit = "^1.41.0" graphviz = "^0.20.3" +polars = "^1.17.1" [tool.poetry.group.dev.dependencies] diff --git a/ringvax/app.py b/ringvax/app.py index 23b79ec..cf9282e 100644 --- a/ringvax/app.py +++ b/ringvax/app.py @@ -1,10 +1,23 @@ +import time +from typing import List + +import altair as alt import graphviz +import polars as pl import streamlit as st from ringvax import Simulation +from ringvax.summary import ( + get_all_person_properties, + get_outbreak_size_df, + prob_control_by_gen, + summarize_detections, + summarize_infections, +) -def make_graph(sim: Simulation): +def make_graph(sim: Simulation) -> graphviz.Digraph: + """Make a transmission graph""" graph = graphviz.Digraph() for infectee in sim.query_people(): infector = sim.get_person_property(infectee, "infector") @@ -21,71 +34,127 @@ def make_graph(sim: Simulation): return graph +@st.fragment +def show_graph(sims: List[Simulation], pause: float = 0.1): + """Show a transmission graph. Wrap as st.fragment, to not re-run simulations. + + Args: + sims (List[Simulation]): list of simulations + pause (float, optional): Number of seconds to pause before displaying + new graph. Defaults to 0.1. + """ + idx = st.number_input( + "Simulation to plot", min_value=0, max_value=len(sims) - 1, value=0 + ) + placeholder = st.empty() + time.sleep(pause) + placeholder.graphviz_chart(make_graph(sims[idx])) + + +def format_control_gens(gen: int): + if gen == 0: + return "index_case" + if gen == 1: + return "contacts" + elif gen > 1: + return "".join(["contacts of "] * (gen - 1)) + "contacts" + else: + raise RuntimeError("Must specify `gen` >= 0.") + + +def format_duration(x: float, digits=3) -> str: + """Format a number of seconds duration into a string""" + assert x >= 0 + min_time = 10 ** (-digits) + if x < min_time: + return f"<{min_time} seconds" + else: + return f"{round(x, digits)} seconds" + + def app(): st.title("Ring vaccination") - latent_duration = st.slider( - "Latent duration", - min_value=0.0, - max_value=10.0, - value=1.0, - step=0.1, - format="%.1f days", - ) - infectious_duration = st.slider( - "Infectious duration", - min_value=0.0, - max_value=10.0, - value=3.0, - step=0.1, - format="%.1f days", - ) - infection_rate = st.slider( - "Infection rate", min_value=0.0, max_value=10.0, value=1.0, step=0.1 - ) - p_passive_detect = ( - st.slider( - "Passive detection probability", + with st.sidebar: + latent_duration = st.slider( + "Latent duration", min_value=0.0, - max_value=100.0, - value=0.5, - step=0.01, - format="%d%%", + max_value=10.0, + value=1.0, + step=0.1, + format="%.1f days", ) - / 100.0 - ) - passive_detection_delay = st.slider( - "Passive detection delay", - min_value=0.0, - max_value=10.0, - value=2.0, - step=0.1, - format="%.1f days", - ) - p_active_detect = ( - st.slider( - "Active detection probability", + infectious_duration = st.slider( + "Infectious duration", min_value=0.0, - max_value=100.0, - value=15.0, + max_value=10.0, + value=3.0, step=0.1, - format="%d%%", + format="%.1f days", + ) + infection_rate = st.slider( + "Infection rate", min_value=0.0, max_value=10.0, value=0.5, step=0.1 + ) + p_passive_detect = ( + st.slider( + "Passive detection probability", + min_value=0.0, + max_value=100.0, + value=50.0, + step=1.0, + format="%d%%", + ) + / 100.0 + ) + passive_detection_delay = st.slider( + "Passive detection delay", + min_value=0.0, + max_value=10.0, + value=2.0, + step=0.1, + format="%.1f days", + ) + p_active_detect = ( + st.slider( + "Active detection probability", + min_value=0.0, + max_value=100.0, + value=15.0, + step=1.0, + format="%d%%", + ) + / 100.0 + ) + active_detection_delay = st.slider( + "Active detection delay", + min_value=0.0, + max_value=10.0, + value=2.0, + step=0.1, + format="%.1f days", ) - / 100.0 - ) - active_detection_delay = st.slider( - "Active detection delay", - min_value=0.0, - max_value=10.0, - value=2.0, - step=0.1, - format="%.1f days", - ) - n_generations = st.number_input("Number of generations", value=4, step=1) - max_infections = st.number_input( - "Maximum number of infections", value=100, step=10, min_value=10 - ) - seed = st.number_input("Random seed", value=1234, step=1) + + with st.expander("Advanced Options"): + n_generations = st.number_input( + "Number of simulated generations", value=4, step=1 + ) + control_generations = st.number_input( + "Degree of contacts for checking control", + value=3, + step=1, + min_value=1, + max_value=n_generations + 1, + help="Successful control is defined as no infections in contacts at this degree. Set to 1 for contacts of the index case, 2 for contacts of contacts, etc. Equivalent to checking for extinction in the specified generation.", + ) + max_infections = st.number_input( + "Maximum number of infections", + value=100, + step=10, + min_value=100, + help="", + ) + seed = st.number_input("Random seed", value=1234, step=1) + nsim = st.number_input("Number of simulations", value=250, step=1) params = { "n_generations": n_generations, @@ -99,21 +168,72 @@ def app(): "max_infections": max_infections, } - st.subheader( - f"R0 is {infectious_duration * infection_rate:.2f}", - help="R0 is the average duration of infection multiplied by the infectious rate.", + sims = [] + with st.spinner("Running simulation..."): + tic = time.perf_counter() + for i in range(nsim): + sims.append(Simulation(params=params, seed=seed + i)) + sims[-1].run() + toc = time.perf_counter() + + st.write( + f"Ran {nsim} simulations in {format_duration(toc - tic)} with an $R_0$ of {infectious_duration * infection_rate:.2f} (the product of the average duration of infection and the infectious rate)." ) - s = Simulation(params=params, seed=seed) - s.run() + tab1, tab2 = st.tabs(["Simulation summary", "Per-simulation results"]) + with tab1: + sim_df = get_all_person_properties(sims) - st.header("Graph of infections") - st.graphviz_chart(make_graph(s)) + pr_control = prob_control_by_gen(sim_df, control_generations) + st.header( + f"Probability of control: {pr_control:.0%}", + help=f"The probability that there are no infections in the {format_control_gens(control_generations)}, or equivalently that the {format_control_gens(control_generations - 1)} do not produce any further infections.", + ) + + st.header("Number of infections") + st.write( + f"Distribution of the total number of infections seen in {n_generations} generations." + ) + st.altair_chart( + alt.Chart(get_outbreak_size_df(sim_df)) + .mark_bar() + .encode( + x=alt.X("size:Q", bin=True, title="Number of infections"), + y=alt.Y("count()", title="Count"), + ) + ) + + st.header("Summary of dynamics") + infection = summarize_infections(sim_df) + st.write( + f"In these simulations, the average duration of infectiousness was {infection['mean_infectious_duration'][0]:.2f} and $R_e$ was {infection['mean_n_infections'][0]:.2f}" + ) + + st.write( + "The following table provides summaries of marginal probabilities regarding detection. Aside from the marginal probability of active detection, these are the observed probabilities that any individual is detected in this manner. The marginal probability of active detection excludes index cases, which are not eligible for active detection." + ) + detection = summarize_detections(sim_df) + st.dataframe( + detection.select( + (pl.col(col) * 100).round(0).cast(pl.Int64) for col in detection.columns + ) + .with_columns( + pl.concat_str([pl.col(col), pl.lit("%")], separator="") + for col in detection.columns + ) + .rename( + { + "prob_detect": "Any detection", + "prob_active": "Active detection", + "prob_passive": "Passive detection", + "prob_detect_before_infectious": "Detection before onset of infectiousness", + } + ) + ) - st.header("Raw results") - for id, content in s.infections.items(): - st.text(id) - st.text(content) + with tab2: + st.header("Graph of infections") + show_graph(sims=sims) if __name__ == "__main__": diff --git a/ringvax/summary.py b/ringvax/summary.py new file mode 100644 index 0000000..c784a96 --- /dev/null +++ b/ringvax/summary.py @@ -0,0 +1,159 @@ +from collections import Counter +from typing import Container, Sequence + +import numpy as np +import polars as pl + +from ringvax import Simulation + +infection_schema = { + "infector": pl.String, + "generation": pl.Int64, + "t_exposed": pl.Float64, + "t_infectious": pl.Float64, + "t_recovered": pl.Float64, + "infection_rate": pl.Float64, + "detected": pl.Boolean, + "detect_method": pl.String, + "t_detected": pl.Float64, + "infection_times": pl.List(pl.Float64), +} +""" +An infection as a polars schema +""" + + +def prepare_for_df(infection: dict) -> dict: + """ + Handle vector-valued infection properties for downstream use in pl.DataFrame + """ + dfable = {} + for k, v in infection.items(): + if isinstance(v, np.ndarray): + assert k == "infection_times" + dfable |= {k: [float(vv) for vv in v]} + else: + assert isinstance(v, str) or not isinstance(v, Container) + dfable |= {k: v} + return dfable + + +def get_all_person_properties(sims: Sequence[Simulation]) -> pl.DataFrame: + """ + Get a dataframe of all properties of all infections + """ + g_max = [sim.params["n_generations"] for sim in sims] + assert ( + len(Counter(g_max).items()) == 1 + ), "Aggregating simulations with different `n_generations` is nonsensical" + + i_max = [sim.params["max_infections"] for sim in sims] + assert ( + len(Counter(i_max).items()) == 1 + ), "Aggregating simulations with different `max_infections` is nonsensical" + + per_sim = [] + for idx, sim in enumerate(sims): + sims_dict = {k: [] for k in infection_schema.keys()} | { + "simulation": [idx] * len(sim.infections) + } + for infection in sim.infections.values(): + prep = prepare_for_df(infection) + for k in infection_schema.keys(): + sims_dict[k].append(prep[k]) + per_sim.append(pl.DataFrame(sims_dict).cast(infection_schema)) # type: ignore + return pl.concat(per_sim) + + +def summarize_detections(df: pl.DataFrame) -> pl.DataFrame: + """ + Get marginal detection probabilities from simulations. + """ + nsims = len(df["simulation"].unique()) + n_infections = df.shape[0] + n_active_eligible = n_infections - nsims + detection_counts = df.select(pl.col("detect_method").value_counts()).unnest( + "detect_method" + ) + + count_nodetect = 0 + if detection_counts.filter(pl.col("detect_method").is_null()).shape[0] == 1: + count_nodetect = detection_counts.filter(pl.col("detect_method").is_null())[ + "count" + ] + count_active, count_passive = 0, 0 + if detection_counts.filter(pl.col("detect_method") == "active").shape[0] == 1: + count_active = detection_counts.filter(pl.col("detect_method") == "active")[ + "count" + ] + if detection_counts.filter(pl.col("detect_method") == "passive").shape[0] == 1: + count_passive = detection_counts.filter(pl.col("detect_method") == "passive")[ + "count" + ] + + return pl.DataFrame( + { + "prob_detect": 1.0 - count_nodetect / n_infections, + "prob_active": count_active / n_active_eligible, + "prob_passive": count_passive / n_infections, + "prob_detect_before_infectious": df.filter(pl.col("detected")) + .filter(pl.col("t_detected") < pl.col("t_infectious")) + .shape[0] + / n_infections, + } + ) + + +def summarize_infections(df: pl.DataFrame) -> pl.DataFrame: + """ + Get summaries of infectiousness from simulations. + """ + df = df.with_columns( + n_infections=pl.col("infection_times").list.len(), + t_noninfectious=pl.min_horizontal( + [pl.col("t_detected"), pl.col("t_recovered")] + ), + ).with_columns( + duration_infectious=(pl.col("t_noninfectious") - pl.col("t_infectious")) + ) + + return pl.DataFrame( + { + "mean_infectious_duration": df["duration_infectious"].mean(), + "sd_infectious_duration": df["duration_infectious"].std(), + # This is R_e + "mean_n_infections": df["n_infections"].mean(), + "sd_n_infections": df["n_infections"].std(), + } + ) + + +def prob_control_by_gen(df: pl.DataFrame, gen: int) -> float: + """ + Compute the probability of control in generation (probability extinct in or before this generation) for all simulations + """ + n_sim = df["simulation"].unique().len() + size_at_gen = ( + df.with_columns( + pl.col("generation") + 1, + n_infections=pl.col("infection_times").list.len(), + ) + .with_columns(size=pl.sum("n_infections").over("simulation", "generation")) + .unique(subset=["simulation", "generation"]) + .filter( + pl.col("generation") == gen, + pl.col("size") > 0, + ) + ) + return 1.0 - (size_at_gen.shape[0] / n_sim) + + +def get_outbreak_size_df(df: pl.DataFrame) -> pl.DataFrame: + """ + Get DataFrame of all total outbreak sizes from simulations + """ + return ( + df.group_by("simulation") + # length of anything in the grouped dataframe is number of infections + .agg(pl.col("t_exposed").len().alias("size")) + )