From 903dc4823f6113bc450e3bf3c8aef38fa023ccf6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Jun 2020 20:22:36 -0700 Subject: [PATCH 01/26] Add first draft of epidemiology experiment --- 2020-06-compartmental/.gitignore | 4 + 2020-06-compartmental/Makefile | 23 +++ 2020-06-compartmental/runner.py | 93 +++++++++ 2020-06-compartmental/setup.cfg | 6 + 2020-06-compartmental/short_uni_synth.csv | 5 + 2020-06-compartmental/uni_synth.py | 221 ++++++++++++++++++++++ 6 files changed, 352 insertions(+) create mode 100644 2020-06-compartmental/.gitignore create mode 100644 2020-06-compartmental/Makefile create mode 100644 2020-06-compartmental/runner.py create mode 100644 2020-06-compartmental/setup.cfg create mode 100644 2020-06-compartmental/short_uni_synth.csv create mode 100644 2020-06-compartmental/uni_synth.py diff --git a/2020-06-compartmental/.gitignore b/2020-06-compartmental/.gitignore new file mode 100644 index 0000000..234a969 --- /dev/null +++ b/2020-06-compartmental/.gitignore @@ -0,0 +1,4 @@ +temp/ +logs/ +errors/ +results/ diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile new file mode 100644 index 0000000..206be39 --- /dev/null +++ b/2020-06-compartmental/Makefile @@ -0,0 +1,23 @@ +.PHONY: all lint clean mrclean short_uni_synth + +all: lint + +lint: FORCE + flake8 + +watch: FORCE + ls -t logs/* | head -n 1 | xargs tail -f + +short_uni_synth: FORCE + python runner.py \ + --script-filename=uni_synth.py \ + --args-filename=short_uni_synth.csv \ + --num-workers=1 + +clean: FORCE + rm -rf temp logs errors + +mrclean: FORCE + rm -rf temp logs errors results + +FORCE: diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py new file mode 100644 index 0000000..83de8d4 --- /dev/null +++ b/2020-06-compartmental/runner.py @@ -0,0 +1,93 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import csv +import multiprocessing +import os +import random +import subprocess +import sys +from collections import OrderedDict + +CPUS = multiprocessing.cpu_count() +ROOT = os.path.dirname(os.path.abspath(__file__)) +TEMP = os.path.join(ROOT, "temp") +LOGS = os.path.join(ROOT, "logs") +ERRORS = os.path.join(ROOT, "errors") +RESULTS = os.path.join(ROOT, "results") + +# Ensure directories exist. +for path in [TEMP, LOGS, ERRORS, RESULTS]: + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError: + assert os.path.exists(path) + + +def work(task): + args, spec = task + basename = "_".join("{}={}".format(k, v) for k, v in spec.items()) + result_file = os.path.join(RESULTS, basename + ".pkl") + if os.path.exists(result_file): + return True + + temp_file = os.path.join(TEMP, basename + ".pkl") + log_file = os.path.join(LOGS, basename + ".txt") + spec["output"] = temp_file + command = ([sys.executable, args.script_filename] + + ["--{}={}".format(k, v) for k, v in spec.items()]) + print(" ".join(command)) + if args.dry_run: + return result_file + try: + with open(log_file, "w") as f: + subprocess.check_call(command, stderr=f, stdout=f) + os.rename(temp_file, result_file) # Use rename to make write atomic. + return result_file + except subprocess.CalledProcessError as e: + pdb_command = [sys.executable, "-m", "pdb", "-cc"] + command[1:-1] + msg = "{}\nTo reproduce, run:\n{}".format(e, " \\\n ".join(pdb_command)) + print(msg) + with open(os.path.join(ERRORS, basename + ".txt"), "w") as f: + f.write(msg) + return None + + +def main(args): + with open(args.args_filename) as f: + reader = csv.reader(f) + header = next(reader) + tasks = [] + for row in reader: + command_args = OrderedDict((k, v) for k, v in zip(header, row) if v) + tasks.append((args, command_args)) + if args.shuffle: + random.shuffle(tasks) + + if args.num_workers == 1: + map_ = map + else: + map_ = multiprocessing.Pool(args.num_workers).map + results = map_(work, tasks) + assert all(results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="experiment runner") + parser.add_argument("-s", "--script-filename") + parser.add_argument("-a", "--args-filename") + parser.add_argument("-w", "--num-workers", type=int, default=CPUS) + parser.add_argument("-cpw", "--cores-per-worker", type=int) + parser.add_argument("--shuffle", action="store_true") + parser.add_argument("-f", "--force", action="store_true") + parser.add_argument("--dry-run", action="store_true") + args = parser.parse_args() + + if args.cores_per_worker: + args.workers = max(1, CPUS // args.cores_per_worker) + if args.dry_run: + args.workers = 1 + + main(args) diff --git a/2020-06-compartmental/setup.cfg b/2020-06-compartmental/setup.cfg new file mode 100644 index 0000000..370a859 --- /dev/null +++ b/2020-06-compartmental/setup.cfg @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 120 + +[isort] +line_length = 120 +multi_line_output=3 diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv new file mode 100644 index 0000000..733500f --- /dev/null +++ b/2020-06-compartmental/short_uni_synth.csv @@ -0,0 +1,5 @@ +population,duration,forecast,R0,incubation-time,recovery-time,rng-seed,infer,num-samples,num-bins,warmup-steps,svi-steps +1000,20,10,3,2,4,0,svi,1000,,,5000 +1000,20,10,3,2,4,0,mcmc,1000,1,200 +1000,20,10,3,2,4,0,mcmc,1000,2,200 +1000,20,10,3,2,4,0,mcmc,1000,4,200 diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py new file mode 100644 index 0000000..90e53eb --- /dev/null +++ b/2020-06-compartmental/uni_synth.py @@ -0,0 +1,221 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging +import math +import resource +import pickle +from timeit import default_timer + +import torch + +import pyro +from pyro.contrib.epidemiology.models import (HeterogeneousSIRModel, + OverdispersedSEIRModel, + OverdispersedSIRModel, + SimpleSEIRModel, SimpleSIRModel, + SuperspreadingSEIRModel, + SuperspreadingSIRModel) +from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse +from pyro.infer.mcmc.util import summary + +fmt = '%(process)d %(message)s' +logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) +logging.basicConfig(format=fmt, level=logging.INFO) + + +def Model(args, data): + """Dispatch between different model classes.""" + if args.heterogeneous: + assert args.incubation_time == 0 + assert args.overdispersion == 0 + return HeterogeneousSIRModel(args.population, args.recovery_time, data) + elif args.incubation_time > 0: + assert args.incubation_time > 1 + if args.concentration < math.inf: + return SuperspreadingSEIRModel(args.population, args.incubation_time, + args.recovery_time, data) + elif args.overdispersion > 0: + return OverdispersedSEIRModel(args.population, args.incubation_time, + args.recovery_time, data) + else: + return SimpleSEIRModel(args.population, args.incubation_time, + args.recovery_time, data) + else: + if args.concentration < math.inf: + return SuperspreadingSIRModel(args.population, args.recovery_time, data) + elif args.overdispersion > 0: + return OverdispersedSIRModel(args.population, args.recovery_time, data) + else: + return SimpleSIRModel(args.population, args.recovery_time, data) + + +def generate_data(args): + extended_data = [None] * (args.duration + args.forecast) + model = Model(args, extended_data) + logging.info("Simulating from a {}".format(type(model).__name__)) + for attempt in range(100): + truth = model.generate({"R0": args.R0, + "rho": args.response_rate, + "k": args.concentration, + "od": args.overdispersion}) + obs = truth["obs"][:args.duration] + new_I = truth.get("S2I", truth.get("E2I")) + + obs_sum = int(obs.sum()) + new_I_sum = int(new_I[:args.duration].sum()) + assert 0 <= args.min_obs_portion < args.max_obs_portion <= 1 + min_obs = int(math.ceil(args.min_obs_portion * args.population)) + max_obs = int(math.floor(args.max_obs_portion * args.population)) + if min_obs <= obs_sum <= max_obs: + logging.info("Observed {:d}/{:d} infections:\n{}".format( + obs_sum, new_I_sum, " ".join(str(int(x)) for x in obs))) + return truth + + if obs_sum < min_obs: + raise ValueError("Failed to generate >={} observations. " + "Try decreasing --min-obs-portion (currently {})." + .format(min_obs, args.min_obs_portion)) + else: + raise ValueError("Failed to generate <={} observations. " + "Try increasing --max-obs-portion (currently {})." + .format(max_obs, args.max_obs_portion)) + + +def infer_mcmc(args, model): + parallel = args.num_chains > 1 + + mcmc = model.fit_mcmc(heuristic_num_particles=args.smc_particles, + warmup_steps=args.warmup_steps, + num_samples=args.num_samples, + num_chains=args.num_chains, + mp_context="spawn" if parallel else None, + max_tree_depth=args.max_tree_depth, + arrowhead_mass=args.arrowhead_mass, + num_quant_bins=args.num_bins, + haar=args.haar, + haar_full_mass=args.haar_full_mass, + jit_compile=args.jit) + + result = summary(mcmc._samples) + for key, value in result.items(): + if isinstance(value, torch.Tensor): + result[key] = value.reshape(-1).median().values.item() + return result + + +def infer_svi(args, model): + losses = model.fit_svi(heuristic_num_particles=args.smc_particles, + num_samples=args.num_samples, + num_steps=args.svi_steps, + num_particles=args.svi_particles, + haar=args.haar, + jit=args.jit) + + return {"loss_initial": losses[0], "loss_final": losses[-1]} + + +def evaluate(args, truth, model, samples): + metrics = [("mae", eval_mae), ("rmse", eval_rmse), ("crps", eval_crps)] + result = {} + for key, pred in samples.items(): + if key == "obs": + pred = pred[..., args.duration:] + + result[key] = {} + result[key]["mean"] = pred.mean().item() + result[key]["std"] = pred.std(dim=0).mean().item() + + if key in truth: + true = truth[key] + if key == "obs": + true = true[..., args.duration:] + for metric, fn in metrics: + result[key][metric] = fn(pred, true) + + return result + + +def main(args): + pyro.enable_validation(__debug__) + pyro.set_rng_seed(args.rng_seed + 20200617) + + result = {} + + truth = generate_data(args) + + t0 = default_timer() + + model = Model(args, data=truth["obs"][:args.duration]) + infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer] + result["infer"] = infer(args, model) + + t1 = default_timer() + + samples = model.predict(forecast=args.forecast) + + t2 = default_timer() + + result["evaluate"] = evaluate(args, truth, model, samples) + result["times"] = {"infer": t1 - t0, "predict": t2 - t1} + result["rusage"] = resource.getrusage(resource.RUSAGE_SELF) + + if args.output: + with open(args.output, "wb") as f: + pickle.dump(result, f) + return result + + +if __name__ == "__main__": + assert pyro.__version__.startswith('1.3.1') + parser = argparse.ArgumentParser(description="CompartmentalModel experiments") + parser.add_argument("--population", default=1000, type=float) + parser.add_argument("--min-obs-portion", default=0.1, type=float) + parser.add_argument("--max-obs-portion", default=0.3, type=float) + parser.add_argument("--duration", default=20, type=int) + parser.add_argument("--forecast", default=10, type=int) + parser.add_argument("--R0", default=1.5, type=float) + parser.add_argument("--recovery-time", default=7.0, type=float) + parser.add_argument("--incubation-time", default=0.0, type=float) + parser.add_argument("--concentration", default=math.inf, type=float) + parser.add_argument("--response-rate", default=0.5, type=float) + parser.add_argument("--overdispersion", default=0., type=float) + parser.add_argument("--heterogeneous", action="store_true") + parser.add_argument("--infer", default="mcmc") + parser.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") + parser.add_argument("--svi", action="store_const", const="svi", dest="infer") + parser.add_argument("--haar", action="store_true") + parser.add_argument("--nohaar", action="store_const", const=False, dest="haar") + parser.add_argument("--haar-full-mass", default=10, type=int) + parser.add_argument("--num-samples", default=200, type=int) + parser.add_argument("--smc-particles", default=1024, type=int) + parser.add_argument("--svi-steps", default=5000, type=int) + parser.add_argument("--svi-particles", default=32, type=int) + parser.add_argument("--warmup-steps", type=int) + parser.add_argument("--num-chains", default=2, type=int) + parser.add_argument("--max-tree-depth", default=5, type=int) + parser.add_argument("--arrowhead-mass", action="store_true") + parser.add_argument("--rng-seed", default=0, type=int) + parser.add_argument("--num-bins", default=1, type=int) + parser.add_argument("--double", action="store_true", default=True) + parser.add_argument("--single", action="store_false", dest="double") + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true", default=True) + parser.add_argument("--nojit", action="store_false", dest="jit") + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--output") + args = parser.parse_args() + args.population = int(args.population) # to allow e.g. --population=1e6 + + if args.warmup_steps is None: + args.warmup_steps = args.num_samples + if args.double: + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) + elif args.cuda: + torch.set_default_tensor_type(torch.cuda.FloatTensor) + + main(args) From 4c43dd7b8cdbc737ab90df7a1d0e6ba0d222b497 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Jun 2020 20:56:33 -0700 Subject: [PATCH 02/26] Enlarge experiment grid --- 2020-06-compartmental/Makefile | 4 +- 2020-06-compartmental/runner.py | 21 +++++-- 2020-06-compartmental/short_uni_synth.csv | 68 +++++++++++++++++++++-- 2020-06-compartmental/uni_synth.py | 6 +- 4 files changed, 84 insertions(+), 15 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index 206be39..eaa8907 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -10,9 +10,11 @@ watch: FORCE short_uni_synth: FORCE python runner.py \ + --outfile=results/short_uni_synth \ --script-filename=uni_synth.py \ --args-filename=short_uni_synth.csv \ - --num-workers=1 + --cores-per-worker=4 \ + --shuffle clean: FORCE rm -rf temp logs errors diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index 83de8d4..ad9c3a5 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -28,14 +28,15 @@ def work(task): args, spec = task - basename = "_".join("{}={}".format(k, v) for k, v in spec.items()) + basename = (args.script_filename + "." + + "_".join("{}={}".format(k, v) for k, v in spec.items())) result_file = os.path.join(RESULTS, basename + ".pkl") - if os.path.exists(result_file): + if os.path.exists(result_file) and not args.force: return True temp_file = os.path.join(TEMP, basename + ".pkl") log_file = os.path.join(LOGS, basename + ".txt") - spec["output"] = temp_file + spec["outfile"] = temp_file command = ([sys.executable, args.script_filename] + ["--{}={}".format(k, v) for k, v in spec.items()]) print(" ".join(command)) @@ -61,7 +62,7 @@ def main(args): header = next(reader) tasks = [] for row in reader: - command_args = OrderedDict((k, v) for k, v in zip(header, row) if v) + command_args = OrderedDict(sorted((k, v) for k, v in zip(header, row) if v)) tasks.append((args, command_args)) if args.shuffle: random.shuffle(tasks) @@ -69,10 +70,17 @@ def main(args): if args.num_workers == 1: map_ = map else: + print("Running {} tasks on {} workers".format(len(tasks), args.num_workers)) map_ = multiprocessing.Pool(args.num_workers).map results = map_(work, tasks) assert all(results) + results.sort() + if args.outfile: + with open(args.outfile, "w") as f: + f.write("\n".join(results)) + return results + if __name__ == "__main__": parser = argparse.ArgumentParser(description="experiment runner") @@ -83,11 +91,12 @@ def main(args): parser.add_argument("--shuffle", action="store_true") parser.add_argument("-f", "--force", action="store_true") parser.add_argument("--dry-run", action="store_true") + parser.add_argument("--outfile") args = parser.parse_args() if args.cores_per_worker: - args.workers = max(1, CPUS // args.cores_per_worker) + args.num_workers = max(1, CPUS // args.cores_per_worker) if args.dry_run: - args.workers = 1 + args.num_workers = 1 main(args) diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv index 733500f..a9f3b23 100644 --- a/2020-06-compartmental/short_uni_synth.csv +++ b/2020-06-compartmental/short_uni_synth.csv @@ -1,5 +1,63 @@ -population,duration,forecast,R0,incubation-time,recovery-time,rng-seed,infer,num-samples,num-bins,warmup-steps,svi-steps -1000,20,10,3,2,4,0,svi,1000,,,5000 -1000,20,10,3,2,4,0,mcmc,1000,1,200 -1000,20,10,3,2,4,0,mcmc,1000,2,200 -1000,20,10,3,2,4,0,mcmc,1000,4,200 +population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,max-tree-depth,svi-steps,rng-seed +1000,20,10,3,2,4,mcmc,1000,1,400,5,,0 +1000,20,10,3,2,4,mcmc,1000,1,400,5,,1 +1000,20,10,3,2,4,mcmc,1000,1,400,6,,0 +1000,20,10,3,2,4,mcmc,1000,1,400,6,,1 +1000,20,10,3,2,4,mcmc,1000,1,400,7,,0 +1000,20,10,3,2,4,mcmc,1000,1,400,7,,1 +1000,20,10,3,2,4,mcmc,1000,2,400,5,,0 +1000,20,10,3,2,4,mcmc,1000,2,400,5,,1 +1000,20,10,3,2,4,mcmc,1000,2,400,6,,0 +1000,20,10,3,2,4,mcmc,1000,2,400,6,,1 +1000,20,10,3,2,4,mcmc,1000,2,400,7,,0 +1000,20,10,3,2,4,mcmc,1000,2,400,7,,1 +1000,20,10,3,2,4,mcmc,1000,4,400,5,,0 +1000,20,10,3,2,4,mcmc,1000,4,400,5,,1 +1000,20,10,3,2,4,mcmc,1000,4,400,6,,0 +1000,20,10,3,2,4,mcmc,1000,4,400,6,,1 +1000,20,10,3,2,4,mcmc,1000,4,400,7,,0 +1000,20,10,3,2,4,mcmc,1000,4,400,7,,1 +1000,20,10,3,2,4,mcmc,2000,1,800,5,,0 +1000,20,10,3,2,4,mcmc,2000,1,800,5,,1 +1000,20,10,3,2,4,mcmc,2000,1,800,6,,0 +1000,20,10,3,2,4,mcmc,2000,1,800,6,,1 +1000,20,10,3,2,4,mcmc,2000,1,800,7,,0 +1000,20,10,3,2,4,mcmc,2000,1,800,7,,1 +1000,20,10,3,2,4,mcmc,2000,2,800,5,,0 +1000,20,10,3,2,4,mcmc,2000,2,800,5,,1 +1000,20,10,3,2,4,mcmc,2000,2,800,6,,0 +1000,20,10,3,2,4,mcmc,2000,2,800,6,,1 +1000,20,10,3,2,4,mcmc,2000,2,800,7,,0 +1000,20,10,3,2,4,mcmc,2000,2,800,7,,1 +1000,20,10,3,2,4,mcmc,2000,4,800,5,,0 +1000,20,10,3,2,4,mcmc,2000,4,800,5,,1 +1000,20,10,3,2,4,mcmc,2000,4,800,6,,0 +1000,20,10,3,2,4,mcmc,2000,4,800,6,,1 +1000,20,10,3,2,4,mcmc,2000,4,800,7,,0 +1000,20,10,3,2,4,mcmc,2000,4,800,7,,1 +1000,20,10,3,2,4,mcmc,500,1,200,5,,0 +1000,20,10,3,2,4,mcmc,500,1,200,5,,1 +1000,20,10,3,2,4,mcmc,500,1,200,6,,0 +1000,20,10,3,2,4,mcmc,500,1,200,6,,1 +1000,20,10,3,2,4,mcmc,500,1,200,7,,0 +1000,20,10,3,2,4,mcmc,500,1,200,7,,1 +1000,20,10,3,2,4,mcmc,500,2,200,5,,0 +1000,20,10,3,2,4,mcmc,500,2,200,5,,1 +1000,20,10,3,2,4,mcmc,500,2,200,6,,0 +1000,20,10,3,2,4,mcmc,500,2,200,6,,1 +1000,20,10,3,2,4,mcmc,500,2,200,7,,0 +1000,20,10,3,2,4,mcmc,500,2,200,7,,1 +1000,20,10,3,2,4,mcmc,500,4,200,5,,0 +1000,20,10,3,2,4,mcmc,500,4,200,5,,1 +1000,20,10,3,2,4,mcmc,500,4,200,6,,0 +1000,20,10,3,2,4,mcmc,500,4,200,6,,1 +1000,20,10,3,2,4,mcmc,500,4,200,7,,0 +1000,20,10,3,2,4,mcmc,500,4,200,7,,1 +1000,20,10,3,2,4,svi,1000,,,,1000,0 +1000,20,10,3,2,4,svi,1000,,,,1000,1 +1000,20,10,3,2,4,svi,1000,,,,10000,0 +1000,20,10,3,2,4,svi,1000,,,,10000,1 +1000,20,10,3,2,4,svi,1000,,,,2000,0 +1000,20,10,3,2,4,svi,1000,,,,2000,1 +1000,20,10,3,2,4,svi,1000,,,,5000,0 +1000,20,10,3,2,4,svi,1000,,,,5000,1 diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index 90e53eb..266cf43 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -161,8 +161,8 @@ def main(args): result["times"] = {"infer": t1 - t0, "predict": t2 - t1} result["rusage"] = resource.getrusage(resource.RUSAGE_SELF) - if args.output: - with open(args.output, "wb") as f: + if args.outfile: + with open(args.outfile, "wb") as f: pickle.dump(result, f) return result @@ -204,7 +204,7 @@ def main(args): parser.add_argument("--jit", action="store_true", default=True) parser.add_argument("--nojit", action="store_false", dest="jit") parser.add_argument("--verbose", action="store_true") - parser.add_argument("--output") + parser.add_argument("--outfile") args = parser.parse_args() args.population = int(args.population) # to allow e.g. --population=1e6 From 72e845f45f3b36ddd1889fe476e218e31bbbff3e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 07:56:01 -0700 Subject: [PATCH 03/26] Simplify rng seed --- 2020-06-compartmental/Makefile | 4 +- 2020-06-compartmental/runner.py | 20 ++--- 2020-06-compartmental/short_uni_synth.csv | 95 ++++++++--------------- 3 files changed, 46 insertions(+), 73 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index eaa8907..f4bda17 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -14,7 +14,9 @@ short_uni_synth: FORCE --script-filename=uni_synth.py \ --args-filename=short_uni_synth.csv \ --cores-per-worker=4 \ - --shuffle + --rng-seed=0,1 \ + --shuffle \ + --dry-run clean: FORCE rm -rf temp logs errors diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index ad9c3a5..a81f7b4 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -8,7 +8,6 @@ import random import subprocess import sys -from collections import OrderedDict CPUS = multiprocessing.cpu_count() ROOT = os.path.dirname(os.path.abspath(__file__)) @@ -62,8 +61,10 @@ def main(args): header = next(reader) tasks = [] for row in reader: - command_args = OrderedDict(sorted((k, v) for k, v in zip(header, row) if v)) - tasks.append((args, command_args)) + spec = {k: v for k, v in zip(header, row) if v} + for seed in args.rng_seed.split(","): + spec["rng-seed"] = seed + tasks.append((args, spec.copy())) if args.shuffle: random.shuffle(tasks) @@ -72,7 +73,7 @@ def main(args): else: print("Running {} tasks on {} workers".format(len(tasks), args.num_workers)) map_ = multiprocessing.Pool(args.num_workers).map - results = map_(work, tasks) + results = list(map_(work, tasks)) assert all(results) results.sort() @@ -84,12 +85,13 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="experiment runner") - parser.add_argument("-s", "--script-filename") - parser.add_argument("-a", "--args-filename") - parser.add_argument("-w", "--num-workers", type=int, default=CPUS) - parser.add_argument("-cpw", "--cores-per-worker", type=int) + parser.add_argument("--script-filename") + parser.add_argument("--args-filename") + parser.add_argument("--rng-seed", default="0") + parser.add_argument("--num-workers", type=int, default=CPUS) + parser.add_argument("--cores-per-worker", type=int) parser.add_argument("--shuffle", action="store_true") - parser.add_argument("-f", "--force", action="store_true") + parser.add_argument("--force", action="store_true") parser.add_argument("--dry-run", action="store_true") parser.add_argument("--outfile") args = parser.parse_args() diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv index a9f3b23..3ab476c 100644 --- a/2020-06-compartmental/short_uni_synth.csv +++ b/2020-06-compartmental/short_uni_synth.csv @@ -1,63 +1,32 @@ -population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,max-tree-depth,svi-steps,rng-seed -1000,20,10,3,2,4,mcmc,1000,1,400,5,,0 -1000,20,10,3,2,4,mcmc,1000,1,400,5,,1 -1000,20,10,3,2,4,mcmc,1000,1,400,6,,0 -1000,20,10,3,2,4,mcmc,1000,1,400,6,,1 -1000,20,10,3,2,4,mcmc,1000,1,400,7,,0 -1000,20,10,3,2,4,mcmc,1000,1,400,7,,1 -1000,20,10,3,2,4,mcmc,1000,2,400,5,,0 -1000,20,10,3,2,4,mcmc,1000,2,400,5,,1 -1000,20,10,3,2,4,mcmc,1000,2,400,6,,0 -1000,20,10,3,2,4,mcmc,1000,2,400,6,,1 -1000,20,10,3,2,4,mcmc,1000,2,400,7,,0 -1000,20,10,3,2,4,mcmc,1000,2,400,7,,1 -1000,20,10,3,2,4,mcmc,1000,4,400,5,,0 -1000,20,10,3,2,4,mcmc,1000,4,400,5,,1 -1000,20,10,3,2,4,mcmc,1000,4,400,6,,0 -1000,20,10,3,2,4,mcmc,1000,4,400,6,,1 -1000,20,10,3,2,4,mcmc,1000,4,400,7,,0 -1000,20,10,3,2,4,mcmc,1000,4,400,7,,1 -1000,20,10,3,2,4,mcmc,2000,1,800,5,,0 -1000,20,10,3,2,4,mcmc,2000,1,800,5,,1 -1000,20,10,3,2,4,mcmc,2000,1,800,6,,0 -1000,20,10,3,2,4,mcmc,2000,1,800,6,,1 -1000,20,10,3,2,4,mcmc,2000,1,800,7,,0 -1000,20,10,3,2,4,mcmc,2000,1,800,7,,1 -1000,20,10,3,2,4,mcmc,2000,2,800,5,,0 -1000,20,10,3,2,4,mcmc,2000,2,800,5,,1 -1000,20,10,3,2,4,mcmc,2000,2,800,6,,0 -1000,20,10,3,2,4,mcmc,2000,2,800,6,,1 -1000,20,10,3,2,4,mcmc,2000,2,800,7,,0 -1000,20,10,3,2,4,mcmc,2000,2,800,7,,1 -1000,20,10,3,2,4,mcmc,2000,4,800,5,,0 -1000,20,10,3,2,4,mcmc,2000,4,800,5,,1 -1000,20,10,3,2,4,mcmc,2000,4,800,6,,0 -1000,20,10,3,2,4,mcmc,2000,4,800,6,,1 -1000,20,10,3,2,4,mcmc,2000,4,800,7,,0 -1000,20,10,3,2,4,mcmc,2000,4,800,7,,1 -1000,20,10,3,2,4,mcmc,500,1,200,5,,0 -1000,20,10,3,2,4,mcmc,500,1,200,5,,1 -1000,20,10,3,2,4,mcmc,500,1,200,6,,0 -1000,20,10,3,2,4,mcmc,500,1,200,6,,1 -1000,20,10,3,2,4,mcmc,500,1,200,7,,0 -1000,20,10,3,2,4,mcmc,500,1,200,7,,1 -1000,20,10,3,2,4,mcmc,500,2,200,5,,0 -1000,20,10,3,2,4,mcmc,500,2,200,5,,1 -1000,20,10,3,2,4,mcmc,500,2,200,6,,0 -1000,20,10,3,2,4,mcmc,500,2,200,6,,1 -1000,20,10,3,2,4,mcmc,500,2,200,7,,0 -1000,20,10,3,2,4,mcmc,500,2,200,7,,1 -1000,20,10,3,2,4,mcmc,500,4,200,5,,0 -1000,20,10,3,2,4,mcmc,500,4,200,5,,1 -1000,20,10,3,2,4,mcmc,500,4,200,6,,0 -1000,20,10,3,2,4,mcmc,500,4,200,6,,1 -1000,20,10,3,2,4,mcmc,500,4,200,7,,0 -1000,20,10,3,2,4,mcmc,500,4,200,7,,1 -1000,20,10,3,2,4,svi,1000,,,,1000,0 -1000,20,10,3,2,4,svi,1000,,,,1000,1 -1000,20,10,3,2,4,svi,1000,,,,10000,0 -1000,20,10,3,2,4,svi,1000,,,,10000,1 -1000,20,10,3,2,4,svi,1000,,,,2000,0 -1000,20,10,3,2,4,svi,1000,,,,2000,1 -1000,20,10,3,2,4,svi,1000,,,,5000,0 -1000,20,10,3,2,4,svi,1000,,,,5000,1 +population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,max-tree-depth,svi-steps +1000,20,10,3,2,4,mcmc,1000,1,400,5 +1000,20,10,3,2,4,mcmc,1000,1,400,6 +1000,20,10,3,2,4,mcmc,1000,1,400,7 +1000,20,10,3,2,4,mcmc,1000,2,400,5 +1000,20,10,3,2,4,mcmc,1000,2,400,6 +1000,20,10,3,2,4,mcmc,1000,2,400,7 +1000,20,10,3,2,4,mcmc,1000,4,400,5 +1000,20,10,3,2,4,mcmc,1000,4,400,6 +1000,20,10,3,2,4,mcmc,1000,4,400,7 +1000,20,10,3,2,4,mcmc,2000,1,800,5 +1000,20,10,3,2,4,mcmc,2000,1,800,6 +1000,20,10,3,2,4,mcmc,2000,1,800,7 +1000,20,10,3,2,4,mcmc,2000,2,800,5 +1000,20,10,3,2,4,mcmc,2000,2,800,6 +1000,20,10,3,2,4,mcmc,2000,2,800,7 +1000,20,10,3,2,4,mcmc,2000,4,800,5 +1000,20,10,3,2,4,mcmc,2000,4,800,6 +1000,20,10,3,2,4,mcmc,2000,4,800,7 +1000,20,10,3,2,4,mcmc,500,1,200,5 +1000,20,10,3,2,4,mcmc,500,1,200,6 +1000,20,10,3,2,4,mcmc,500,1,200,7 +1000,20,10,3,2,4,mcmc,500,2,200,5 +1000,20,10,3,2,4,mcmc,500,2,200,6 +1000,20,10,3,2,4,mcmc,500,2,200,7 +1000,20,10,3,2,4,mcmc,500,4,200,5 +1000,20,10,3,2,4,mcmc,500,4,200,6 +1000,20,10,3,2,4,mcmc,500,4,200,7 +1000,20,10,3,2,4,svi,1000,,,,1000 +1000,20,10,3,2,4,svi,1000,,,,10000 +1000,20,10,3,2,4,svi,1000,,,,2000 +1000,20,10,3,2,4,svi,1000,,,,5000 From 890432b51fa310a00a6b1fe7fdc3055a4c526be2 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 08:26:20 -0700 Subject: [PATCH 04/26] Fix bugs --- 2020-06-compartmental/Makefile | 3 +-- 2020-06-compartmental/runner.py | 4 ++++ 2020-06-compartmental/uni_synth.py | 13 ++++++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index f4bda17..200041f 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -15,8 +15,7 @@ short_uni_synth: FORCE --args-filename=short_uni_synth.csv \ --cores-per-worker=4 \ --rng-seed=0,1 \ - --shuffle \ - --dry-run + --shuffle clean: FORCE rm -rf temp logs errors diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index a81f7b4..56f97b7 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -80,6 +80,10 @@ def main(args): if args.outfile: with open(args.outfile, "w") as f: f.write("\n".join(results)) + + print("-------------------------") + print("COMPLETED {} TASKS".format(len(tasks))) + print("-------------------------") return results diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index 266cf43..9cf9dac 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -83,6 +83,15 @@ def generate_data(args): .format(max_obs, args.max_obs_portion)) +def _item(x): + if isinstance(x, torch.Tensor): + x = x.reshape(-1).median().item() + elif isinstance(x, dict): + for key, value in x.items(): + x[key] = _item(value) + return x + + def infer_mcmc(args, model): parallel = args.num_chains > 1 @@ -99,9 +108,7 @@ def infer_mcmc(args, model): jit_compile=args.jit) result = summary(mcmc._samples) - for key, value in result.items(): - if isinstance(value, torch.Tensor): - result[key] = value.reshape(-1).median().values.item() + result = _item(result) return result From 588a20e83ee45de58b5b8e2e9f61416951fc6b91 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 11:00:04 -0700 Subject: [PATCH 05/26] Add more metadata, shrink grid --- 2020-06-compartmental/runner.py | 13 +++++++++---- 2020-06-compartmental/short_uni_synth.csv | 9 --------- 2020-06-compartmental/uni_synth.py | 21 +++++++++++++-------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index 56f97b7..09382a1 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -56,10 +56,10 @@ def work(task): def main(args): + tasks = [] with open(args.args_filename) as f: reader = csv.reader(f) header = next(reader) - tasks = [] for row in reader: spec = {k: v for k, v in zip(header, row) if v} for seed in args.rng_seed.split(","): @@ -67,6 +67,7 @@ def main(args): tasks.append((args, spec.copy())) if args.shuffle: random.shuffle(tasks) + num_tasks = len(tasks) if args.num_workers == 1: map_ = map @@ -74,7 +75,10 @@ def main(args): print("Running {} tasks on {} workers".format(len(tasks), args.num_workers)) map_ = multiprocessing.Pool(args.num_workers).map results = list(map_(work, tasks)) - assert all(results) + if args.skip: + tasks = [t for t in tasks if t] + else: + assert all(results) results.sort() if args.outfile: @@ -82,7 +86,7 @@ def main(args): f.write("\n".join(results)) print("-------------------------") - print("COMPLETED {} TASKS".format(len(tasks))) + print("COMPLETED {}/{} TASKS".format(len(tasks), num_tasks)) print("-------------------------") return results @@ -91,13 +95,14 @@ def main(args): parser = argparse.ArgumentParser(description="experiment runner") parser.add_argument("--script-filename") parser.add_argument("--args-filename") + parser.add_argument("--outfile") parser.add_argument("--rng-seed", default="0") parser.add_argument("--num-workers", type=int, default=CPUS) parser.add_argument("--cores-per-worker", type=int) parser.add_argument("--shuffle", action="store_true") parser.add_argument("--force", action="store_true") + parser.add_argument("--skip", action="store_true") parser.add_argument("--dry-run", action="store_true") - parser.add_argument("--outfile") args = parser.parse_args() if args.cores_per_worker: diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv index 3ab476c..4fc17f0 100644 --- a/2020-06-compartmental/short_uni_synth.csv +++ b/2020-06-compartmental/short_uni_synth.csv @@ -1,22 +1,13 @@ population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,max-tree-depth,svi-steps 1000,20,10,3,2,4,mcmc,1000,1,400,5 1000,20,10,3,2,4,mcmc,1000,1,400,6 -1000,20,10,3,2,4,mcmc,1000,1,400,7 1000,20,10,3,2,4,mcmc,1000,2,400,5 1000,20,10,3,2,4,mcmc,1000,2,400,6 -1000,20,10,3,2,4,mcmc,1000,2,400,7 1000,20,10,3,2,4,mcmc,1000,4,400,5 1000,20,10,3,2,4,mcmc,1000,4,400,6 -1000,20,10,3,2,4,mcmc,1000,4,400,7 1000,20,10,3,2,4,mcmc,2000,1,800,5 -1000,20,10,3,2,4,mcmc,2000,1,800,6 -1000,20,10,3,2,4,mcmc,2000,1,800,7 1000,20,10,3,2,4,mcmc,2000,2,800,5 -1000,20,10,3,2,4,mcmc,2000,2,800,6 -1000,20,10,3,2,4,mcmc,2000,2,800,7 1000,20,10,3,2,4,mcmc,2000,4,800,5 -1000,20,10,3,2,4,mcmc,2000,4,800,6 -1000,20,10,3,2,4,mcmc,2000,4,800,7 1000,20,10,3,2,4,mcmc,500,1,200,5 1000,20,10,3,2,4,mcmc,500,1,200,6 1000,20,10,3,2,4,mcmc,500,1,200,7 diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index 9cf9dac..abd63ac 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -4,19 +4,23 @@ import argparse import logging import math -import resource import pickle +import resource +import sys from timeit import default_timer import torch import pyro -from pyro.contrib.epidemiology.models import (HeterogeneousSIRModel, - OverdispersedSEIRModel, - OverdispersedSIRModel, - SimpleSEIRModel, SimpleSIRModel, - SuperspreadingSEIRModel, - SuperspreadingSIRModel) +from pyro.contrib.epidemiology.models import ( + HeterogeneousSIRModel, + OverdispersedSEIRModel, + OverdispersedSIRModel, + SimpleSEIRModel, + SimpleSIRModel, + SuperspreadingSEIRModel, + SuperspreadingSIRModel +) from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse from pyro.infer.mcmc.util import summary @@ -148,7 +152,7 @@ def main(args): pyro.enable_validation(__debug__) pyro.set_rng_seed(args.rng_seed + 20200617) - result = {} + result = {"args": args, "file": __file__, "argv": sys.argv} truth = generate_data(args) @@ -171,6 +175,7 @@ def main(args): if args.outfile: with open(args.outfile, "wb") as f: pickle.dump(result, f) + logging.info("DONE") return result From 692ce917451f3eeeca29103cf81665bbf25adcd7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 11:35:59 -0700 Subject: [PATCH 06/26] Allow --force to update metadata --- 2020-06-compartmental/runner.py | 2 +- 2020-06-compartmental/uni_synth.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index 09382a1..b88d3f4 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -31,7 +31,7 @@ def work(task): "_".join("{}={}".format(k, v) for k, v in spec.items())) result_file = os.path.join(RESULTS, basename + ".pkl") if os.path.exists(result_file) and not args.force: - return True + return result_file temp_file = os.path.join(TEMP, basename + ".pkl") log_file = os.path.join(LOGS, basename + ".txt") diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index abd63ac..c352825 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -4,6 +4,7 @@ import argparse import logging import math +import os import pickle import resource import sys @@ -153,6 +154,14 @@ def main(args): pyro.set_rng_seed(args.rng_seed + 20200617) result = {"args": args, "file": __file__, "argv": sys.argv} + if args.outfile and os.path.exists(args.outfile): + # Simply update metadata. + with open(args.outfile, "rb") as f: + result.update(pickle.load(f)) + with open(args.outfile, "wb") as f: + pickle.dump(result, f) + logging.info("DONE") + return result truth = generate_data(args) From e95477cb24d4a730f0a05a623204c12775c4825e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 16:22:50 -0700 Subject: [PATCH 07/26] Split mcmc versus svi --- 2020-06-compartmental/Makefile | 12 ++++++-- 2020-06-compartmental/runner.py | 11 ++++--- 2020-06-compartmental/short_uni_synth.csv | 37 +++++++++-------------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index 200041f..ed611ea 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -10,11 +10,17 @@ watch: FORCE short_uni_synth: FORCE python runner.py \ - --outfile=results/short_uni_synth \ + --outfile=results/short_uni_synth_mcmc \ --script-filename=uni_synth.py \ --args-filename=short_uni_synth.csv \ - --cores-per-worker=4 \ - --rng-seed=0,1 \ + --rng-seed=0,1,2,3,4,5,6,7,8,9 \ + --shuffle + python runner.py \ + --outfile=results/short_uni_synth_svi \ + --script-filename=uni_synth.py \ + --args-filename=short_uni_synth.csv \ + --cores-per-worker=1 \ + --rng-seed=0,1,2 \ --shuffle clean: FORCE diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index b88d3f4..50ae70d 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -10,6 +10,7 @@ import sys CPUS = multiprocessing.cpu_count() +ENV = os.environ.copy() ROOT = os.path.dirname(os.path.abspath(__file__)) TEMP = os.path.join(ROOT, "temp") LOGS = os.path.join(ROOT, "logs") @@ -32,6 +33,8 @@ def work(task): result_file = os.path.join(RESULTS, basename + ".pkl") if os.path.exists(result_file) and not args.force: return result_file + elif args.skip: + return None temp_file = os.path.join(TEMP, basename + ".pkl") log_file = os.path.join(LOGS, basename + ".txt") @@ -43,7 +46,7 @@ def work(task): return result_file try: with open(log_file, "w") as f: - subprocess.check_call(command, stderr=f, stdout=f) + subprocess.check_call(command, stderr=f, stdout=f, env=ENV) os.rename(temp_file, result_file) # Use rename to make write atomic. return result_file except subprocess.CalledProcessError as e: @@ -67,7 +70,6 @@ def main(args): tasks.append((args, spec.copy())) if args.shuffle: random.shuffle(tasks) - num_tasks = len(tasks) if args.num_workers == 1: map_ = map @@ -76,7 +78,7 @@ def main(args): map_ = multiprocessing.Pool(args.num_workers).map results = list(map_(work, tasks)) if args.skip: - tasks = [t for t in tasks if t] + results = [r for r in results if r is not None] else: assert all(results) @@ -86,7 +88,7 @@ def main(args): f.write("\n".join(results)) print("-------------------------") - print("COMPLETED {}/{} TASKS".format(len(tasks), num_tasks)) + print("COMPLETED {}/{} TASKS".format(len(results), len(tasks))) print("-------------------------") return results @@ -107,6 +109,7 @@ def main(args): if args.cores_per_worker: args.num_workers = max(1, CPUS // args.cores_per_worker) + ENV["OMP_NUM_THREAD"] = min(CPUS, 2 * args.cores_per_worker) if args.dry_run: args.num_workers = 1 diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv index 4fc17f0..c7e2152 100644 --- a/2020-06-compartmental/short_uni_synth.csv +++ b/2020-06-compartmental/short_uni_synth.csv @@ -1,23 +1,14 @@ -population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,max-tree-depth,svi-steps -1000,20,10,3,2,4,mcmc,1000,1,400,5 -1000,20,10,3,2,4,mcmc,1000,1,400,6 -1000,20,10,3,2,4,mcmc,1000,2,400,5 -1000,20,10,3,2,4,mcmc,1000,2,400,6 -1000,20,10,3,2,4,mcmc,1000,4,400,5 -1000,20,10,3,2,4,mcmc,1000,4,400,6 -1000,20,10,3,2,4,mcmc,2000,1,800,5 -1000,20,10,3,2,4,mcmc,2000,2,800,5 -1000,20,10,3,2,4,mcmc,2000,4,800,5 -1000,20,10,3,2,4,mcmc,500,1,200,5 -1000,20,10,3,2,4,mcmc,500,1,200,6 -1000,20,10,3,2,4,mcmc,500,1,200,7 -1000,20,10,3,2,4,mcmc,500,2,200,5 -1000,20,10,3,2,4,mcmc,500,2,200,6 -1000,20,10,3,2,4,mcmc,500,2,200,7 -1000,20,10,3,2,4,mcmc,500,4,200,5 -1000,20,10,3,2,4,mcmc,500,4,200,6 -1000,20,10,3,2,4,mcmc,500,4,200,7 -1000,20,10,3,2,4,svi,1000,,,,1000 -1000,20,10,3,2,4,svi,1000,,,,10000 -1000,20,10,3,2,4,svi,1000,,,,2000 -1000,20,10,3,2,4,svi,1000,,,,5000 +population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps +1000,20,10,3,2,4,mcmc,1000,1,400 +1000,20,10,3,2,4,mcmc,1000,2,400 +1000,20,10,3,2,4,mcmc,1000,4,400 +1000,20,10,3,2,4,mcmc,2000,1,800 +1000,20,10,3,2,4,mcmc,2000,2,800 +1000,20,10,3,2,4,mcmc,2000,4,800 +1000,20,10,3,2,4,mcmc,500,1,200 +1000,20,10,3,2,4,mcmc,500,2,200 +1000,20,10,3,2,4,mcmc,500,4,200 +1000,20,10,3,2,4,svi,1000,,,1000 +1000,20,10,3,2,4,svi,1000,,,10000 +1000,20,10,3,2,4,svi,1000,,,2000 +1000,20,10,3,2,4,svi,1000,,,5000 From f8fc452149ac45b59de5224b58409becabadfda6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 16:23:13 -0700 Subject: [PATCH 08/26] Fix typo --- 2020-06-compartmental/Makefile | 4 ++-- 2020-06-compartmental/short_uni_synth_mcmc.csv | 10 ++++++++++ 2020-06-compartmental/short_uni_synth_svi.csv | 5 +++++ 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 2020-06-compartmental/short_uni_synth_mcmc.csv create mode 100644 2020-06-compartmental/short_uni_synth_svi.csv diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index ed611ea..a15243d 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -10,13 +10,13 @@ watch: FORCE short_uni_synth: FORCE python runner.py \ - --outfile=results/short_uni_synth_mcmc \ + --outfile=results/short_uni_synth_svi \ --script-filename=uni_synth.py \ --args-filename=short_uni_synth.csv \ --rng-seed=0,1,2,3,4,5,6,7,8,9 \ --shuffle python runner.py \ - --outfile=results/short_uni_synth_svi \ + --outfile=results/short_uni_synth_mcmc \ --script-filename=uni_synth.py \ --args-filename=short_uni_synth.csv \ --cores-per-worker=1 \ diff --git a/2020-06-compartmental/short_uni_synth_mcmc.csv b/2020-06-compartmental/short_uni_synth_mcmc.csv new file mode 100644 index 0000000..6bdb7a6 --- /dev/null +++ b/2020-06-compartmental/short_uni_synth_mcmc.csv @@ -0,0 +1,10 @@ +population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps +1000,20,10,3,2,4,mcmc,1000,1,400 +1000,20,10,3,2,4,mcmc,1000,2,400 +1000,20,10,3,2,4,mcmc,1000,4,400 +1000,20,10,3,2,4,mcmc,2000,1,800 +1000,20,10,3,2,4,mcmc,2000,2,800 +1000,20,10,3,2,4,mcmc,2000,4,800 +1000,20,10,3,2,4,mcmc,500,1,200 +1000,20,10,3,2,4,mcmc,500,2,200 +1000,20,10,3,2,4,mcmc,500,4,200 diff --git a/2020-06-compartmental/short_uni_synth_svi.csv b/2020-06-compartmental/short_uni_synth_svi.csv new file mode 100644 index 0000000..8a8d75e --- /dev/null +++ b/2020-06-compartmental/short_uni_synth_svi.csv @@ -0,0 +1,5 @@ +population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps +1000,20,10,3,2,4,svi,1000,,,1000 +1000,20,10,3,2,4,svi,1000,,,10000 +1000,20,10,3,2,4,svi,1000,,,2000 +1000,20,10,3,2,4,svi,1000,,,5000 From 53139d5a87049a8a918109af9e477f92c824ebba Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 16:24:28 -0700 Subject: [PATCH 09/26] Fix typo --- 2020-06-compartmental/Makefile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index a15243d..83cc0d5 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -10,15 +10,13 @@ watch: FORCE short_uni_synth: FORCE python runner.py \ - --outfile=results/short_uni_synth_svi \ --script-filename=uni_synth.py \ - --args-filename=short_uni_synth.csv \ + --args-filename=short_uni_synth_svi.csv \ --rng-seed=0,1,2,3,4,5,6,7,8,9 \ --shuffle python runner.py \ - --outfile=results/short_uni_synth_mcmc \ --script-filename=uni_synth.py \ - --args-filename=short_uni_synth.csv \ + --args-filename=short_uni_synth_mcmc.csv \ --cores-per-worker=1 \ --rng-seed=0,1,2 \ --shuffle From 9c538e807e37be3892dd5a555d4da4f7d241ebb6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 16:27:05 -0700 Subject: [PATCH 10/26] Reduce cost --- 2020-06-compartmental/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index 83cc0d5..d068b64 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -12,7 +12,7 @@ short_uni_synth: FORCE python runner.py \ --script-filename=uni_synth.py \ --args-filename=short_uni_synth_svi.csv \ - --rng-seed=0,1,2,3,4,5,6,7,8,9 \ + --rng-seed=0,1,2,3,4 \ --shuffle python runner.py \ --script-filename=uni_synth.py \ From 3d3f0aa279c871d781bd30422f26d3683fbfdb09 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 17:33:26 -0700 Subject: [PATCH 11/26] Simplify --- 2020-06-compartmental/short_uni_synth.csv | 14 -------------- 2020-06-compartmental/short_uni_synth_mcmc.csv | 10 ---------- 2020-06-compartmental/short_uni_synth_svi.csv | 5 ----- 3 files changed, 29 deletions(-) delete mode 100644 2020-06-compartmental/short_uni_synth.csv delete mode 100644 2020-06-compartmental/short_uni_synth_mcmc.csv delete mode 100644 2020-06-compartmental/short_uni_synth_svi.csv diff --git a/2020-06-compartmental/short_uni_synth.csv b/2020-06-compartmental/short_uni_synth.csv deleted file mode 100644 index c7e2152..0000000 --- a/2020-06-compartmental/short_uni_synth.csv +++ /dev/null @@ -1,14 +0,0 @@ -population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps -1000,20,10,3,2,4,mcmc,1000,1,400 -1000,20,10,3,2,4,mcmc,1000,2,400 -1000,20,10,3,2,4,mcmc,1000,4,400 -1000,20,10,3,2,4,mcmc,2000,1,800 -1000,20,10,3,2,4,mcmc,2000,2,800 -1000,20,10,3,2,4,mcmc,2000,4,800 -1000,20,10,3,2,4,mcmc,500,1,200 -1000,20,10,3,2,4,mcmc,500,2,200 -1000,20,10,3,2,4,mcmc,500,4,200 -1000,20,10,3,2,4,svi,1000,,,1000 -1000,20,10,3,2,4,svi,1000,,,10000 -1000,20,10,3,2,4,svi,1000,,,2000 -1000,20,10,3,2,4,svi,1000,,,5000 diff --git a/2020-06-compartmental/short_uni_synth_mcmc.csv b/2020-06-compartmental/short_uni_synth_mcmc.csv deleted file mode 100644 index 6bdb7a6..0000000 --- a/2020-06-compartmental/short_uni_synth_mcmc.csv +++ /dev/null @@ -1,10 +0,0 @@ -population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps -1000,20,10,3,2,4,mcmc,1000,1,400 -1000,20,10,3,2,4,mcmc,1000,2,400 -1000,20,10,3,2,4,mcmc,1000,4,400 -1000,20,10,3,2,4,mcmc,2000,1,800 -1000,20,10,3,2,4,mcmc,2000,2,800 -1000,20,10,3,2,4,mcmc,2000,4,800 -1000,20,10,3,2,4,mcmc,500,1,200 -1000,20,10,3,2,4,mcmc,500,2,200 -1000,20,10,3,2,4,mcmc,500,4,200 diff --git a/2020-06-compartmental/short_uni_synth_svi.csv b/2020-06-compartmental/short_uni_synth_svi.csv deleted file mode 100644 index 8a8d75e..0000000 --- a/2020-06-compartmental/short_uni_synth_svi.csv +++ /dev/null @@ -1,5 +0,0 @@ -population,duration,forecast,R0,incubation-time,recovery-time,infer,num-samples,num-bins,warmup-steps,svi-steps -1000,20,10,3,2,4,svi,1000,,,1000 -1000,20,10,3,2,4,svi,1000,,,10000 -1000,20,10,3,2,4,svi,1000,,,2000 -1000,20,10,3,2,4,svi,1000,,,5000 From 8422b0a8109e4fd18c226da08b0d51a4edfde15f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 17:50:38 -0700 Subject: [PATCH 12/26] Add mcmc grid points --- 2020-06-compartmental/.gitignore | 3 - 2020-06-compartmental/Makefile | 17 +--- 2020-06-compartmental/runner.py | 122 +++++++---------------------- 2020-06-compartmental/uni_synth.py | 39 +++++---- 4 files changed, 53 insertions(+), 128 deletions(-) diff --git a/2020-06-compartmental/.gitignore b/2020-06-compartmental/.gitignore index 234a969..fbca225 100644 --- a/2020-06-compartmental/.gitignore +++ b/2020-06-compartmental/.gitignore @@ -1,4 +1 @@ -temp/ -logs/ -errors/ results/ diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index d068b64..bb98437 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -5,26 +5,13 @@ all: lint lint: FORCE flake8 -watch: FORCE - ls -t logs/* | head -n 1 | xargs tail -f - short_uni_synth: FORCE - python runner.py \ - --script-filename=uni_synth.py \ - --args-filename=short_uni_synth_svi.csv \ - --rng-seed=0,1,2,3,4 \ - --shuffle - python runner.py \ - --script-filename=uni_synth.py \ - --args-filename=short_uni_synth_mcmc.csv \ - --cores-per-worker=1 \ - --rng-seed=0,1,2 \ - --shuffle + python runner.py --experiment=short_uni_synth clean: FORCE rm -rf temp logs errors mrclean: FORCE - rm -rf temp logs errors results + rm -rf results FORCE: diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index 50ae70d..bb59e67 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -2,115 +2,49 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import csv -import multiprocessing -import os -import random import subprocess import sys -CPUS = multiprocessing.cpu_count() -ENV = os.environ.copy() -ROOT = os.path.dirname(os.path.abspath(__file__)) -TEMP = os.path.join(ROOT, "temp") -LOGS = os.path.join(ROOT, "logs") -ERRORS = os.path.join(ROOT, "errors") -RESULTS = os.path.join(ROOT, "results") -# Ensure directories exist. -for path in [TEMP, LOGS, ERRORS, RESULTS]: - if not os.path.exists(path): - try: - os.makedirs(path) - except OSError: - assert os.path.exists(path) - - -def work(task): - args, spec = task - basename = (args.script_filename + "." + - "_".join("{}={}".format(k, v) for k, v in spec.items())) - result_file = os.path.join(RESULTS, basename + ".pkl") - if os.path.exists(result_file) and not args.force: - return result_file - elif args.skip: - return None - - temp_file = os.path.join(TEMP, basename + ".pkl") - log_file = os.path.join(LOGS, basename + ".txt") - spec["outfile"] = temp_file - command = ([sys.executable, args.script_filename] + - ["--{}={}".format(k, v) for k, v in spec.items()]) - print(" ".join(command)) - if args.dry_run: - return result_file - try: - with open(log_file, "w") as f: - subprocess.check_call(command, stderr=f, stdout=f, env=ENV) - os.rename(temp_file, result_file) # Use rename to make write atomic. - return result_file - except subprocess.CalledProcessError as e: - pdb_command = [sys.executable, "-m", "pdb", "-cc"] + command[1:-1] - msg = "{}\nTo reproduce, run:\n{}".format(e, " \\\n ".join(pdb_command)) - print(msg) - with open(os.path.join(ERRORS, basename + ".txt"), "w") as f: - f.write(msg) - return None +def short_uni_synth(): + base = [ + sys.executable, + "uni_synth.py", + "--population=1000", + "--duration=20", "--forecast=10", + "--R0=3", "--incubation-time=2", "--recovery-time=4", + ] + for svi_steps in [1000, 2000, 5000, 10000]: + for rng_seed in range(5): + yield base + ["--svi", + "--num-samples=1000", + f"--svi-steps={svi_steps}", + f"--rng-seed={rng_seed}"] + for num_bins in [1, 2, 4]: + for num_samples in [200, 500, 1000]: + for rng_seed in range(1): + yield base + ["--mcmc", + "--warmup-steps=200", + f"--num-samples={num_samples}", + f"--rng-seed={rng_seed}"] def main(args): - tasks = [] - with open(args.args_filename) as f: - reader = csv.reader(f) - header = next(reader) - for row in reader: - spec = {k: v for k, v in zip(header, row) if v} - for seed in args.rng_seed.split(","): - spec["rng-seed"] = seed - tasks.append((args, spec.copy())) - if args.shuffle: - random.shuffle(tasks) - - if args.num_workers == 1: - map_ = map - else: - print("Running {} tasks on {} workers".format(len(tasks), args.num_workers)) - map_ = multiprocessing.Pool(args.num_workers).map - results = list(map_(work, tasks)) - if args.skip: - results = [r for r in results if r is not None] - else: - assert all(results) - - results.sort() - if args.outfile: - with open(args.outfile, "w") as f: - f.write("\n".join(results)) + tasks = list(globals()[args.experiment]()) + for task in tasks: + print(" \\\n ".join(task)) + if not args.dry_run: + subprocess.check_call(task) print("-------------------------") - print("COMPLETED {}/{} TASKS".format(len(results), len(tasks))) + print("COMPLETED {} TASKS".format(len(tasks))) print("-------------------------") - return results if __name__ == "__main__": parser = argparse.ArgumentParser(description="experiment runner") - parser.add_argument("--script-filename") - parser.add_argument("--args-filename") - parser.add_argument("--outfile") - parser.add_argument("--rng-seed", default="0") - parser.add_argument("--num-workers", type=int, default=CPUS) - parser.add_argument("--cores-per-worker", type=int) - parser.add_argument("--shuffle", action="store_true") - parser.add_argument("--force", action="store_true") - parser.add_argument("--skip", action="store_true") + parser.add_argument("--experiment") parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() - if args.cores_per_worker: - args.num_workers = max(1, CPUS // args.cores_per_worker) - ENV["OMP_NUM_THREAD"] = min(CPUS, 2 * args.cores_per_worker) - if args.dry_run: - args.num_workers = 1 - main(args) diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index c352825..e485d38 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -8,6 +8,7 @@ import pickle import resource import sys +from hashlib import sha1 from timeit import default_timer import torch @@ -29,6 +30,16 @@ logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) logging.basicConfig(format=fmt, level=logging.INFO) +ROOT = os.path.dirname(os.path.abspath(__file__)) +RESULTS = os.path.join(ROOT, "results") + +# Ensure directories exist. +if not os.path.exists(RESULTS): + try: + os.makedirs(RESULTS) + except OSError: + assert os.path.exists(RESULTS) + def Model(args, data): """Dispatch between different model classes.""" @@ -153,15 +164,7 @@ def main(args): pyro.enable_validation(__debug__) pyro.set_rng_seed(args.rng_seed + 20200617) - result = {"args": args, "file": __file__, "argv": sys.argv} - if args.outfile and os.path.exists(args.outfile): - # Simply update metadata. - with open(args.outfile, "rb") as f: - result.update(pickle.load(f)) - with open(args.outfile, "wb") as f: - pickle.dump(result, f) - logging.info("DONE") - return result + result = {"file": __file__, "args": args, "argv": sys.argv} truth = generate_data(args) @@ -180,10 +183,6 @@ def main(args): result["evaluate"] = evaluate(args, truth, model, samples) result["times"] = {"infer": t1 - t0, "predict": t2 - t1} result["rusage"] = resource.getrusage(resource.RUSAGE_SELF) - - if args.outfile: - with open(args.outfile, "wb") as f: - pickle.dump(result, f) logging.info("DONE") return result @@ -225,10 +224,10 @@ def main(args): parser.add_argument("--jit", action="store_true", default=True) parser.add_argument("--nojit", action="store_false", dest="jit") parser.add_argument("--verbose", action="store_true") - parser.add_argument("--outfile") + + # Parse args. args = parser.parse_args() args.population = int(args.population) # to allow e.g. --population=1e6 - if args.warmup_steps is None: args.warmup_steps = args.num_samples if args.double: @@ -239,4 +238,12 @@ def main(args): elif args.cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) - main(args) + # Cache output. + unique = __file__, sorted(args.__dict__.items()) + fingerprint = sha1(str(unique).encode()).hexdigest() + outfile = os.path.join(RESULTS, fingerprint + ".pkl") + if not os.path.exists(outfile): + result = main(args) + with open(outfile, "wb") as f: + pickle.dump(result, f) + logging.info("Saved {}".format(outfile)) From 361ab0eaa38cc7483a57d86ed6f607e63328d15f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 18 Jun 2020 21:34:08 -0700 Subject: [PATCH 13/26] Add analysis notebook and long_uni_synth experiment --- 2020-06-compartmental/analyze.ipynb | 210 ++++++++++++++++++++++++++++ 2020-06-compartmental/runner.py | 41 +++++- 2 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 2020-06-compartmental/analyze.ipynb diff --git a/2020-06-compartmental/analyze.ipynb b/2020-06-compartmental/analyze.ipynb new file mode 100644 index 0000000..6c74aaa --- /dev/null +++ b/2020-06-compartmental/analyze.ipynb @@ -0,0 +1,210 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "import json\n", + "import pickle\n", + "from collections import defaultdict\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = []\n", + "for filename in glob.glob(\"results/*.pkl\"):\n", + " with open(filename, \"rb\") as f:\n", + " results.append(pickle.load(f))\n", + "print(len(results))\n", + "print(results[0].keys())\n", + "print(results[0][\"times\"].keys())\n", + "print(results[0][\"evaluate\"].keys())\n", + "print(results[0][\"evaluate\"][\"R0\"].keys())\n", + "print(results[0][\"infer\"].keys())\n", + "print(results[0][\"infer\"][\"R0\"].keys())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_accuracy(variable, metric):\n", + " view = defaultdict(list)\n", + " for result in results:\n", + " args = result['args']\n", + " view[args.infer, args.num_bins, args.svi_steps].append(result)\n", + " markers = [\"o\", \"d\", \"s\", \"<\", \"v\", \"^\", \">\"]\n", + " assert len(view) == len(markers)\n", + "\n", + " plt.figure(figsize=(6, 5)).patch.set_color(\"white\")\n", + " for (key, value), marker in zip(sorted(view.items()), markers):\n", + " algo, num_bins, svi_steps = key\n", + " if algo == \"svi\":\n", + " label = \"SVI steps={}\".format(svi_steps)\n", + " elif algo == \"mcmc\":\n", + " if num_bins == 1:\n", + " label = \"MCMC relaxed\"\n", + " else:\n", + " label = \"MCMC num_bins={}\".format(num_bins)\n", + " X = [v[\"times\"][\"infer\"] for v in value]\n", + " Y = [v[\"evaluate\"][variable][metric] for v in value]\n", + " plt.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", + " plt.ylim(0, None)\n", + " plt.xscale(\"log\")\n", + " plt.xlabel(\"inference time (sec)\")\n", + " plt.ylabel(metric)\n", + " plt.title(f\"{variable} accuracy\")\n", + " plt.legend(loc=\"best\", prop={'size': 8})\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_accuracy(\"R0\", \"crps\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_accuracy(\"rho\", \"crps\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "plot_accuracy(\"obs\", \"crps\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_accuracy(\"I\", \"crps\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_convergence(variable, metrics=[\"n_eff\", \"r_hat\"]):\n", + " view = defaultdict(list)\n", + " for result in results:\n", + " args = result['args']\n", + " if args.infer == \"mcmc\":\n", + " view[args.num_bins].append(result)\n", + " markers = [\"o\", \"d\", \"s\"]\n", + " assert len(view) == len(markers)\n", + "\n", + " fig, axes = plt.subplots(len(metrics), 1, figsize=(6, 5), sharex=True)\n", + " fig.patch.set_color(\"white\")\n", + " for (num_bins, value), marker in zip(sorted(view.items()), markers):\n", + " if num_bins == 1:\n", + " label = \"MCMC relaxed\"\n", + " else:\n", + " label = \"MCMC num_bins={}\".format(num_bins)\n", + " X = [v[\"times\"][\"infer\"] for v in value]\n", + " for metric, ax in zip(metrics, axes):\n", + " Y = [v[\"infer\"][variable][metric] for v in value]\n", + " ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", + " ax.set_xscale(\"log\")\n", + " ax.set_ylabel(metric)\n", + " axes[0].set_title(f\"{variable} convergence\")\n", + " axes[-1].legend(loc=\"best\", prop={'size': 8})\n", + " axes[-1].set_xlabel(\"inference time (sec)\")\n", + " plt.subplots_adjust(hspace=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_convergence(\"R0\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_convergence(\"rho\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_convergence(\"auxiliary_haar_split_0\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index bb59e67..f5b66b3 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -15,17 +15,54 @@ def short_uni_synth(): "--R0=3", "--incubation-time=2", "--recovery-time=4", ] for svi_steps in [1000, 2000, 5000, 10000]: - for rng_seed in range(5): + for rng_seed in range(10): yield base + ["--svi", "--num-samples=1000", f"--svi-steps={svi_steps}", f"--rng-seed={rng_seed}"] for num_bins in [1, 2, 4]: for num_samples in [200, 500, 1000]: - for rng_seed in range(1): + if num_bins > 1: + num_seeds = 2 + elif num_samples > 200: + num_seeds = 5 + else: + num_seeds = 10 + for rng_seed in range(num_seeds): yield base + ["--mcmc", "--warmup-steps=200", f"--num-samples={num_samples}", + f"--num-bins={num_bins}", + f"--rng-seed={rng_seed}"] + + +def long_uni_synth(): + base = [ + sys.executable, + "uni_synth.py", + "--population=100000", + "--duration=100", "--forecast=30", + "--R0=2.5", "--incubation-time=4", "--recovery-time=10", + ] + for svi_steps in [1000, 2000, 5000, 10000]: + for rng_seed in range(10): + yield base + ["--svi", + "--num-samples=1000", + f"--svi-steps={svi_steps}", + f"--rng-seed={rng_seed}"] + for num_bins in [1, 2, 4]: + for num_samples in [200, 500, 1000]: + if num_bins > 1: + num_seeds = 2 + elif num_samples > 200: + num_seeds = 5 + else: + num_seeds = 10 + for rng_seed in range(num_seeds): + yield base + ["--mcmc", + "--warmup-steps=200", + f"--num-samples={num_samples}", + f"--num-bins={num_bins}", f"--rng-seed={rng_seed}"] From add527f4b0fa6c0e63cd7d68c529c9a668fd07b8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 07:29:54 -0700 Subject: [PATCH 14/26] Add more grid points --- 2020-06-compartmental/Makefile | 3 ++ 2020-06-compartmental/analyze.ipynb | 79 ++++++++++++----------------- 2020-06-compartmental/runner.py | 2 - 3 files changed, 36 insertions(+), 48 deletions(-) diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index bb98437..c481616 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -8,6 +8,9 @@ lint: FORCE short_uni_synth: FORCE python runner.py --experiment=short_uni_synth +long_uni_synth: FORCE + python runner.py --experiment=long_uni_synth + clean: FORCE rm -rf temp logs errors diff --git a/2020-06-compartmental/analyze.ipynb b/2020-06-compartmental/analyze.ipynb index 6c74aaa..8695a7f 100644 --- a/2020-06-compartmental/analyze.ipynb +++ b/2020-06-compartmental/analyze.ipynb @@ -49,32 +49,34 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_accuracy(variable, metric):\n", + "def plot_accuracy(variable, metric, **kwargs):\n", " view = defaultdict(list)\n", " for result in results:\n", " args = result['args']\n", - " view[args.infer, args.num_bins, args.svi_steps].append(result)\n", + " if all(getattr(args, k) == v for k, v in kwargs.items()):\n", + " view[args.infer, args.num_bins, args.svi_steps].append(result)\n", " markers = [\"o\", \"d\", \"s\", \"<\", \"v\", \"^\", \">\"]\n", - " assert len(view) == len(markers)\n", + " assert len(view) <= len(markers)\n", "\n", " plt.figure(figsize=(6, 5)).patch.set_color(\"white\")\n", " for (key, value), marker in zip(sorted(view.items()), markers):\n", " algo, num_bins, svi_steps = key\n", " if algo == \"svi\":\n", - " label = \"SVI steps={}\".format(svi_steps)\n", + " label = f\"SVI steps={svi_steps}\"\n", " elif algo == \"mcmc\":\n", " if num_bins == 1:\n", " label = \"MCMC relaxed\"\n", " else:\n", - " label = \"MCMC num_bins={}\".format(num_bins)\n", + " label = f\"MCMC num_bins={num_bins}\"\n", " X = [v[\"times\"][\"infer\"] for v in value]\n", " Y = [v[\"evaluate\"][variable][metric] for v in value]\n", " plt.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", " plt.ylim(0, None)\n", " plt.xscale(\"log\")\n", " plt.xlabel(\"inference time (sec)\")\n", - " plt.ylabel(metric)\n", - " plt.title(f\"{variable} accuracy\")\n", + " plt.ylabel(metric.upper())\n", + " plt.title(\", \".join([f\"{variable} accuracy\"] +\n", + " [f\"{k}={v}\" for k, v in sorted(kwargs.items())]))\n", " plt.legend(loc=\"best\", prop={'size': 8})\n", " plt.tight_layout()" ] @@ -85,7 +87,10 @@ "metadata": {}, "outputs": [], "source": [ - "plot_accuracy(\"R0\", \"crps\")" + "plot_accuracy(\"R0\", \"crps\", duration=20)\n", + "plot_accuracy(\"rho\", \"crps\", duration=20)\n", + "plot_accuracy(\"obs\", \"crps\", duration=20)\n", + "plot_accuracy(\"I\", \"crps\", duration=20)" ] }, { @@ -94,27 +99,10 @@ "metadata": {}, "outputs": [], "source": [ - "plot_accuracy(\"rho\", \"crps\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "plot_accuracy(\"obs\", \"crps\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_accuracy(\"I\", \"crps\")" + "plot_accuracy(\"R0\", \"crps\", duration=100)\n", + "plot_accuracy(\"rho\", \"crps\", duration=100)\n", + "plot_accuracy(\"obs\", \"crps\", duration=100)\n", + "plot_accuracy(\"I\", \"crps\", duration=100)" ] }, { @@ -123,14 +111,15 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_convergence(variable, metrics=[\"n_eff\", \"r_hat\"]):\n", + "def plot_convergence(variable, metrics=[\"n_eff\", \"r_hat\"], **kwargs):\n", " view = defaultdict(list)\n", " for result in results:\n", " args = result['args']\n", - " if args.infer == \"mcmc\":\n", - " view[args.num_bins].append(result)\n", + " if all(getattr(args, k) == v for k, v in kwargs.items()):\n", + " if args.infer == \"mcmc\":\n", + " view[args.num_bins].append(result)\n", " markers = [\"o\", \"d\", \"s\"]\n", - " assert len(view) == len(markers)\n", + " assert len(view) <= len(markers)\n", "\n", " fig, axes = plt.subplots(len(metrics), 1, figsize=(6, 5), sharex=True)\n", " fig.patch.set_color(\"white\")\n", @@ -138,14 +127,17 @@ " if num_bins == 1:\n", " label = \"MCMC relaxed\"\n", " else:\n", - " label = \"MCMC num_bins={}\".format(num_bins)\n", + " label = f\"MCMC num_bins={num_bins}\"\n", " X = [v[\"times\"][\"infer\"] for v in value]\n", " for metric, ax in zip(metrics, axes):\n", " Y = [v[\"infer\"][variable][metric] for v in value]\n", " ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", " ax.set_xscale(\"log\")\n", " ax.set_ylabel(metric)\n", - " axes[0].set_title(f\"{variable} convergence\")\n", + " axes[0].set_title(\", \".join([f\"{variable} convergence\"] +\n", + " [f\"{k}={v}\" for k, v in sorted(kwargs.items())]))\n", + " axes[1].set_yscale(\"log\")\n", + " axes[1].set_ylim(1, None)\n", " axes[-1].legend(loc=\"best\", prop={'size': 8})\n", " axes[-1].set_xlabel(\"inference time (sec)\")\n", " plt.subplots_adjust(hspace=0)" @@ -157,16 +149,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot_convergence(\"R0\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_convergence(\"rho\")" + "plot_convergence(\"R0\", duration=20)\n", + "plot_convergence(\"rho\", duration=20)\n", + "plot_convergence(\"auxiliary_haar_split_0\", duration=20)" ] }, { @@ -175,7 +160,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot_convergence(\"auxiliary_haar_split_0\")" + "plot_convergence(\"R0\", duration=100)\n", + "plot_convergence(\"rho\", duration=100)\n", + "plot_convergence(\"auxiliary_haar_split_0\", duration=100)" ] }, { diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index f5b66b3..a893c30 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -54,8 +54,6 @@ def long_uni_synth(): for num_samples in [200, 500, 1000]: if num_bins > 1: num_seeds = 2 - elif num_samples > 200: - num_seeds = 5 else: num_seeds = 10 for rng_seed in range(num_seeds): From 149eb31a3885731c8a0c4ede17e7884a724ce597 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 09:15:11 -0700 Subject: [PATCH 15/26] Add more test points --- 2020-06-compartmental/runner.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index a893c30..ebb68a4 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -22,12 +22,10 @@ def short_uni_synth(): f"--rng-seed={rng_seed}"] for num_bins in [1, 2, 4]: for num_samples in [200, 500, 1000]: - if num_bins > 1: - num_seeds = 2 - elif num_samples > 200: - num_seeds = 5 - else: + if num_bins == 1: num_seeds = 10 + else: + num_seeds = 2 for rng_seed in range(num_seeds): yield base + ["--mcmc", "--warmup-steps=200", @@ -50,13 +48,16 @@ def long_uni_synth(): "--num-samples=1000", f"--svi-steps={svi_steps}", f"--rng-seed={rng_seed}"] - for num_bins in [1, 2, 4]: + for num_samples in [200, 500, 1000, 2000, 5000]: + for rng_seed in range(10): + yield base + ["--mcmc", + "--warmup-steps=200", + "--num-bins=1", + f"--num-samples={num_samples}", + f"--rng-seed={rng_seed}"] + for num_bins in [2, 4]: for num_samples in [200, 500, 1000]: - if num_bins > 1: - num_seeds = 2 - else: - num_seeds = 10 - for rng_seed in range(num_seeds): + for rng_seed in range(2): yield base + ["--mcmc", "--warmup-steps=200", f"--num-samples={num_samples}", From 6d934fec44bd97115846bba8edced30d92862af1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 11:05:13 -0700 Subject: [PATCH 16/26] Increase warmup_steps, speed up cache check --- 2020-06-compartmental/runner.py | 33 ++++++++-- 2020-06-compartmental/uni_synth.py | 102 +++++++++++++++-------------- 2 files changed, 82 insertions(+), 53 deletions(-) diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index ebb68a4..e13e9d0 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -1,9 +1,15 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os import argparse import subprocess import sys +from hashlib import sha1 +from importlib import import_module + +ROOT = os.path.dirname(os.path.abspath(__file__)) +RESULTS = os.path.join(ROOT, "results") def short_uni_synth(): @@ -22,13 +28,14 @@ def short_uni_synth(): f"--rng-seed={rng_seed}"] for num_bins in [1, 2, 4]: for num_samples in [200, 500, 1000]: + num_warmup = max(200, int(round(0.4 * num_samples))) if num_bins == 1: num_seeds = 10 else: num_seeds = 2 for rng_seed in range(num_seeds): yield base + ["--mcmc", - "--warmup-steps=200", + f"--warmup-steps={num_warmup}", f"--num-samples={num_samples}", f"--num-bins={num_bins}", f"--rng-seed={rng_seed}"] @@ -49,17 +56,19 @@ def long_uni_synth(): f"--svi-steps={svi_steps}", f"--rng-seed={rng_seed}"] for num_samples in [200, 500, 1000, 2000, 5000]: + num_warmup = max(200, int(round(0.4 * num_samples))) for rng_seed in range(10): yield base + ["--mcmc", - "--warmup-steps=200", "--num-bins=1", + f"--warmup-steps={num_warmup}", f"--num-samples={num_samples}", f"--rng-seed={rng_seed}"] for num_bins in [2, 4]: for num_samples in [200, 500, 1000]: + num_warmup = max(200, int(round(0.4 * num_samples))) for rng_seed in range(2): yield base + ["--mcmc", - "--warmup-steps=200", + f"--warmup-steps={num_warmup}", f"--num-samples={num_samples}", f"--num-bins={num_bins}", f"--rng-seed={rng_seed}"] @@ -69,8 +78,22 @@ def main(args): tasks = list(globals()[args.experiment]()) for task in tasks: print(" \\\n ".join(task)) - if not args.dry_run: - subprocess.check_call(task) + if args.dry_run: + continue + + # Optimization: Parse args to compute output filename and check for + # previous completion. This is equivalent to but much cheaper than + # creating a new process and checking in the process. + script = task[1] + parser = import_module(script.replace(".py", "")).Parser() + args_dict = parser.parse_args(task[2:]).__dict__ + unique = script, sorted(args_dict.items()) + fingerprint = sha1(str(unique).encode()).hexdigest() + outfile = os.path.join(RESULTS, fingerprint + ".pkl") + if os.path.exists(outfile): + continue + + subprocess.check_call(task) print("-------------------------") print("COMPLETED {} TASKS".format(len(tasks))) diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index e485d38..47df97c 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -187,56 +187,62 @@ def main(args): return result +class Parser(argparse.ArgumentParser): + def __init__(self): + super().__init__(description="CompartmentalModel experiments") + self.add_argument("--population", default=1000, type=float) + self.add_argument("--min-obs-portion", default=0.1, type=float) + self.add_argument("--max-obs-portion", default=0.3, type=float) + self.add_argument("--duration", default=20, type=int) + self.add_argument("--forecast", default=10, type=int) + self.add_argument("--R0", default=1.5, type=float) + self.add_argument("--recovery-time", default=7.0, type=float) + self.add_argument("--incubation-time", default=0.0, type=float) + self.add_argument("--concentration", default=math.inf, type=float) + self.add_argument("--response-rate", default=0.5, type=float) + self.add_argument("--overdispersion", default=0., type=float) + self.add_argument("--heterogeneous", action="store_true") + self.add_argument("--infer", default="mcmc") + self.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") + self.add_argument("--svi", action="store_const", const="svi", dest="infer") + self.add_argument("--haar", action="store_true") + self.add_argument("--nohaar", action="store_const", const=False, dest="haar") + self.add_argument("--haar-full-mass", default=10, type=int) + self.add_argument("--num-samples", default=200, type=int) + self.add_argument("--smc-particles", default=1024, type=int) + self.add_argument("--svi-steps", default=5000, type=int) + self.add_argument("--svi-particles", default=32, type=int) + self.add_argument("--warmup-steps", type=int) + self.add_argument("--num-chains", default=2, type=int) + self.add_argument("--max-tree-depth", default=5, type=int) + self.add_argument("--arrowhead-mass", action="store_true") + self.add_argument("--rng-seed", default=0, type=int) + self.add_argument("--num-bins", default=1, type=int) + self.add_argument("--double", action="store_true", default=True) + self.add_argument("--single", action="store_false", dest="double") + self.add_argument("--cuda", action="store_true") + self.add_argument("--jit", action="store_true", default=True) + self.add_argument("--nojit", action="store_false", dest="jit") + self.add_argument("--verbose", action="store_true") + + def parse_args(self, *args, **kwargs): + args = super().parse_args(*args, **kwargs) + args.population = int(args.population) # to allow e.g. --population=1e6 + if args.warmup_steps is None: + args.warmup_steps = args.num_samples + if args.double: + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) + elif args.cuda: + torch.set_default_tensor_type(torch.cuda.FloatTensor) + return args + + if __name__ == "__main__": assert pyro.__version__.startswith('1.3.1') - parser = argparse.ArgumentParser(description="CompartmentalModel experiments") - parser.add_argument("--population", default=1000, type=float) - parser.add_argument("--min-obs-portion", default=0.1, type=float) - parser.add_argument("--max-obs-portion", default=0.3, type=float) - parser.add_argument("--duration", default=20, type=int) - parser.add_argument("--forecast", default=10, type=int) - parser.add_argument("--R0", default=1.5, type=float) - parser.add_argument("--recovery-time", default=7.0, type=float) - parser.add_argument("--incubation-time", default=0.0, type=float) - parser.add_argument("--concentration", default=math.inf, type=float) - parser.add_argument("--response-rate", default=0.5, type=float) - parser.add_argument("--overdispersion", default=0., type=float) - parser.add_argument("--heterogeneous", action="store_true") - parser.add_argument("--infer", default="mcmc") - parser.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") - parser.add_argument("--svi", action="store_const", const="svi", dest="infer") - parser.add_argument("--haar", action="store_true") - parser.add_argument("--nohaar", action="store_const", const=False, dest="haar") - parser.add_argument("--haar-full-mass", default=10, type=int) - parser.add_argument("--num-samples", default=200, type=int) - parser.add_argument("--smc-particles", default=1024, type=int) - parser.add_argument("--svi-steps", default=5000, type=int) - parser.add_argument("--svi-particles", default=32, type=int) - parser.add_argument("--warmup-steps", type=int) - parser.add_argument("--num-chains", default=2, type=int) - parser.add_argument("--max-tree-depth", default=5, type=int) - parser.add_argument("--arrowhead-mass", action="store_true") - parser.add_argument("--rng-seed", default=0, type=int) - parser.add_argument("--num-bins", default=1, type=int) - parser.add_argument("--double", action="store_true", default=True) - parser.add_argument("--single", action="store_false", dest="double") - parser.add_argument("--cuda", action="store_true") - parser.add_argument("--jit", action="store_true", default=True) - parser.add_argument("--nojit", action="store_false", dest="jit") - parser.add_argument("--verbose", action="store_true") - - # Parse args. - args = parser.parse_args() - args.population = int(args.population) # to allow e.g. --population=1e6 - if args.warmup_steps is None: - args.warmup_steps = args.num_samples - if args.double: - if args.cuda: - torch.set_default_tensor_type(torch.cuda.DoubleTensor) - else: - torch.set_default_dtype(torch.float64) - elif args.cuda: - torch.set_default_tensor_type(torch.cuda.FloatTensor) + args = Parser().parse_args() # Cache output. unique = __file__, sorted(args.__dict__.items()) From e7793035c61ba68580f83f2462ee33e5462dc5e0 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 11:23:28 -0700 Subject: [PATCH 17/26] Refactor --- 2020-06-compartmental/runner.py | 10 ++++------ 2020-06-compartmental/uni_synth.py | 14 +++++--------- 2020-06-compartmental/util.py | 24 ++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 15 deletions(-) create mode 100644 2020-06-compartmental/util.py diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index e13e9d0..3b31ffd 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -1,13 +1,14 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import os import argparse +import os import subprocess import sys -from hashlib import sha1 from importlib import import_module +from util import get_filename + ROOT = os.path.dirname(os.path.abspath(__file__)) RESULTS = os.path.join(ROOT, "results") @@ -86,10 +87,7 @@ def main(args): # creating a new process and checking in the process. script = task[1] parser = import_module(script.replace(".py", "")).Parser() - args_dict = parser.parse_args(task[2:]).__dict__ - unique = script, sorted(args_dict.items()) - fingerprint = sha1(str(unique).encode()).hexdigest() - outfile = os.path.join(RESULTS, fingerprint + ".pkl") + outfile = get_filename(script, parser.parse_args(task[2:])) if os.path.exists(outfile): continue diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index 47df97c..b47ea68 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -8,7 +8,6 @@ import pickle import resource import sys -from hashlib import sha1 from timeit import default_timer import torch @@ -26,19 +25,18 @@ from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse from pyro.infer.mcmc.util import summary +from util import RESULTS, get_filename + fmt = '%(process)d %(message)s' logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) logging.basicConfig(format=fmt, level=logging.INFO) -ROOT = os.path.dirname(os.path.abspath(__file__)) -RESULTS = os.path.join(ROOT, "results") - # Ensure directories exist. if not os.path.exists(RESULTS): try: os.makedirs(RESULTS) - except OSError: - assert os.path.exists(RESULTS) + except FileExistsError: + pass def Model(args, data): @@ -245,9 +243,7 @@ def parse_args(self, *args, **kwargs): args = Parser().parse_args() # Cache output. - unique = __file__, sorted(args.__dict__.items()) - fingerprint = sha1(str(unique).encode()).hexdigest() - outfile = os.path.join(RESULTS, fingerprint + ".pkl") + outfile = get_filename(__file__, args) if not os.path.exists(outfile): result = main(args) with open(outfile, "wb") as f: diff --git a/2020-06-compartmental/util.py b/2020-06-compartmental/util.py new file mode 100644 index 0000000..26e6add --- /dev/null +++ b/2020-06-compartmental/util.py @@ -0,0 +1,24 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import os +from hashlib import sha1 + +ROOT = os.path.dirname(os.path.abspath(__file__)) +DATA = os.path.join(ROOT, "data") +RESULTS = os.path.join(ROOT, "results") + +# Ensure directories exist. +for path in [DATA, RESULTS]: + if not os.path.exists(path): + try: + os.makedirs(path) + except FileExistsError: + pass + + +def get_filename(script, args): + unique = script, sorted(args.__dict__.items()) + fingerprint = sha1(str(unique).encode()).hexdigest() + cachefile = os.path.join(RESULTS, fingerprint + ".pkl") + return cachefile From c536aa0abc6c848bd891c01200b5c1ebd504195d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 12:57:26 -0700 Subject: [PATCH 18/26] Refactor to create an Experiment class --- 2020-06-compartmental/.gitignore | 1 + 2020-06-compartmental/Makefile | 2 +- 2020-06-compartmental/analyze.ipynb | 75 +++++++++++++---------------- 2020-06-compartmental/runner.py | 58 ++++++++++++++-------- 2020-06-compartmental/uni_synth.py | 9 +--- 5 files changed, 74 insertions(+), 71 deletions(-) diff --git a/2020-06-compartmental/.gitignore b/2020-06-compartmental/.gitignore index fbca225..2385b03 100644 --- a/2020-06-compartmental/.gitignore +++ b/2020-06-compartmental/.gitignore @@ -1 +1,2 @@ +data/ results/ diff --git a/2020-06-compartmental/Makefile b/2020-06-compartmental/Makefile index c481616..912e74e 100644 --- a/2020-06-compartmental/Makefile +++ b/2020-06-compartmental/Makefile @@ -15,6 +15,6 @@ clean: FORCE rm -rf temp logs errors mrclean: FORCE - rm -rf results + rm -rf data results FORCE: diff --git a/2020-06-compartmental/analyze.ipynb b/2020-06-compartmental/analyze.ipynb index 8695a7f..a9cb51a 100644 --- a/2020-06-compartmental/analyze.ipynb +++ b/2020-06-compartmental/analyze.ipynb @@ -13,13 +13,10 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import glob\n", - "import json\n", - "import pickle\n", "from collections import defaultdict\n", - "\n", "import matplotlib.pyplot as plt\n", + "from runner import short_uni_synth, long_uni_synth\n", + "\n", "%matplotlib inline\n", "%config InlineBackend.figure_formats = ['svg']" ] @@ -30,17 +27,17 @@ "metadata": {}, "outputs": [], "source": [ - "results = []\n", - "for filename in glob.glob(\"results/*.pkl\"):\n", - " with open(filename, \"rb\") as f:\n", - " results.append(pickle.load(f))\n", + "results = list(short_uni_synth.results)\n", "print(len(results))\n", - "print(results[0].keys())\n", - "print(results[0][\"times\"].keys())\n", - "print(results[0][\"evaluate\"].keys())\n", - "print(results[0][\"evaluate\"][\"R0\"].keys())\n", - "print(results[0][\"infer\"].keys())\n", - "print(results[0][\"infer\"][\"R0\"].keys())" + "for r in results:\n", + " if r[\"args\"].infer == \"mcmc\":\n", + " break\n", + "print(r.keys())\n", + "print(r[\"times\"].keys())\n", + "print(r[\"evaluate\"].keys())\n", + "print(r[\"evaluate\"][\"R0\"].keys())\n", + "print(r[\"infer\"].keys())\n", + "print(r[\"infer\"][\"R0\"].keys())" ] }, { @@ -49,12 +46,11 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_accuracy(variable, metric, **kwargs):\n", + "def plot_accuracy(variable, metric, experiment):\n", " view = defaultdict(list)\n", - " for result in results:\n", + " for result in experiment.results:\n", " args = result['args']\n", - " if all(getattr(args, k) == v for k, v in kwargs.items()):\n", - " view[args.infer, args.num_bins, args.svi_steps].append(result)\n", + " view[args.infer, args.num_bins, args.svi_steps].append(result)\n", " markers = [\"o\", \"d\", \"s\", \"<\", \"v\", \"^\", \">\"]\n", " assert len(view) <= len(markers)\n", "\n", @@ -75,8 +71,7 @@ " plt.xscale(\"log\")\n", " plt.xlabel(\"inference time (sec)\")\n", " plt.ylabel(metric.upper())\n", - " plt.title(\", \".join([f\"{variable} accuracy\"] +\n", - " [f\"{k}={v}\" for k, v in sorted(kwargs.items())]))\n", + " plt.title(f\"{variable} accuracy ({experiment.__name__})\")\n", " plt.legend(loc=\"best\", prop={'size': 8})\n", " plt.tight_layout()" ] @@ -87,10 +82,10 @@ "metadata": {}, "outputs": [], "source": [ - "plot_accuracy(\"R0\", \"crps\", duration=20)\n", - "plot_accuracy(\"rho\", \"crps\", duration=20)\n", - "plot_accuracy(\"obs\", \"crps\", duration=20)\n", - "plot_accuracy(\"I\", \"crps\", duration=20)" + "plot_accuracy(\"R0\", \"crps\", short_uni_synth)\n", + "plot_accuracy(\"rho\", \"crps\", short_uni_synth)\n", + "plot_accuracy(\"obs\", \"crps\", short_uni_synth)\n", + "plot_accuracy(\"I\", \"crps\", short_uni_synth)" ] }, { @@ -99,10 +94,10 @@ "metadata": {}, "outputs": [], "source": [ - "plot_accuracy(\"R0\", \"crps\", duration=100)\n", - "plot_accuracy(\"rho\", \"crps\", duration=100)\n", - "plot_accuracy(\"obs\", \"crps\", duration=100)\n", - "plot_accuracy(\"I\", \"crps\", duration=100)" + "plot_accuracy(\"R0\", \"crps\", long_uni_synth)\n", + "plot_accuracy(\"rho\", \"crps\", long_uni_synth)\n", + "plot_accuracy(\"obs\", \"crps\", long_uni_synth)\n", + "plot_accuracy(\"I\", \"crps\", long_uni_synth)" ] }, { @@ -111,13 +106,12 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_convergence(variable, metrics=[\"n_eff\", \"r_hat\"], **kwargs):\n", + "def plot_convergence(variable, experiment, metrics=[\"n_eff\", \"r_hat\"]):\n", " view = defaultdict(list)\n", " for result in results:\n", " args = result['args']\n", - " if all(getattr(args, k) == v for k, v in kwargs.items()):\n", - " if args.infer == \"mcmc\":\n", - " view[args.num_bins].append(result)\n", + " if args.infer == \"mcmc\":\n", + " view[args.num_bins].append(result)\n", " markers = [\"o\", \"d\", \"s\"]\n", " assert len(view) <= len(markers)\n", "\n", @@ -134,8 +128,7 @@ " ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", " ax.set_xscale(\"log\")\n", " ax.set_ylabel(metric)\n", - " axes[0].set_title(\", \".join([f\"{variable} convergence\"] +\n", - " [f\"{k}={v}\" for k, v in sorted(kwargs.items())]))\n", + " axes[0].set_title(f\"{variable} convergence ({experiment.__name__})\")\n", " axes[1].set_yscale(\"log\")\n", " axes[1].set_ylim(1, None)\n", " axes[-1].legend(loc=\"best\", prop={'size': 8})\n", @@ -149,9 +142,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot_convergence(\"R0\", duration=20)\n", - "plot_convergence(\"rho\", duration=20)\n", - "plot_convergence(\"auxiliary_haar_split_0\", duration=20)" + "plot_convergence(\"R0\", short_uni_synth)\n", + "plot_convergence(\"rho\", short_uni_synth)\n", + "plot_convergence(\"auxiliary_haar_split_0\", short_uni_synth)" ] }, { @@ -160,9 +153,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot_convergence(\"R0\", duration=100)\n", - "plot_convergence(\"rho\", duration=100)\n", - "plot_convergence(\"auxiliary_haar_split_0\", duration=100)" + "plot_convergence(\"R0\", long_uni_synth)\n", + "plot_convergence(\"rho\", long_uni_synth)\n", + "plot_convergence(\"auxiliary_haar_split_0\", long_uni_synth)" ] }, { diff --git a/2020-06-compartmental/runner.py b/2020-06-compartmental/runner.py index 3b31ffd..e1380e7 100644 --- a/2020-06-compartmental/runner.py +++ b/2020-06-compartmental/runner.py @@ -3,19 +3,45 @@ import argparse import os +import pickle import subprocess import sys from importlib import import_module from util import get_filename -ROOT = os.path.dirname(os.path.abspath(__file__)) -RESULTS = os.path.join(ROOT, "results") +class Experiment: + """ + An experiment consists of a collection of tasks. + Each task generates a datapoint by running a python script. + Result datapoints are cached in pickle files named by fingerprint. + """ + def __init__(self, generate_tasks): + self.__name__ = generate_tasks.__name__ + self.tasks = [[sys.executable] + task for task in generate_tasks()] + self.files = [] + for task in self.tasks: + script = task[1] + parser = import_module(script.replace(".py", "")).Parser() + outfile = get_filename(script, parser.parse_args(task[2:])) + self.files.append(outfile) + @property + def results(self): + """ + Iterates over the subset of experiment results that have been generated. + """ + for outfile in self.files: + if os.path.exists(outfile): + with open(outfile, "rb") as f: + result = pickle.load(f) + yield result + + +@Experiment def short_uni_synth(): base = [ - sys.executable, "uni_synth.py", "--population=1000", "--duration=20", "--forecast=10", @@ -29,7 +55,7 @@ def short_uni_synth(): f"--rng-seed={rng_seed}"] for num_bins in [1, 2, 4]: for num_samples in [200, 500, 1000]: - num_warmup = max(200, int(round(0.4 * num_samples))) + num_warmup = int(round(0.4 * num_samples)) if num_bins == 1: num_seeds = 10 else: @@ -42,9 +68,9 @@ def short_uni_synth(): f"--rng-seed={rng_seed}"] +@Experiment def long_uni_synth(): base = [ - sys.executable, "uni_synth.py", "--population=100000", "--duration=100", "--forecast=30", @@ -57,7 +83,7 @@ def long_uni_synth(): f"--svi-steps={svi_steps}", f"--rng-seed={rng_seed}"] for num_samples in [200, 500, 1000, 2000, 5000]: - num_warmup = max(200, int(round(0.4 * num_samples))) + num_warmup = int(round(0.4 * num_samples)) for rng_seed in range(10): yield base + ["--mcmc", "--num-bins=1", @@ -66,7 +92,7 @@ def long_uni_synth(): f"--rng-seed={rng_seed}"] for num_bins in [2, 4]: for num_samples in [200, 500, 1000]: - num_warmup = max(200, int(round(0.4 * num_samples))) + num_warmup = int(round(0.4 * num_samples)) for rng_seed in range(2): yield base + ["--mcmc", f"--warmup-steps={num_warmup}", @@ -76,25 +102,15 @@ def long_uni_synth(): def main(args): - tasks = list(globals()[args.experiment]()) - for task in tasks: + experiment = globals()[args.experiment] + for task, outfile in zip(experiment.tasks, experiment.files): print(" \\\n ".join(task)) - if args.dry_run: + if args.dry_run or os.path.exists(outfile): continue - - # Optimization: Parse args to compute output filename and check for - # previous completion. This is equivalent to but much cheaper than - # creating a new process and checking in the process. - script = task[1] - parser = import_module(script.replace(".py", "")).Parser() - outfile = get_filename(script, parser.parse_args(task[2:])) - if os.path.exists(outfile): - continue - subprocess.check_call(task) print("-------------------------") - print("COMPLETED {} TASKS".format(len(tasks))) + print("COMPLETED {} TASKS".format(len(experiment.tasks))) print("-------------------------") diff --git a/2020-06-compartmental/uni_synth.py b/2020-06-compartmental/uni_synth.py index b47ea68..5bb3ccd 100644 --- a/2020-06-compartmental/uni_synth.py +++ b/2020-06-compartmental/uni_synth.py @@ -25,19 +25,12 @@ from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse from pyro.infer.mcmc.util import summary -from util import RESULTS, get_filename +from util import get_filename fmt = '%(process)d %(message)s' logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) logging.basicConfig(format=fmt, level=logging.INFO) -# Ensure directories exist. -if not os.path.exists(RESULTS): - try: - os.makedirs(RESULTS) - except FileExistsError: - pass - def Model(args, data): """Dispatch between different model classes.""" From 34cb9f0220864ddb0564c03b45ff5d071f12709a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 19 Jun 2020 18:45:32 -0700 Subject: [PATCH 19/26] Add first draft of real data experiment --- 2020-06-compartmental/uni_real.py | 393 ++++++++++++++++++++++++++++++ 1 file changed, 393 insertions(+) create mode 100644 2020-06-compartmental/uni_real.py diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py new file mode 100644 index 0000000..6208c50 --- /dev/null +++ b/2020-06-compartmental/uni_real.py @@ -0,0 +1,393 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging +import os +import pickle +import resource +import sys +import urllib.request +from timeit import default_timer +from collections import OrderedDict + +import pandas as pd +import torch + +import pyro +import pyro.distributions as dist +from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist +from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse +from pyro.infer.mcmc.util import summary +from pyro.ops.tensor_utils import convolve + +from util import DATA, get_filename + +fmt = '%(process)d %(message)s' +logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) +logging.basicConfig(format=fmt, level=logging.INFO) + + +# Bay area county populations. +counties = OrderedDict([ + ("Santa Clara", 1763000), + ("Alameda", 1495000), + ("Contra Costa", 1038000), + ("San Francisco", 871000), + ("San Mateo", 712000), + ("Sonoma", 479000), + ("Solano", 412000), + ("Marin", 251000), + ("Napa", 135000), +]) + + +def load_df(basename): + url = ("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/" + "csse_covid_19_data/csse_covid_19_time_series/") + local_path = os.path.join(DATA, basename) + if not os.path.exists(local_path): + urllib.request.urlretrieve(url + basename, local_path) + return pd.read_csv(local_path) + + +def load_data(args): + cum_cases_df = load_df("time_series_covid19_confirmed_US.csv") + cum_deaths_df = load_df("time_series_covid19_deaths_US.csv") + + # Convert to torch.Tensor. + cum_cases = [] + cum_deaths = [] + for county in counties: + i = list(cum_cases_df["Admin2"]).index(county) + cum_cases.append(cum_cases_df.iloc[i, 11:]) + i = list(cum_deaths_df["Admin2"]).index(county) + cum_deaths.append(cum_deaths_df.iloc[i, 12:]) + cum_cases = torch.tensor(cum_cases, dtype=torch.get_default_dtype()).T.contiguous() + cum_deaths = torch.tensor(cum_deaths, dtype=torch.get_default_dtype()).T.contiguous() + assert cum_cases.shape == cum_deaths.shape + logging.info(f"Data shape = {tuple(cum_cases.shape)}, {tuple(cum_deaths.shape)}") + + # Convert from cumulative to difference data, and convolve to ensure positivity. + T = len(cum_cases) + for window in range(1, 100): + kernel = torch.ones(window) / window + smooth_cases = convolve(cum_cases.T, kernel).T[:T].round() + smooth_deaths = convolve(cum_deaths.T, kernel).T[:T].round() + new_cases = smooth_cases[1:] - smooth_cases[:-1] + new_deaths = smooth_deaths[1:] - smooth_deaths[:-1] + if (new_cases >= 0).all() and (new_deaths >= 0).all(): + break + logging.info(f"window = {window}, shape = {tuple(new_cases.shape)}") + + # Truncate and select a single county. + new_cases = new_cases[args.truncate:, args.county].contiguous() + new_deaths = new_deaths[args.truncate:, args.county].contiguous() + population = list(counties.values())[args.county] + + return {"population": population, + "new_cases": new_cases, + "new_deaths": new_deaths} + + +class Model(CompartmentalModel): + def __init__(self, args, population, new_cases, new_deaths): + assert new_cases.dim() == 1 + assert new_cases.shape == new_deaths.shape + duration = len(new_cases) + compartments = ("S", "E", "I") # R is implicit. + super().__init__(compartments, duration, population) + + self.incubation_time = args.incubation_time + self.recovery_time = args.recovery_time + self.new_cases = new_cases + self.new_deaths = new_deaths + + def global_model(self): + tau_e = self.incubation_time + tau_i = self.recovery_time + R0 = pyro.sample("R0", dist.LogNormal(1., 0.5)) # Weak prior. + external_rate = pyro.sample("external_rate", dist.LogNormal(-2, 2)) + rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. + mu = pyro.sample("mu", dist.Beta(1, 100)) # About 1% mortality rate. + drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) + od = pyro.sample("od", dist.Beta(2, 6)) + + return R0, external_rate, tau_e, tau_i, rho, mu, drift, od + + def initialize(self, params): + R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params + + # Start with no local infections and close to basic reproductive number. + return {"S": self.population, "E": 0, "I": 0, + "R_factor": torch.tensor(0.98)} + + def transition(self, params, state, t): + R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params + + # Assume effective reproductive number Rt varies in time. + sigmoid = torch.distributions.transforms.SigmoidTransform() + R_factor = pyro.sample("R_factor_{}".format(t), + dist.TransformedDistribution( + dist.Normal(sigmoid.inv(state["R_factor"]), drift), + sigmoid)) + Rt = pyro.deterministic("Rt_{}".format(t), R0 * R_factor, event_dim=0) + I_external = external_rate * tau_i / Rt + + # Sample flows between compartments. + S2E = pyro.sample("S2E_{}".format(t), + infection_dist(individual_rate=Rt / tau_i, + num_susceptible=state["S"], + num_infectious=state["I"] + I_external, + population=self.population, + overdispersion=od)) + E2I = pyro.sample("E2I_{}".format(t), + binomial_dist(state["E"], 1 / tau_e, + overdispersion=od)) + I2R = pyro.sample("I2R_{}".format(t), + binomial_dist(state["I"], 1 / tau_i, + overdispersion=od)) + + # Update compartments and heterogeneous variables. + state["S"] = state["S"] - S2E + state["E"] = state["E"] + S2E - E2I + state["I"] = state["I"] + E2I - I2R + state["R_factor"] = R_factor + + # Condition on observations. + t_is_observed = isinstance(t, slice) or t < self.duration + pyro.sample("new_cases_{}".format(t), + binomial_dist(S2E, rho, overdispersion=od), + obs=self.new_cases[t] if t_is_observed else None) + pyro.sample("new_deaths_{}".format(t), + binomial_dist(I2R, mu, overdispersion=od), + obs=self.new_deaths[t] if t_is_observed else None) + + +def _item(x): + if isinstance(x, torch.Tensor): + x = x.reshape(-1).median().item() + elif isinstance(x, dict): + for key, value in x.items(): + x[key] = _item(value) + return x + + +def infer_mcmc(args, model): + parallel = args.num_chains > 1 + + mcmc = model.fit_mcmc(heuristic_num_particles=args.smc_particles, + warmup_steps=args.warmup_steps, + num_samples=args.num_samples, + num_chains=args.num_chains, + mp_context="spawn" if parallel else None, + max_tree_depth=args.max_tree_depth, + num_quant_bins=args.num_bins, + haar=True, + haar_full_mass=args.haar_full_mass, + jit_compile=args.jit) + + result = summary(mcmc._samples) + result = _item(result) + return result + + +def infer_svi(args, model): + losses = model.fit_svi(heuristic_num_particles=args.smc_particles, + num_samples=args.num_samples, + num_steps=args.svi_steps, + num_particles=args.svi_particles, + jit=args.jit) + + return {"loss_initial": losses[0], "loss_final": losses[-1]} + + +def predict(args, model, truth): + samples = model.predict(forecast=args.forecast) + + if args.plot: + import matplotlib.pyplot as plt + fig, axes = plt.subplots(3, 1, figsize=(6, 8), sharex=True) + + # Plot forecasted series. + for name, ax in zip(["new_cases", "new_deaths"], axes): + pred = samples[name][..., model.duration:] + time = torch.arange(model.duration + args.forecast) + median = pred.median(dim=0).values + p05 = pred.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values + p95 = pred.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values + ax.fill_between(time[model.duration:], p05, p95, color="red", alpha=0.3, + label="90% CI") + ax.plot(time[model.duration:], median, "r-", label="median") + ax.plot(time, truth[name], "k--", label="truth") + ax.axvline(model.duration - 0.5, color="gray", lw=1) + ax.set_yscale("log") + ax.set_ylim(1, None) + ax.set_ylabel(f"{name} / day") + ax.legend(loc="upper left") + + # Plot Rt time series. + Rt = samples["Rt"] + median = Rt.median(dim=0).values + p05 = Rt.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values + p95 = Rt.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values + axes[2].fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") + axes[2].plot(time, median, "r-", label="median") + axes[2].axvline(model.duration - 0.5, color="gray", lw=1) + axes[2].set_ylim(0, None) + axes[2].set_ylabel("Rt") + axes[2].legend(loc="best") + + axes[-1].set_xlim(0, len(time) - 1) + axes[-1].set_xlabel("day") + axes[0].set_title("{}, population {}".format( + list(counties)[args.county], truth["population"])) + plt.tight_layout() + plt.subplots_adjust(hspace=0) + + return samples + + +def evaluate(args, truth, model, samples): + metrics = [("mae", eval_mae), ("rmse", eval_rmse), ("crps", eval_crps)] + result = {} + for key, pred in samples.items(): + if key in ("new_cases", "new_deaths"): + pred = pred[..., model.duration:] + + result[key] = {} + result[key]["mean"] = pred.mean().item() + result[key]["std"] = pred.std(dim=0).mean().item() + + if key in truth: + true = truth[key][..., model.duration:] + for metric, fn in metrics: + result[key][metric] = fn(pred, true) + + # Print estimated values. + covariates = [(name, value.squeeze()) + for name, value in sorted(samples.items()) + if value[0].numel() == 1] + for name, value in covariates: + mean = value.mean().item() + std = value.std().item() + logging.info(f"{name} = {mean:0.3g} \u00B1 {std:0.3g}") + + if args.plot: + # Plot pairwise joint distributions for selected variables. + import matplotlib.pyplot as plt + N = len(covariates) + fig, axes = plt.subplots(N, N, figsize=(8, 8), sharex="col", sharey="row") + for i in range(N): + axes[i][0].set_ylabel(covariates[i][0]) + axes[0][i].set_xlabel(covariates[i][0]) + axes[0][i].xaxis.set_label_position("top") + for j in range(N): + ax = axes[i][j] + ax.set_xticks(()) + ax.set_yticks(()) + ax.scatter(covariates[j][1], -covariates[i][1], + lw=0, color="darkblue", alpha=0.3) + plt.tight_layout() + plt.subplots_adjust(wspace=0, hspace=0) + + return result + + +def main(args): + pyro.enable_validation(__debug__) + pyro.set_rng_seed(args.rng_seed + 20200619) + + result = {"file": __file__, "args": args, "argv": sys.argv} + + truth = load_data(args) + result["data"] = { + "population": truth["population"], + "total_cases": truth["new_cases"].sum().item(), + "total_deaths": truth["new_deaths"].sum().item(), + "max_cases": truth["new_cases"].max().item(), + "max_deaths": truth["new_deaths"].max().item(), + } + + t0 = default_timer() + + model = Model(args, truth["population"], + truth["new_cases"][:-args.forecast], + truth["new_deaths"][:-args.forecast]) + infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer] + result["infer"] = infer(args, model) + + t1 = default_timer() + + samples = predict(args, model, truth) + + t2 = default_timer() + + result["evaluate"] = evaluate(args, truth, model, samples) + result["times"] = {"infer": t1 - t0, "predict": t2 - t1} + result["rusage"] = resource.getrusage(resource.RUSAGE_SELF) + logging.info("DONE") + return result + + +class Parser(argparse.ArgumentParser): + def __init__(self): + super().__init__(description="CompartmentalModel experiments") + self.add_argument("--county", default=0, type=int, + help="which SF Bay Area county, 0-8") + self.add_argument("--truncate", default=30, type=int) + self.add_argument("--forecast", default=14, type=int) + self.add_argument("--recovery-time", default=14.0, type=float) + self.add_argument("--incubation-time", default=5.5, type=float) + self.add_argument("--infer", default="svi") + self.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") + self.add_argument("--svi", action="store_const", const="svi", dest="infer") + self.add_argument("--haar-full-mass", default=10, type=int) + self.add_argument("--num-samples", default=200, type=int) + self.add_argument("--smc-particles", default=1024, type=int) + self.add_argument("--svi-steps", default=5000, type=int) + self.add_argument("--svi-particles", default=32, type=int) + self.add_argument("--warmup-steps", type=int) + self.add_argument("--num-chains", default=2, type=int) + self.add_argument("--max-tree-depth", default=5, type=int) + self.add_argument("--rng-seed", default=0, type=int) + self.add_argument("--num-bins", default=1, type=int) + self.add_argument("--double", action="store_true", default=True) + self.add_argument("--single", action="store_false", dest="double") + self.add_argument("--cuda", action="store_true") + self.add_argument("--jit", action="store_true", default=True) + self.add_argument("--nojit", action="store_false", dest="jit") + self.add_argument("--plot", action="store_true") + + def parse_args(self, *args, **kwargs): + args = super().parse_args(*args, **kwargs) + assert args.forecast > 0 + if args.warmup_steps is None: + args.warmup_steps = args.num_samples + if args.double: + if args.cuda: + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + else: + torch.set_default_dtype(torch.float64) + elif args.cuda: + torch.set_default_tensor_type(torch.cuda.FloatTensor) + return args + + +if __name__ == "__main__": + assert pyro.__version__.startswith('1.3.1') + args = Parser().parse_args() + + args.plot = True # DEBUG + if args.plot: + main(args) + import matplotlib.pyplot as plt + plt.show() + else: + # Cache output. + outfile = get_filename(__file__, args) + if not os.path.exists(outfile): + result = main(args) + with open(outfile, "wb") as f: + pickle.dump(result, f) + logging.info("Saved {}".format(outfile)) From df95828e938b0b831a020cd93a81e63c92262f76 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 20 Jun 2020 18:22:12 -0700 Subject: [PATCH 20/26] Expose more SVI parameters --- 2020-06-compartmental/analyze.ipynb | 2 +- 2020-06-compartmental/uni_real.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/2020-06-compartmental/analyze.ipynb b/2020-06-compartmental/analyze.ipynb index a9cb51a..ffd90b5 100644 --- a/2020-06-compartmental/analyze.ipynb +++ b/2020-06-compartmental/analyze.ipynb @@ -127,9 +127,9 @@ " Y = [v[\"infer\"][variable][metric] for v in value]\n", " ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)\n", " ax.set_xscale(\"log\")\n", + " ax.set_yscale(\"log\")\n", " ax.set_ylabel(metric)\n", " axes[0].set_title(f\"{variable} convergence ({experiment.__name__})\")\n", - " axes[1].set_yscale(\"log\")\n", " axes[1].set_ylim(1, None)\n", " axes[-1].legend(loc=\"best\", prop={'size': 8})\n", " axes[-1].set_xlabel(\"inference time (sec)\")\n", diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index 6208c50..59a8245 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -109,7 +109,7 @@ def global_model(self): R0 = pyro.sample("R0", dist.LogNormal(1., 0.5)) # Weak prior. external_rate = pyro.sample("external_rate", dist.LogNormal(-2, 2)) rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. - mu = pyro.sample("mu", dist.Beta(1, 100)) # About 1% mortality rate. + mu = pyro.sample("mu", dist.Beta(2, 100)) # About 2% mortality rate. drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) od = pyro.sample("od", dist.Beta(2, 6)) @@ -197,6 +197,10 @@ def infer_svi(args, model): num_samples=args.num_samples, num_steps=args.svi_steps, num_particles=args.svi_particles, + learning_rate=args.learning_rate, + learning_rate_decay=args.learning_rate_decay, + betas=args.betas, + init_scale=args.init_scale, jit=args.jit) return {"loss_initial": losses[0], "loss_final": losses[-1]} @@ -347,6 +351,10 @@ def __init__(self): self.add_argument("--smc-particles", default=1024, type=int) self.add_argument("--svi-steps", default=5000, type=int) self.add_argument("--svi-particles", default=32, type=int) + self.add_argument("--learning-rate", default=0.1, type=float) + self.add_argument("--learning-rate-decay", default=0.01, type=float) + self.add_argument("--betas", default="0.8,0.99") + self.add_argument("--init-scale", default=0.1, type=float) self.add_argument("--warmup-steps", type=int) self.add_argument("--num-chains", default=2, type=int) self.add_argument("--max-tree-depth", default=5, type=int) @@ -361,6 +369,7 @@ def __init__(self): def parse_args(self, *args, **kwargs): args = super().parse_args(*args, **kwargs) + args.betas = tuple(map(float, args.betas.split(","))) assert args.forecast > 0 if args.warmup_steps is None: args.warmup_steps = args.num_samples From d900498f887f272791611c66ba3139e9fab1353f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 21 Jun 2020 10:28:25 -0700 Subject: [PATCH 21/26] Add zero inflation and more plots --- 2020-06-compartmental/uni_real.py | 130 +++++++++++++++++++----------- 1 file changed, 82 insertions(+), 48 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index 59a8245..7222104 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -2,15 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import datetime import logging import os import pickle import resource import sys import urllib.request -from timeit import default_timer from collections import OrderedDict +from timeit import default_timer +import numpy as np import pandas as pd import torch @@ -19,8 +21,6 @@ from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse from pyro.infer.mcmc.util import summary -from pyro.ops.tensor_utils import convolve - from util import DATA, get_filename fmt = '%(process)d %(message)s' @@ -66,28 +66,26 @@ def load_data(args): cum_cases = torch.tensor(cum_cases, dtype=torch.get_default_dtype()).T.contiguous() cum_deaths = torch.tensor(cum_deaths, dtype=torch.get_default_dtype()).T.contiguous() assert cum_cases.shape == cum_deaths.shape - logging.info(f"Data shape = {tuple(cum_cases.shape)}, {tuple(cum_deaths.shape)}") - - # Convert from cumulative to difference data, and convolve to ensure positivity. - T = len(cum_cases) - for window in range(1, 100): - kernel = torch.ones(window) / window - smooth_cases = convolve(cum_cases.T, kernel).T[:T].round() - smooth_deaths = convolve(cum_deaths.T, kernel).T[:T].round() - new_cases = smooth_cases[1:] - smooth_cases[:-1] - new_deaths = smooth_deaths[1:] - smooth_deaths[:-1] - if (new_cases >= 0).all() and (new_deaths >= 0).all(): - break - logging.info(f"window = {window}, shape = {tuple(new_cases.shape)}") + start_date = datetime.datetime.strptime(cum_cases_df.columns[11], "%m/%d/%y") + + # Convert from cumulative to difference data, and clamp to ensure positivity. + new_cases = (cum_cases[1:] - cum_cases[:-1]).clamp(min=0) + new_deaths = (cum_deaths[1:] - cum_deaths[:-1]).clamp(min=0) + start_date += datetime.timedelta(days=1) # Truncate and select a single county. - new_cases = new_cases[args.truncate:, args.county].contiguous() - new_deaths = new_deaths[args.truncate:, args.county].contiguous() + truncate = (datetime.datetime.strptime(args.start_date, "%m/%d/%y") - start_date).days + assert truncate > 0, "start date is too early" + new_cases = new_cases[truncate:, args.county].contiguous() + new_deaths = new_deaths[truncate:, args.county].contiguous() population = list(counties.values())[args.county] + start_date += datetime.timedelta(days=truncate) + logging.info(f"Data shape = {tuple(new_cases.shape)}") return {"population": population, "new_cases": new_cases, - "new_deaths": new_deaths} + "new_deaths": new_deaths, + "start_date": start_date} class Model(CompartmentalModel): @@ -111,19 +109,23 @@ def global_model(self): rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. mu = pyro.sample("mu", dist.Beta(2, 100)) # About 2% mortality rate. drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) - od = pyro.sample("od", dist.Beta(2, 6)) - return R0, external_rate, tau_e, tau_i, rho, mu, drift, od + # Assume observations are overdispersed and zero-inflated. + # od = pyro.sample("od", dist.Beta(2, 6)) + od = 0.6 # FIXME fix overdispersion inference. + zi = pyro.sample("zi", dist.Beta(2, 6)) + + return R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi def initialize(self, params): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params + R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi = params # Start with no local infections and close to basic reproductive number. return {"S": self.population, "E": 0, "I": 0, "R_factor": torch.tensor(0.98)} def transition(self, params, state, t): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params + R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi = params # Assume effective reproductive number Rt varies in time. sigmoid = torch.distributions.transforms.SigmoidTransform() @@ -157,10 +159,14 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration pyro.sample("new_cases_{}".format(t), - binomial_dist(S2E, rho, overdispersion=od), + dist.ZeroInflatedDistribution( + zi, binomial_dist(S2E, (rho / (1 - zi)).clamp(max=1), + overdispersion=od)), obs=self.new_cases[t] if t_is_observed else None) pyro.sample("new_deaths_{}".format(t), - binomial_dist(I2R, mu, overdispersion=od), + dist.ZeroInflatedDistribution( + zi, binomial_dist(I2R, (mu / (1 - zi)).clamp(max=1), + overdispersion=od)), obs=self.new_deaths[t] if t_is_observed else None) @@ -197,10 +203,11 @@ def infer_svi(args, model): num_samples=args.num_samples, num_steps=args.svi_steps, num_particles=args.svi_particles, + guide_rank=args.guide_rank, + init_scale=args.init_scale, learning_rate=args.learning_rate, learning_rate_decay=args.learning_rate_decay, betas=args.betas, - init_scale=args.init_scale, jit=args.jit) return {"loss_initial": losses[0], "loss_final": losses[-1]} @@ -211,41 +218,67 @@ def predict(args, model, truth): if args.plot: import matplotlib.pyplot as plt - fig, axes = plt.subplots(3, 1, figsize=(6, 8), sharex=True) + import matplotlib.dates as mdates + fig, axes = plt.subplots(4, 1, figsize=(8, 8), sharex=True) + shelter_in_place = datetime.datetime.strptime("3/16/20", "%m/%d/%y") + axes[-1].text(shelter_in_place + datetime.timedelta(days=1), 0.2, + "shelter in place") + for ax in axes: + ax.axvline(shelter_in_place, color="gray", linestyle=":", lw=1) + ax.axvline(truth["start_date"] + datetime.timedelta(days=model.duration), + color="gray", lw=1) + axes[0].set_title("{}, population {}".format( + list(counties)[args.county], truth["population"])) + time = np.array([truth["start_date"] + datetime.timedelta(days=t) + for t in range(model.duration + args.forecast)]) # Plot forecasted series. + num_samples = samples["R0"].size(0) for name, ax in zip(["new_cases", "new_deaths"], axes): pred = samples[name][..., model.duration:] - time = torch.arange(model.duration + args.forecast) median = pred.median(dim=0).values - p05 = pred.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values - p95 = pred.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values + p05 = pred.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = pred.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values ax.fill_between(time[model.duration:], p05, p95, color="red", alpha=0.3, label="90% CI") ax.plot(time[model.duration:], median, "r-", label="median") ax.plot(time, truth[name], "k--", label="truth") - ax.axvline(model.duration - 0.5, color="gray", lw=1) ax.set_yscale("log") ax.set_ylim(1, None) - ax.set_ylabel(f"{name} / day") + ax.set_ylabel(f"{name.replace('_', ' ')} / day") ax.legend(loc="upper left") + # Plot the latent time series. + ax = axes[2] + for name, color in zip(["E", "I"], ["red", "blue"]): + value = samples[name] + median = value.median(dim=0).values + p05 = value.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values + ax.fill_between(time, p05, p95, color=color, alpha=0.3) + ax.plot(time, median, color=color, label=name) + ax.set_yscale("log") + ax.set_ylim(1, None) + ax.set_ylabel("# people") + ax.legend(loc="best") + # Plot Rt time series. Rt = samples["Rt"] median = Rt.median(dim=0).values - p05 = Rt.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values - p95 = Rt.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values - axes[2].fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") - axes[2].plot(time, median, "r-", label="median") - axes[2].axvline(model.duration - 0.5, color="gray", lw=1) - axes[2].set_ylim(0, None) - axes[2].set_ylabel("Rt") - axes[2].legend(loc="best") - - axes[-1].set_xlim(0, len(time) - 1) - axes[-1].set_xlabel("day") - axes[0].set_title("{}, population {}".format( - list(counties)[args.county], truth["population"])) + p05 = Rt.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = Rt.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values + ax = axes[3] + ax.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") + ax.plot(time, median, "r-", label="median") + ax.set_ylim(0, None) + ax.set_ylabel("Rt") + ax.legend(loc="best") + + ax.set_xlim(time[0], time[-1]) + locator = mdates.AutoDateLocator(minticks=5, maxticks=15) + formatter = mdates.ConciseDateFormatter(locator) + ax.xaxis.set_major_locator(locator) + ax.xaxis.set_major_formatter(formatter) plt.tight_layout() plt.subplots_adjust(hspace=0) @@ -277,7 +310,7 @@ def evaluate(args, truth, model, samples): std = value.std().item() logging.info(f"{name} = {mean:0.3g} \u00B1 {std:0.3g}") - if args.plot: + if args.plot and args.infer == "mcmc": # Plot pairwise joint distributions for selected variables. import matplotlib.pyplot as plt N = len(covariates) @@ -339,7 +372,7 @@ def __init__(self): super().__init__(description="CompartmentalModel experiments") self.add_argument("--county", default=0, type=int, help="which SF Bay Area county, 0-8") - self.add_argument("--truncate", default=30, type=int) + self.add_argument("--start-date", default="2/1/20") self.add_argument("--forecast", default=14, type=int) self.add_argument("--recovery-time", default=14.0, type=float) self.add_argument("--incubation-time", default=5.5, type=float) @@ -347,6 +380,8 @@ def __init__(self): self.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer") self.add_argument("--svi", action="store_const", const="svi", dest="infer") self.add_argument("--haar-full-mass", default=10, type=int) + self.add_argument("--guide-rank", type=int) + self.add_argument("--init-scale", default=0.01, type=float) self.add_argument("--num-samples", default=200, type=int) self.add_argument("--smc-particles", default=1024, type=int) self.add_argument("--svi-steps", default=5000, type=int) @@ -354,7 +389,6 @@ def __init__(self): self.add_argument("--learning-rate", default=0.1, type=float) self.add_argument("--learning-rate-decay", default=0.01, type=float) self.add_argument("--betas", default="0.8,0.99") - self.add_argument("--init-scale", default=0.1, type=float) self.add_argument("--warmup-steps", type=int) self.add_argument("--num-chains", default=2, type=int) self.add_argument("--max-tree-depth", default=5, type=int) From 8381b825fd2551d35ae97a2bc3efdbd5424b7dba Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 22 Jun 2020 13:03:49 -0700 Subject: [PATCH 22/26] Add intervention date --- 2020-06-compartmental/uni_real.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index 7222104..1498da9 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -74,7 +74,7 @@ def load_data(args): start_date += datetime.timedelta(days=1) # Truncate and select a single county. - truncate = (datetime.datetime.strptime(args.start_date, "%m/%d/%y") - start_date).days + truncate = (args.start_date - start_date).days assert truncate > 0, "start date is too early" new_cases = new_cases[truncate:, args.county].contiguous() new_deaths = new_deaths[truncate:, args.county].contiguous() @@ -101,6 +101,11 @@ def __init__(self, args, population, new_cases, new_deaths): self.new_cases = new_cases self.new_deaths = new_deaths + # Intervene via a step function. + t1 = (args.intervene_date - args.start_date).days + t2 = self.duration + args.forecast + self.intervene = torch.cat([torch.zeros(t1), torch.ones(t2 - t1)]) + def global_model(self): tau_e = self.incubation_time tau_i = self.recovery_time @@ -127,6 +132,9 @@ def initialize(self, params): def transition(self, params, state, t): R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi = params + # Assume drift is 4x larger during various interventions. + drift = drift * (0.25 + 0.75 * self.intervene[t]) + # Assume effective reproductive number Rt varies in time. sigmoid = torch.distributions.transforms.SigmoidTransform() R_factor = pyro.sample("R_factor_{}".format(t), @@ -373,6 +381,7 @@ def __init__(self): self.add_argument("--county", default=0, type=int, help="which SF Bay Area county, 0-8") self.add_argument("--start-date", default="2/1/20") + self.add_argument("--intervene-date", default="3/1/20") self.add_argument("--forecast", default=14, type=int) self.add_argument("--recovery-time", default=14.0, type=float) self.add_argument("--incubation-time", default=5.5, type=float) @@ -403,10 +412,20 @@ def __init__(self): def parse_args(self, *args, **kwargs): args = super().parse_args(*args, **kwargs) - args.betas = tuple(map(float, args.betas.split(","))) + assert args.forecast > 0 + + args.betas = tuple(map(float, args.betas.split(","))) + + # Parse dates. + for name, value in args.__dict__.items(): + if name.endswith("_date"): + value = datetime.datetime.strptime(value, "%m/%d/%y") + setattr(args, name, value) + if args.warmup_steps is None: args.warmup_steps = args.num_samples + if args.double: if args.cuda: torch.set_default_tensor_type(torch.cuda.DoubleTensor) @@ -414,6 +433,7 @@ def parse_args(self, *args, **kwargs): torch.set_default_dtype(torch.float64) elif args.cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) + return args From 38680b2699744576acb9a7af72829bf8dca336a7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 23 Jun 2020 17:58:01 -0700 Subject: [PATCH 23/26] Add more counties; convolve input --- 2020-06-compartmental/uni_real.py | 110 ++++++++++++++++-------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index 1498da9..d4877dc 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -15,12 +15,14 @@ import numpy as np import pandas as pd import torch +from torch.nn.functional import pad import pyro import pyro.distributions as dist from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist from pyro.contrib.forecast.evaluate import eval_crps, eval_mae, eval_rmse from pyro.infer.mcmc.util import summary +from pyro.ops.tensor_utils import convolve from util import DATA, get_filename fmt = '%(process)d %(message)s' @@ -28,18 +30,31 @@ logging.basicConfig(format=fmt, level=logging.INFO) -# Bay area county populations. +# Misc California county populations as of early 2020. counties = OrderedDict([ - ("Santa Clara", 1763000), - ("Alameda", 1495000), - ("Contra Costa", 1038000), - ("San Francisco", 871000), - ("San Mateo", 712000), - ("Sonoma", 479000), - ("Solano", 412000), - ("Marin", 251000), - ("Napa", 135000), + # Bay Area counties. + ("Santa Clara", 1.763e6), + ("Alameda", 1.495e6), + ("Contra Costa", 1.038e6), + ("San Francisco", 871e3), + ("San Mateo", 712e3), + ("Sonoma", 479e3), + ("Solano", 412e3), + ("Marin", 251e3), + ("Napa", 135e3), + # Misc non Bay Area counties. + ("Los Angeles", 10.04e6), + ("Riverside", 2.471e6), + ("San Diego", 3.338e6), + ("Orange", 3.176e6), + ("San Bernardino", 2.18e6), + ("Imperial", 181e3), + ("Kern", 900e3), + ("Fresno", 999e3), + ("Tulare", 466e3), + ("Santa Barbara", 446e3), ]) +counties = OrderedDict((k, int(v)) for k, v in counties.items()) def load_df(basename): @@ -56,31 +71,34 @@ def load_data(args): cum_deaths_df = load_df("time_series_covid19_deaths_US.csv") # Convert to torch.Tensor. - cum_cases = [] - cum_deaths = [] - for county in counties: - i = list(cum_cases_df["Admin2"]).index(county) - cum_cases.append(cum_cases_df.iloc[i, 11:]) - i = list(cum_deaths_df["Admin2"]).index(county) - cum_deaths.append(cum_deaths_df.iloc[i, 12:]) - cum_cases = torch.tensor(cum_cases, dtype=torch.get_default_dtype()).T.contiguous() - cum_deaths = torch.tensor(cum_deaths, dtype=torch.get_default_dtype()).T.contiguous() + county = list(counties)[args.county] + population = counties[county] + i = list(cum_cases_df["Admin2"]).index(county) + cum_cases = cum_cases_df.iloc[i, 11:] + i = list(cum_deaths_df["Admin2"]).index(county) + cum_deaths = cum_deaths_df.iloc[i, 12:] + cum_cases = torch.tensor(cum_cases, dtype=torch.get_default_dtype()).contiguous() + cum_deaths = torch.tensor(cum_deaths, dtype=torch.get_default_dtype()).contiguous() assert cum_cases.shape == cum_deaths.shape start_date = datetime.datetime.strptime(cum_cases_df.columns[11], "%m/%d/%y") + # Distribute reported cases and deaths among previous few days. + if args.report_lag: + kernel = torch.ones(args.report_lag) / args.report_lag + cum_cases = convolve(cum_cases, kernel, mode="valid").round() + cum_deaths = convolve(cum_deaths, kernel, mode="valid").round() + # Convert from cumulative to difference data, and clamp to ensure positivity. - new_cases = (cum_cases[1:] - cum_cases[:-1]).clamp(min=0) - new_deaths = (cum_deaths[1:] - cum_deaths[:-1]).clamp(min=0) - start_date += datetime.timedelta(days=1) + new_cases = (cum_cases - pad(cum_cases[:-1], (1, 0), value=0)).clamp(min=0) + new_deaths = (cum_deaths - pad(cum_deaths[:-1], (1, 0), value=0)).clamp(min=0) - # Truncate and select a single county. + # Truncate. truncate = (args.start_date - start_date).days assert truncate > 0, "start date is too early" - new_cases = new_cases[truncate:, args.county].contiguous() - new_deaths = new_deaths[truncate:, args.county].contiguous() - population = list(counties.values())[args.county] + new_cases = new_cases[truncate:].contiguous() + new_deaths = new_deaths[truncate:].contiguous() start_date += datetime.timedelta(days=truncate) - logging.info(f"Data shape = {tuple(new_cases.shape)}") + logging.info(f"{county} data shape = {tuple(new_cases.shape)}") return {"population": population, "new_cases": new_cases, @@ -114,25 +132,19 @@ def global_model(self): rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. mu = pyro.sample("mu", dist.Beta(2, 100)) # About 2% mortality rate. drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) - - # Assume observations are overdispersed and zero-inflated. - # od = pyro.sample("od", dist.Beta(2, 6)) - od = 0.6 # FIXME fix overdispersion inference. - zi = pyro.sample("zi", dist.Beta(2, 6)) - - return R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi + od1 = pyro.sample("od1", dist.Uniform(0, 2)) + od2 = pyro.sample("od2", dist.Uniform(0, 2)) + return R0, external_rate, tau_e, tau_i, rho, mu, drift, od1, od2 def initialize(self, params): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi = params - # Start with no local infections and close to basic reproductive number. return {"S": self.population, "E": 0, "I": 0, "R_factor": torch.tensor(0.98)} def transition(self, params, state, t): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od, zi = params + R0, external_rate, tau_e, tau_i, rho, mu, drift, od1, od2 = params - # Assume drift is 4x larger during various interventions. + # Assume drift is 4x larger after various interventions begin. drift = drift * (0.25 + 0.75 * self.intervene[t]) # Assume effective reproductive number Rt varies in time. @@ -150,13 +162,11 @@ def transition(self, params, state, t): num_susceptible=state["S"], num_infectious=state["I"] + I_external, population=self.population, - overdispersion=od)) + overdispersion=od1)) E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e, - overdispersion=od)) + binomial_dist(state["E"], 1 / tau_e, overdispersion=od1)) I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau_i, - overdispersion=od)) + binomial_dist(state["I"], 1 / tau_i, overdispersion=od1)) # Update compartments and heterogeneous variables. state["S"] = state["S"] - S2E @@ -167,14 +177,10 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration pyro.sample("new_cases_{}".format(t), - dist.ZeroInflatedDistribution( - zi, binomial_dist(S2E, (rho / (1 - zi)).clamp(max=1), - overdispersion=od)), + binomial_dist(S2E, rho, overdispersion=od2), obs=self.new_cases[t] if t_is_observed else None) pyro.sample("new_deaths_{}".format(t), - dist.ZeroInflatedDistribution( - zi, binomial_dist(I2R, (mu / (1 - zi)).clamp(max=1), - overdispersion=od)), + binomial_dist(I2R, mu, overdispersion=od2), obs=self.new_deaths[t] if t_is_observed else None) @@ -232,9 +238,9 @@ def predict(args, model, truth): axes[-1].text(shelter_in_place + datetime.timedelta(days=1), 0.2, "shelter in place") for ax in axes: - ax.axvline(shelter_in_place, color="gray", linestyle=":", lw=1) + ax.axvline(shelter_in_place, color="black", linestyle=":", lw=1, alpha=0.3) ax.axvline(truth["start_date"] + datetime.timedelta(days=model.duration), - color="gray", lw=1) + color="black", lw=1, alpha=0.3) axes[0].set_title("{}, population {}".format( list(counties)[args.county], truth["population"])) time = np.array([truth["start_date"] + datetime.timedelta(days=t) @@ -278,6 +284,7 @@ def predict(args, model, truth): ax = axes[3] ax.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") ax.plot(time, median, "r-", label="median") + ax.axhline(1, color="black", linestyle=":", lw=1, alpha=0.3) ax.set_ylim(0, None) ax.set_ylabel("Rt") ax.legend(loc="best") @@ -382,6 +389,7 @@ def __init__(self): help="which SF Bay Area county, 0-8") self.add_argument("--start-date", default="2/1/20") self.add_argument("--intervene-date", default="3/1/20") + self.add_argument("--report-lag", type=int, default=5) self.add_argument("--forecast", default=14, type=int) self.add_argument("--recovery-time", default=14.0, type=float) self.add_argument("--incubation-time", default=5.5, type=float) From 3d9f4417c5001dac2f89b08ce3dc2e0b58076b44 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 24 Jun 2020 11:47:51 -0700 Subject: [PATCH 24/26] Combine od1=od2=od --- 2020-06-compartmental/uni_real.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index d4877dc..a5599a2 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -132,9 +132,8 @@ def global_model(self): rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. mu = pyro.sample("mu", dist.Beta(2, 100)) # About 2% mortality rate. drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) - od1 = pyro.sample("od1", dist.Uniform(0, 2)) - od2 = pyro.sample("od2", dist.Uniform(0, 2)) - return R0, external_rate, tau_e, tau_i, rho, mu, drift, od1, od2 + od = pyro.sample("od", dist.Beta(1, 3)) + return R0, external_rate, tau_e, tau_i, rho, mu, drift, od def initialize(self, params): # Start with no local infections and close to basic reproductive number. @@ -142,7 +141,7 @@ def initialize(self, params): "R_factor": torch.tensor(0.98)} def transition(self, params, state, t): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od1, od2 = params + R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params # Assume drift is 4x larger after various interventions begin. drift = drift * (0.25 + 0.75 * self.intervene[t]) @@ -162,11 +161,11 @@ def transition(self, params, state, t): num_susceptible=state["S"], num_infectious=state["I"] + I_external, population=self.population, - overdispersion=od1)) + overdispersion=od)) E2I = pyro.sample("E2I_{}".format(t), - binomial_dist(state["E"], 1 / tau_e, overdispersion=od1)) + binomial_dist(state["E"], 1 / tau_e, overdispersion=od)) I2R = pyro.sample("I2R_{}".format(t), - binomial_dist(state["I"], 1 / tau_i, overdispersion=od1)) + binomial_dist(state["I"], 1 / tau_i, overdispersion=od)) # Update compartments and heterogeneous variables. state["S"] = state["S"] - S2E @@ -177,13 +176,12 @@ def transition(self, params, state, t): # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration pyro.sample("new_cases_{}".format(t), - binomial_dist(S2E, rho, overdispersion=od2), + binomial_dist(S2E, rho, overdispersion=od), obs=self.new_cases[t] if t_is_observed else None) pyro.sample("new_deaths_{}".format(t), - binomial_dist(I2R, mu, overdispersion=od2), + binomial_dist(I2R, mu, overdispersion=od), obs=self.new_deaths[t] if t_is_observed else None) - def _item(x): if isinstance(x, torch.Tensor): x = x.reshape(-1).median().item() @@ -258,7 +256,7 @@ def predict(args, model, truth): ax.plot(time[model.duration:], median, "r-", label="median") ax.plot(time, truth[name], "k--", label="truth") ax.set_yscale("log") - ax.set_ylim(1, None) + ax.set_ylim(0.5, None) ax.set_ylabel(f"{name.replace('_', ' ')} / day") ax.legend(loc="upper left") @@ -272,7 +270,7 @@ def predict(args, model, truth): ax.fill_between(time, p05, p95, color=color, alpha=0.3) ax.plot(time, median, color=color, label=name) ax.set_yscale("log") - ax.set_ylim(1, None) + ax.set_ylim(0.5, None) ax.set_ylabel("# people") ax.legend(loc="best") From 75dc4c7ad45dd7766d1a9aa8508b76d4e2fad7a4 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 24 Jun 2020 13:48:19 -0700 Subject: [PATCH 25/26] Allow heterogeneous response rate --- 2020-06-compartmental/uni_real.py | 171 ++++++++++++++++-------------- 1 file changed, 92 insertions(+), 79 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index a5599a2..d1d949d 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -4,6 +4,7 @@ import argparse import datetime import logging +import math import os import pickle import resource @@ -29,7 +30,6 @@ logging.getLogger("pyro").handlers[0].setFormatter(logging.Formatter(fmt)) logging.basicConfig(format=fmt, level=logging.INFO) - # Misc California county populations as of early 2020. counties = OrderedDict([ # Bay Area counties. @@ -106,6 +106,9 @@ def load_data(args): "start_date": start_date} +sigmoid = torch.distributions.transforms.SigmoidTransform() + + class Model(CompartmentalModel): def __init__(self, args, population, new_cases, new_deaths): assert new_cases.dim() == 1 @@ -129,30 +132,37 @@ def global_model(self): tau_i = self.recovery_time R0 = pyro.sample("R0", dist.LogNormal(1., 0.5)) # Weak prior. external_rate = pyro.sample("external_rate", dist.LogNormal(-2, 2)) - rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. + rho0 = pyro.sample("rho0", dist.Beta(10, 10)) # About 50% response rate. mu = pyro.sample("mu", dist.Beta(2, 100)) # About 2% mortality rate. - drift = pyro.sample("drift", dist.LogNormal(-3, 1.)) - od = pyro.sample("od", dist.Beta(1, 3)) - return R0, external_rate, tau_e, tau_i, rho, mu, drift, od + R_drift = pyro.sample("R_drift", dist.LogNormal(math.log(0.1), 1.)) + rho_drift = pyro.sample("rho_drift", dist.LogNormal(math.log(0.01), 1.)) + od = pyro.sample("od", dist.Beta(2, 6)) + return R0, external_rate, tau_e, tau_i, rho0, mu, R_drift, rho_drift, od def initialize(self, params): - # Start with no local infections and close to basic reproductive number. + R0, external_rate, tau_e, tau_i, rho0, mu, R_drift, rho_drift, od = params + + # Start with no local infections and initial Brownian motion. return {"S": self.population, "E": 0, "I": 0, - "R_factor": torch.tensor(0.98)} + "R_motion": sigmoid.inv(torch.tensor(0.98)), + "rho_motion": torch.tensor(0.)} def transition(self, params, state, t): - R0, external_rate, tau_e, tau_i, rho, mu, drift, od = params + R0, external_rate, tau_e, tau_i, rho0, mu, R_drift, rho_drift, od = params # Assume drift is 4x larger after various interventions begin. - drift = drift * (0.25 + 0.75 * self.intervene[t]) - - # Assume effective reproductive number Rt varies in time. - sigmoid = torch.distributions.transforms.SigmoidTransform() - R_factor = pyro.sample("R_factor_{}".format(t), - dist.TransformedDistribution( - dist.Normal(sigmoid.inv(state["R_factor"]), drift), - sigmoid)) - Rt = pyro.deterministic("Rt_{}".format(t), R0 * R_factor, event_dim=0) + R_drift = R_drift * (0.25 + 0.75 * self.intervene[t]) + rho_drift = rho_drift * (0.25 + 0.75 * self.intervene[t]) + + # Assume reproductive number Rt and response rate rho vary in time. + R_motion = pyro.sample("R_motion_{}".format(t), + dist.Normal(state["R_motion"], R_drift)) + rho_motion = pyro.sample("rho_motion_{}".format(t), + dist.Normal(state["rho_motion"], rho_drift)) + Rt = pyro.deterministic("Rt_{}".format(t), + R0 * sigmoid(R_motion), event_dim=0) + rho = pyro.deterministic("rho_{}".format(t), + sigmoid(sigmoid.inv(rho0) + rho_motion), event_dim=0) I_external = external_rate * tau_i / Rt # Sample flows between compartments. @@ -171,7 +181,8 @@ def transition(self, params, state, t): state["S"] = state["S"] - S2E state["E"] = state["E"] + S2E - E2I state["I"] = state["I"] + E2I - I2R - state["R_factor"] = R_factor + state["R_motion"] = R_motion + state["rho_motion"] = rho_motion # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration @@ -227,73 +238,75 @@ def infer_svi(args, model): def predict(args, model, truth): samples = model.predict(forecast=args.forecast) - - if args.plot: - import matplotlib.pyplot as plt - import matplotlib.dates as mdates - fig, axes = plt.subplots(4, 1, figsize=(8, 8), sharex=True) - shelter_in_place = datetime.datetime.strptime("3/16/20", "%m/%d/%y") - axes[-1].text(shelter_in_place + datetime.timedelta(days=1), 0.2, - "shelter in place") - for ax in axes: - ax.axvline(shelter_in_place, color="black", linestyle=":", lw=1, alpha=0.3) - ax.axvline(truth["start_date"] + datetime.timedelta(days=model.duration), - color="black", lw=1, alpha=0.3) - axes[0].set_title("{}, population {}".format( - list(counties)[args.county], truth["population"])) - time = np.array([truth["start_date"] + datetime.timedelta(days=t) - for t in range(model.duration + args.forecast)]) - - # Plot forecasted series. - num_samples = samples["R0"].size(0) - for name, ax in zip(["new_cases", "new_deaths"], axes): - pred = samples[name][..., model.duration:] - median = pred.median(dim=0).values - p05 = pred.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values - p95 = pred.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values - ax.fill_between(time[model.duration:], p05, p95, color="red", alpha=0.3, - label="90% CI") - ax.plot(time[model.duration:], median, "r-", label="median") - ax.plot(time, truth[name], "k--", label="truth") - ax.set_yscale("log") - ax.set_ylim(0.5, None) - ax.set_ylabel(f"{name.replace('_', ' ')} / day") - ax.legend(loc="upper left") - - # Plot the latent time series. - ax = axes[2] - for name, color in zip(["E", "I"], ["red", "blue"]): - value = samples[name] - median = value.median(dim=0).values - p05 = value.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values - p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values - ax.fill_between(time, p05, p95, color=color, alpha=0.3) - ax.plot(time, median, color=color, label=name) + if not args.plot: + return samples + + import matplotlib.pyplot as plt + import matplotlib.dates as mdates + fig, axes = plt.subplots(5, 1, figsize=(8, 8), sharex=True) + shelter_in_place = datetime.datetime.strptime("3/16/20", "%m/%d/%y") + axes[-1].text(shelter_in_place + datetime.timedelta(days=1), -0.25, + "shelter in place", horizontalalignment='center') + for ax in axes: + ax.axvline(shelter_in_place, color="black", linestyle=":", lw=1, alpha=0.3) + ax.axvline(truth["start_date"] + datetime.timedelta(days=model.duration), + color="black", lw=1, alpha=0.3) + axes[0].set_title("{}, population {}".format( + list(counties)[args.county], truth["population"])) + time = np.array([truth["start_date"] + datetime.timedelta(days=t) + for t in range(model.duration + args.forecast)]) + + # Plot forecasted series. + num_samples = samples["R0"].size(0) + for name, ax in zip(["new_cases", "new_deaths"], axes): + pred = samples[name][..., model.duration:] + median = pred.median(dim=0).values + p05 = pred.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = pred.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values + ax.fill_between(time[model.duration:], p05, p95, color="red", alpha=0.3, + label="90% CI") + ax.plot(time[model.duration:], median, "r-", label="median") + ax.plot(time, truth[name], "k--", label="truth") ax.set_yscale("log") ax.set_ylim(0.5, None) - ax.set_ylabel("# people") - ax.legend(loc="best") - - # Plot Rt time series. - Rt = samples["Rt"] - median = Rt.median(dim=0).values - p05 = Rt.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values - p95 = Rt.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values - ax = axes[3] + ax.set_ylabel(f"{name.replace('_', ' ')} / day") + ax.legend(loc="upper left") + + # Plot the latent time series. + ax = axes[2] + for name, color in zip(["E", "I"], ["red", "blue"]): + value = samples[name] + median = value.median(dim=0).values + p05 = value.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values + ax.fill_between(time, p05, p95, color=color, alpha=0.3) + ax.plot(time, median, color=color, label=name) + ax.set_yscale("log") + ax.set_ylim(0.5, None) + ax.set_ylabel("# people") + ax.legend(loc="best") + + # Plot parameter time series. + for name, ax in zip(["Rt", "rho"], axes[3:]): + value = samples[name] + median = value.median(dim=0).values + p05 = value.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values + p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values ax.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") ax.plot(time, median, "r-", label="median") ax.axhline(1, color="black", linestyle=":", lw=1, alpha=0.3) - ax.set_ylim(0, None) - ax.set_ylabel("Rt") + ax.set_ylabel(name) ax.legend(loc="best") + axes[3].set_ylim(0, None) + axes[4].set_ylim(0, 1) - ax.set_xlim(time[0], time[-1]) - locator = mdates.AutoDateLocator(minticks=5, maxticks=15) - formatter = mdates.ConciseDateFormatter(locator) - ax.xaxis.set_major_locator(locator) - ax.xaxis.set_major_formatter(formatter) - plt.tight_layout() - plt.subplots_adjust(hspace=0) + ax.set_xlim(time[0], time[-1]) + locator = mdates.AutoDateLocator(minticks=5, maxticks=15) + formatter = mdates.ConciseDateFormatter(locator) + ax.xaxis.set_major_locator(locator) + ax.xaxis.set_major_formatter(formatter) + plt.tight_layout() + plt.subplots_adjust(hspace=0) return samples @@ -430,7 +443,7 @@ def parse_args(self, *args, **kwargs): setattr(args, name, value) if args.warmup_steps is None: - args.warmup_steps = args.num_samples + args.warmup_steps = int(round(0.4 * args.num_samples)) if args.double: if args.cuda: From 6dec724168548b3faba9d1a43662bd09bff39fb8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 24 Jun 2020 14:18:09 -0700 Subject: [PATCH 26/26] Support synthetic data in uni_real.py --- 2020-06-compartmental/uni_real.py | 34 +++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/2020-06-compartmental/uni_real.py b/2020-06-compartmental/uni_real.py index d1d949d..622ddca 100644 --- a/2020-06-compartmental/uni_real.py +++ b/2020-06-compartmental/uni_real.py @@ -111,8 +111,7 @@ def load_data(args): class Model(CompartmentalModel): def __init__(self, args, population, new_cases, new_deaths): - assert new_cases.dim() == 1 - assert new_cases.shape == new_deaths.shape + assert len(new_cases) == len(new_deaths) duration = len(new_cases) compartments = ("S", "E", "I") # R is implicit. super().__init__(compartments, duration, population) @@ -144,7 +143,7 @@ def initialize(self, params): # Start with no local infections and initial Brownian motion. return {"S": self.population, "E": 0, "I": 0, - "R_motion": sigmoid.inv(torch.tensor(0.98)), + "R_motion": sigmoid.inv(torch.tensor(0.95)), "rho_motion": torch.tensor(0.)} def transition(self, params, state, t): @@ -193,6 +192,7 @@ def transition(self, params, state, t): binomial_dist(I2R, mu, overdispersion=od), obs=self.new_deaths[t] if t_is_observed else None) + def _item(x): if isinstance(x, torch.Tensor): x = x.reshape(-1).median().item() @@ -280,7 +280,9 @@ def predict(args, model, truth): p05 = value.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values ax.fill_between(time, p05, p95, color=color, alpha=0.3) - ax.plot(time, median, color=color, label=name) + ax.plot(time, median, color=color, label=name, lw=1) + if name in truth: + ax.plot(time, truth[name], color=color, linestyle="--") ax.set_yscale("log") ax.set_ylim(0.5, None) ax.set_ylabel("# people") @@ -294,6 +296,8 @@ def predict(args, model, truth): p95 = value.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values ax.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") ax.plot(time, median, "r-", label="median") + if name in truth: + ax.plot(time, truth[name], "k--", label="truth") ax.axhline(1, color="black", linestyle=":", lw=1, alpha=0.3) ax.set_ylabel(name) ax.legend(loc="best") @@ -322,7 +326,7 @@ def evaluate(args, truth, model, samples): result[key]["mean"] = pred.mean().item() result[key]["std"] = pred.std(dim=0).mean().item() - if key in truth: + if key in ("new_cases", "new_deaths"): true = truth[key][..., model.duration:] for metric, fn in metrics: result[key][metric] = fn(pred, true) @@ -334,7 +338,11 @@ def evaluate(args, truth, model, samples): for name, value in covariates: mean = value.mean().item() std = value.std().item() - logging.info(f"{name} = {mean:0.3g} \u00B1 {std:0.3g}") + if name in truth: + true = truth[name] + logging.info(f"{name}: true = {true:0.3g}, pred = {mean:0.3g} \u00B1 {std:0.3g}") + else: + logging.info(f"{name} = {mean:0.3g} \u00B1 {std:0.3g}") if args.plot and args.infer == "mcmc": # Plot pairwise joint distributions for selected variables. @@ -364,6 +372,16 @@ def main(args): result = {"file": __file__, "args": args, "argv": sys.argv} truth = load_data(args) + if args.generate: + # Generate data with same population and duration as real data. + model = Model(args, truth["population"], + [None] * len(truth["new_cases"]), + [None] * len(truth["new_cases"])) + truth.update(model.generate()) + logging.info("Synthetic: {}".format(", ".join( + f"{name}={value:0.3g}" + for name, value in sorted(truth.items()) + if isinstance(value, torch.Tensor) and value.numel() == 1))) result["data"] = { "population": truth["population"], "total_cases": truth["new_cases"].sum().item(), @@ -396,8 +414,8 @@ def main(args): class Parser(argparse.ArgumentParser): def __init__(self): super().__init__(description="CompartmentalModel experiments") - self.add_argument("--county", default=0, type=int, - help="which SF Bay Area county, 0-8") + self.add_argument("--county", default=0, type=int) + self.add_argument("--generate", action="store_true") self.add_argument("--start-date", default="2/1/20") self.add_argument("--intervene-date", default="3/1/20") self.add_argument("--report-lag", type=int, default=5)