Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added demography.py module, added jointplots #25

Merged
merged 3 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions demography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import msprime
import math


def onepop_constant(args):
"""Single population model with pop size Ne and constant growth"""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "Ne"]
assert sorted(necessary_params) == sorted(
list(params.keys())
), "Invalid combination of parameters. Needed: mu | r | Ne"

if proposals:
mu, r, Ne = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, Ne = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

ts = msprime.simulate(
sample_size=genob.num_samples,
Ne=Ne,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts


def onepop_exp(args):
"""Single population model with sudden population size increase from N1 to N2
at time T1 and exponential growth at time T2"""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "T1", "N1", "T2", "N2", "growth"]
assert sorted(necessary_params) == sorted(
list(params.keys())
), "Invalid combination of parameters. Needed: mu | r | T1 | N1 | T2 | N2 | growth"

if proposals:
mu, r, T1, N1, T2, N2, growth = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, T1, N1, T2, N2, growth = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

N0 = N2 / math.exp(growth * T2)

# Time is given in generations unit (t/25)
demographic_events = [
msprime.PopulationParametersChange(time=0, initial_size=N0, growth_rate=growth),
msprime.PopulationParametersChange(time=T2, initial_size=N2, growth_rate=0),
msprime.PopulationParametersChange(time=T1, initial_size=N1),
]

ts = msprime.simulate(
sample_size=genob.num_samples,
demographic_events=demographic_events,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts


def onepop_migration(args):
"""Mass migration at time T1 from population 1 with pop size N2 to population
0 with pop size N1. Samples are collected only from population 0."""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "T1", "N1", "N2", "mig"]
assert sorted(necessary_params) == sorted(list(params.keys())), (
"Invalid combination of parameters. Needed: mu | r | T1 | N1 | N2 | mig \n"
f"Obtained: {list(params.keys())}"
)

if proposals:
mu, r, T1, N1, N2, mig = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, T1, N1, N2, mig = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

population_configurations = [
msprime.PopulationConfiguration(sample_size=genob.num_samples, initial_size=N1),
msprime.PopulationConfiguration(sample_size=0, initial_size=N2),
]

# migration from pop 1 into pop 0 (back in time)
mig_event = msprime.MassMigration(time=T1, source=1, destination=0, proportion=mig)

ts = msprime.simulate(
population_configurations=population_configurations,
demographic_events=[mig_event],
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts
47 changes: 23 additions & 24 deletions genobuilder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import OrderedDict
import concurrent.futures
import msprime
import pickle
import argparse
import stdpopsim
Expand All @@ -11,6 +10,7 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import demography as dm
from parameter import Parameter

_ex = None
Expand All @@ -24,27 +24,11 @@ def executor(p):


def do_sim(args):
genob, params, randomize, i, proposals, seed = args
rng = random.Random(seed)

if proposals:
Ne, mu, r = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in ("Ne", "mu", "r")
]
else:
Ne, mu, r = [
params[p].rand() if randomize else params[p].val for p in ("Ne", "mu", "r")
]

ts = msprime.simulate(
sample_size=genob.num_samples,
Ne=Ne,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)
seed = args[5]
genob = args[0]
rng = random.Random(seed)
ts = dm.onepop_exp(args)

return genob._resize_from_ts(ts, rng)

Expand Down Expand Up @@ -765,9 +749,24 @@ def vcf2zarr(vcf_files, pop_file, zarr_path):
args = parser.parse_args()
params_dict = OrderedDict()

params_dict["r"] = Parameter("r", 1.25e-9, 1e-10, (1e-11, 1e-7), inferable=True)
params_dict["mu"] = Parameter("mu", 1.25e-8, 1e-9, (1e-10, 1e-7), inferable=False)
params_dict["Ne"] = Parameter("Ne", 10000, 14000, (5000, 15000), inferable=True)
params_dict["r"] = Parameter("r", 2e-8, 1e-10, (1e-11, 1e-7), inferable=False)
params_dict["mu"] = Parameter("mu", 1.29e-8, 1e-9, (1e-10, 1e-7), inferable=False)
# params_dict["Ne"] = Parameter("Ne", 10000, 14000, (5000, 15000), inferable=True)

# For onepop_exp model:
params_dict["T1"] = Parameter("T1", 3000, 4000, (1500, 5000), inferable=True)
params_dict["N1"] = Parameter("N1", 10000, 20000, (1000, 30000), inferable=True)
params_dict["T2"] = Parameter("T2", 500, 1000, (100, 1500), inferable=False)
params_dict["N2"] = Parameter("N2", 5000, 20000, (1000, 20000), inferable=True)
params_dict["growth"] = Parameter("growth", 0.01, 0.02, (0, 0.05), inferable=True)

# For onepop_migration model:
"""
params_dict["T1"] = Parameter("T1", 1000, 4000, (500, 5000), inferable=False)
params_dict["N1"] = Parameter("N1", 5000, 18000, (1000, 20000), inferable=False)
params_dict["N2"] = Parameter("N2", 8000, 15000, (1000, 20000), inferable=False)
params_dict["mig"] = Parameter("mig", 0.9, 0.2, (0, 0.3), inferable=True)
"""

genob = Genobuilder(
source=args.source,
Expand Down
15 changes: 6 additions & 9 deletions genomcmcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ def run_genomcmcgan(
"""
xtest = genob.generate_fakedata(num_reps=1000)
test_data = tf.data.Dataset.from_tensor_slices((xtest.astype("float32")))
test_data = (
test_data
.cache()
.batch(batch_size)
.prefetch(2)
)
test_data = test_data.cache().batch(batch_size).prefetch(2)
"""

print("Data simulation finished")
Expand Down Expand Up @@ -125,7 +120,7 @@ def run_genomcmcgan(

while not convergence and max_num_iters != it:

print("Starting the MCMC sampling chain")
print(f"Starting the MCMC sampling chain for iteration {it}")
start_t = time.time()

is_accepted, log_acc_rate = mcmcgan.run_chain()
Expand All @@ -134,7 +129,9 @@ def run_genomcmcgan(
# Draw traceplot and histogram of collected samples
mcmcgan.traceplot_samples(inferable_params, it)
mcmcgan.hist_samples(inferable_params, it)
mcmcgan.jointplot(it)
print(mcmcgan.samples.shape)
if mcmcgan.samples.shape[1] == 2:
mcmcgan.jointplot_samples(inferable_params, it)

for i, p in enumerate(inferable_params):
p.proposals = mcmcgan.samples[:, i].numpy()
Expand Down Expand Up @@ -176,7 +173,7 @@ def run_genomcmcgan(
print(f"A single iteration of the MCMC-GAN took {t} seconds")

print(f"The estimates obtained after {it} iterations are:")
print(means)
print(float(means))


if __name__ == "__main__":
Expand Down
23 changes: 16 additions & 7 deletions mcmcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def D(self, x, num_reps=64):
self.genob.num_reps = num_reps

return tf.reduce_mean(
self.discriminator.predict(self.genob.simulate_msprime(x).astype("float32"))
self.discriminator.predict_on_batch(
self.genob.simulate_msprime(x).astype("float32")
)
)

# Where `D(x)` is the average discriminator output from n independent
Expand Down Expand Up @@ -300,7 +302,7 @@ def setup_mcmc(
self.mcmc_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
mcmc,
num_adaptation_steps=int(self.num_burnin_steps * 0.8),
target_accept_prob=0.75,
target_accept_prob=0.3,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
step_size=new_step_size
),
Expand Down Expand Up @@ -356,7 +358,7 @@ def run_chain(self):

def hist_samples(self, params, it, bins=10):

colors = ["b", "g", "r"]
colors = ["red", "blue", "green", "black", "gold", "chocolate", "teal"]
for i, p in enumerate(params):
sns.distplot(self.samples[:, i], color=colors[i])
ymax = plt.ylim()[1]
Expand All @@ -375,8 +377,10 @@ def hist_samples(self, params, it, bins=10):
def traceplot_samples(self, params, it):

# EXPAND COLORS FOR MORE PARAMETERS
colors = ["b", "g", "r"]
colors = ["red", "blue", "green", "black", "gold", "chocolate", "teal"]
sns.set_style("darkgrid")
print(len(params))
print(self.samples.shape)
for i, p in enumerate(params):
plt.plot(self.samples[:, i], c=colors[i], alpha=0.3)
plt.hlines(
Expand All @@ -400,13 +404,18 @@ def traceplot_samples(self, params, it):
)
plt.clf()

def jointplot(self, it):
def jointplot_samples(self, params, it):

p1 = list(params.values())[0]
p2 = list(params.values())[1]

g = sns.jointplot(self.samples[:, 0], self.samples[:, 1], kind="kde")
g.plot_joint(sns.kdeplot, color="b", zorder=0, levels=6)
g.plot_marginals(sns.rugplot, color="r", height=-0.15, clip_on=False)
plt.xlabel("P1")
plt.ylabel("P2")
plt.xlim(p1.bounds)
plt.ylim(p2.bounds)
plt.xlabel(p1.name)
plt.ylabel(p2.name)
plt.title(f"Jointplot at iteration {it}")
plt.savefig(f"./results/jointplot_{it}.png")
plt.clf()