Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyro.contrib.epidemiology experiment #13

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions 2020-06-compartmental/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
results/
20 changes: 20 additions & 0 deletions 2020-06-compartmental/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.PHONY: all lint clean mrclean short_uni_synth

all: lint

lint: FORCE
flake8

short_uni_synth: FORCE
python runner.py --experiment=short_uni_synth

long_uni_synth: FORCE
python runner.py --experiment=long_uni_synth

clean: FORCE
rm -rf temp logs errors

mrclean: FORCE
rm -rf data results

FORCE:
190 changes: 190 additions & 0 deletions 2020-06-compartmental/analyze.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import matplotlib.pyplot as plt\n",
"from runner import short_uni_synth, long_uni_synth\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_formats = ['svg']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results = list(short_uni_synth.results)\n",
"print(len(results))\n",
"for r in results:\n",
" if r[\"args\"].infer == \"mcmc\":\n",
" break\n",
"print(r.keys())\n",
"print(r[\"times\"].keys())\n",
"print(r[\"evaluate\"].keys())\n",
"print(r[\"evaluate\"][\"R0\"].keys())\n",
"print(r[\"infer\"].keys())\n",
"print(r[\"infer\"][\"R0\"].keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_accuracy(variable, metric, experiment):\n",
" view = defaultdict(list)\n",
" for result in experiment.results:\n",
" args = result['args']\n",
" view[args.infer, args.num_bins, args.svi_steps].append(result)\n",
" markers = [\"o\", \"d\", \"s\", \"<\", \"v\", \"^\", \">\"]\n",
" assert len(view) <= len(markers)\n",
"\n",
" plt.figure(figsize=(6, 5)).patch.set_color(\"white\")\n",
" for (key, value), marker in zip(sorted(view.items()), markers):\n",
" algo, num_bins, svi_steps = key\n",
" if algo == \"svi\":\n",
" label = f\"SVI steps={svi_steps}\"\n",
" elif algo == \"mcmc\":\n",
" if num_bins == 1:\n",
" label = \"MCMC relaxed\"\n",
" else:\n",
" label = f\"MCMC num_bins={num_bins}\"\n",
" X = [v[\"times\"][\"infer\"] for v in value]\n",
" Y = [v[\"evaluate\"][variable][metric] for v in value]\n",
" plt.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n",
" plt.ylim(0, None)\n",
" plt.xscale(\"log\")\n",
" plt.xlabel(\"inference time (sec)\")\n",
" plt.ylabel(metric.upper())\n",
" plt.title(f\"{variable} accuracy ({experiment.__name__})\")\n",
" plt.legend(loc=\"best\", prop={'size': 8})\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_accuracy(\"R0\", \"crps\", short_uni_synth)\n",
"plot_accuracy(\"rho\", \"crps\", short_uni_synth)\n",
"plot_accuracy(\"obs\", \"crps\", short_uni_synth)\n",
"plot_accuracy(\"I\", \"crps\", short_uni_synth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_accuracy(\"R0\", \"crps\", long_uni_synth)\n",
"plot_accuracy(\"rho\", \"crps\", long_uni_synth)\n",
"plot_accuracy(\"obs\", \"crps\", long_uni_synth)\n",
"plot_accuracy(\"I\", \"crps\", long_uni_synth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_convergence(variable, experiment, metrics=[\"n_eff\", \"r_hat\"]):\n",
" view = defaultdict(list)\n",
" for result in results:\n",
" args = result['args']\n",
" if args.infer == \"mcmc\":\n",
" view[args.num_bins].append(result)\n",
" markers = [\"o\", \"d\", \"s\"]\n",
" assert len(view) <= len(markers)\n",
"\n",
" fig, axes = plt.subplots(len(metrics), 1, figsize=(6, 5), sharex=True)\n",
" fig.patch.set_color(\"white\")\n",
" for (num_bins, value), marker in zip(sorted(view.items()), markers):\n",
" if num_bins == 1:\n",
" label = \"MCMC relaxed\"\n",
" else:\n",
" label = f\"MCMC num_bins={num_bins}\"\n",
" X = [v[\"times\"][\"infer\"] for v in value]\n",
" for metric, ax in zip(metrics, axes):\n",
" Y = [v[\"infer\"][variable][metric] for v in value]\n",
" ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n",
" ax.set_xscale(\"log\")\n",
" ax.set_yscale(\"log\")\n",
" ax.set_ylabel(metric)\n",
" axes[0].set_title(f\"{variable} convergence ({experiment.__name__})\")\n",
" axes[1].set_ylim(1, None)\n",
" axes[-1].legend(loc=\"best\", prop={'size': 8})\n",
" axes[-1].set_xlabel(\"inference time (sec)\")\n",
" plt.subplots_adjust(hspace=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_convergence(\"R0\", short_uni_synth)\n",
"plot_convergence(\"rho\", short_uni_synth)\n",
"plot_convergence(\"auxiliary_haar_split_0\", short_uni_synth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_convergence(\"R0\", long_uni_synth)\n",
"plot_convergence(\"rho\", long_uni_synth)\n",
"plot_convergence(\"auxiliary_haar_split_0\", long_uni_synth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
123 changes: 123 additions & 0 deletions 2020-06-compartmental/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
import pickle
import subprocess
import sys
from importlib import import_module

from util import get_filename


class Experiment:
"""
An experiment consists of a collection of tasks.
Each task generates a datapoint by running a python script.
Result datapoints are cached in pickle files named by fingerprint.
"""
def __init__(self, generate_tasks):
self.__name__ = generate_tasks.__name__
self.tasks = [[sys.executable] + task for task in generate_tasks()]
self.files = []
for task in self.tasks:
script = task[1]
parser = import_module(script.replace(".py", "")).Parser()
outfile = get_filename(script, parser.parse_args(task[2:]))
self.files.append(outfile)

@property
def results(self):
"""
Iterates over the subset of experiment results that have been generated.
"""
for outfile in self.files:
if os.path.exists(outfile):
with open(outfile, "rb") as f:
result = pickle.load(f)
yield result


@Experiment
def short_uni_synth():
base = [
"uni_synth.py",
"--population=1000",
"--duration=20", "--forecast=10",
"--R0=3", "--incubation-time=2", "--recovery-time=4",
]
for svi_steps in [1000, 2000, 5000, 10000]:
for rng_seed in range(10):
yield base + ["--svi",
"--num-samples=1000",
f"--svi-steps={svi_steps}",
f"--rng-seed={rng_seed}"]
for num_bins in [1, 2, 4]:
for num_samples in [200, 500, 1000]:
num_warmup = int(round(0.4 * num_samples))
if num_bins == 1:
num_seeds = 10
else:
num_seeds = 2
for rng_seed in range(num_seeds):
yield base + ["--mcmc",
f"--warmup-steps={num_warmup}",
f"--num-samples={num_samples}",
f"--num-bins={num_bins}",
f"--rng-seed={rng_seed}"]


@Experiment
def long_uni_synth():
base = [
"uni_synth.py",
"--population=100000",
"--duration=100", "--forecast=30",
"--R0=2.5", "--incubation-time=4", "--recovery-time=10",
]
for svi_steps in [1000, 2000, 5000, 10000]:
for rng_seed in range(10):
yield base + ["--svi",
"--num-samples=1000",
f"--svi-steps={svi_steps}",
f"--rng-seed={rng_seed}"]
for num_samples in [200, 500, 1000, 2000, 5000]:
num_warmup = int(round(0.4 * num_samples))
for rng_seed in range(10):
yield base + ["--mcmc",
"--num-bins=1",
f"--warmup-steps={num_warmup}",
f"--num-samples={num_samples}",
f"--rng-seed={rng_seed}"]
for num_bins in [2, 4]:
for num_samples in [200, 500, 1000]:
num_warmup = int(round(0.4 * num_samples))
for rng_seed in range(2):
yield base + ["--mcmc",
f"--warmup-steps={num_warmup}",
f"--num-samples={num_samples}",
f"--num-bins={num_bins}",
f"--rng-seed={rng_seed}"]


def main(args):
experiment = globals()[args.experiment]
for task, outfile in zip(experiment.tasks, experiment.files):
print(" \\\n ".join(task))
if args.dry_run or os.path.exists(outfile):
continue
subprocess.check_call(task)

print("-------------------------")
print("COMPLETED {} TASKS".format(len(experiment.tasks)))
print("-------------------------")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="experiment runner")
parser.add_argument("--experiment")
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()

main(args)
6 changes: 6 additions & 0 deletions 2020-06-compartmental/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
max-line-length = 120

[isort]
line_length = 120
multi_line_output=3
Loading