Skip to content

Commit

Permalink
Merge pull request #25 from pabloswfly/master
Browse files Browse the repository at this point in the history
Added demography.py module, added jointplots
  • Loading branch information
pabloswfly authored Dec 14, 2020
2 parents e4ea0aa + 3c2b123 commit c906c05
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 40 deletions.
115 changes: 115 additions & 0 deletions demography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import msprime
import math


def onepop_constant(args):
"""Single population model with pop size Ne and constant growth"""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "Ne"]
assert sorted(necessary_params) == sorted(
list(params.keys())
), "Invalid combination of parameters. Needed: mu | r | Ne"

if proposals:
mu, r, Ne = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, Ne = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

ts = msprime.simulate(
sample_size=genob.num_samples,
Ne=Ne,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts


def onepop_exp(args):
"""Single population model with sudden population size increase from N1 to N2
at time T1 and exponential growth at time T2"""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "T1", "N1", "T2", "N2", "growth"]
assert sorted(necessary_params) == sorted(
list(params.keys())
), "Invalid combination of parameters. Needed: mu | r | T1 | N1 | T2 | N2 | growth"

if proposals:
mu, r, T1, N1, T2, N2, growth = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, T1, N1, T2, N2, growth = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

N0 = N2 / math.exp(growth * T2)

# Time is given in generations unit (t/25)
demographic_events = [
msprime.PopulationParametersChange(time=0, initial_size=N0, growth_rate=growth),
msprime.PopulationParametersChange(time=T2, initial_size=N2, growth_rate=0),
msprime.PopulationParametersChange(time=T1, initial_size=N1),
]

ts = msprime.simulate(
sample_size=genob.num_samples,
demographic_events=demographic_events,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts


def onepop_migration(args):
"""Mass migration at time T1 from population 1 with pop size N2 to population
0 with pop size N1. Samples are collected only from population 0."""

genob, params, randomize, i, proposals, seed = args
necessary_params = ["mu", "r", "T1", "N1", "N2", "mig"]
assert sorted(necessary_params) == sorted(list(params.keys())), (
"Invalid combination of parameters. Needed: mu | r | T1 | N1 | N2 | mig \n"
f"Obtained: {list(params.keys())}"
)

if proposals:
mu, r, T1, N1, N2, mig = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, T1, N1, N2, mig = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

population_configurations = [
msprime.PopulationConfiguration(sample_size=genob.num_samples, initial_size=N1),
msprime.PopulationConfiguration(sample_size=0, initial_size=N2),
]

# migration from pop 1 into pop 0 (back in time)
mig_event = msprime.MassMigration(time=T1, source=1, destination=0, proportion=mig)

ts = msprime.simulate(
population_configurations=population_configurations,
demographic_events=[mig_event],
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)

return ts
47 changes: 23 additions & 24 deletions genobuilder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import OrderedDict
import concurrent.futures
import msprime
import pickle
import argparse
import stdpopsim
Expand All @@ -11,6 +10,7 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import demography as dm
from parameter import Parameter

_ex = None
Expand All @@ -24,27 +24,11 @@ def executor(p):


def do_sim(args):
genob, params, randomize, i, proposals, seed = args
rng = random.Random(seed)

if proposals:
Ne, mu, r = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in ("Ne", "mu", "r")
]
else:
Ne, mu, r = [
params[p].rand() if randomize else params[p].val for p in ("Ne", "mu", "r")
]

ts = msprime.simulate(
sample_size=genob.num_samples,
Ne=Ne,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
random_seed=seed,
)
seed = args[5]
genob = args[0]
rng = random.Random(seed)
ts = dm.onepop_exp(args)

return genob._resize_from_ts(ts, rng)

Expand Down Expand Up @@ -765,9 +749,24 @@ def vcf2zarr(vcf_files, pop_file, zarr_path):
args = parser.parse_args()
params_dict = OrderedDict()

params_dict["r"] = Parameter("r", 1.25e-9, 1e-10, (1e-11, 1e-7), inferable=True)
params_dict["mu"] = Parameter("mu", 1.25e-8, 1e-9, (1e-10, 1e-7), inferable=False)
params_dict["Ne"] = Parameter("Ne", 10000, 14000, (5000, 15000), inferable=True)
params_dict["r"] = Parameter("r", 2e-8, 1e-10, (1e-11, 1e-7), inferable=False)
params_dict["mu"] = Parameter("mu", 1.29e-8, 1e-9, (1e-10, 1e-7), inferable=False)
# params_dict["Ne"] = Parameter("Ne", 10000, 14000, (5000, 15000), inferable=True)

# For onepop_exp model:
params_dict["T1"] = Parameter("T1", 3000, 4000, (1500, 5000), inferable=True)
params_dict["N1"] = Parameter("N1", 10000, 20000, (1000, 30000), inferable=True)
params_dict["T2"] = Parameter("T2", 500, 1000, (100, 1500), inferable=False)
params_dict["N2"] = Parameter("N2", 5000, 20000, (1000, 20000), inferable=True)
params_dict["growth"] = Parameter("growth", 0.01, 0.02, (0, 0.05), inferable=True)

# For onepop_migration model:
"""
params_dict["T1"] = Parameter("T1", 1000, 4000, (500, 5000), inferable=False)
params_dict["N1"] = Parameter("N1", 5000, 18000, (1000, 20000), inferable=False)
params_dict["N2"] = Parameter("N2", 8000, 15000, (1000, 20000), inferable=False)
params_dict["mig"] = Parameter("mig", 0.9, 0.2, (0, 0.3), inferable=True)
"""

genob = Genobuilder(
source=args.source,
Expand Down
15 changes: 6 additions & 9 deletions genomcmcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ def run_genomcmcgan(
"""
xtest = genob.generate_fakedata(num_reps=1000)
test_data = tf.data.Dataset.from_tensor_slices((xtest.astype("float32")))
test_data = (
test_data
.cache()
.batch(batch_size)
.prefetch(2)
)
test_data = test_data.cache().batch(batch_size).prefetch(2)
"""

print("Data simulation finished")
Expand Down Expand Up @@ -125,7 +120,7 @@ def run_genomcmcgan(

while not convergence and max_num_iters != it:

print("Starting the MCMC sampling chain")
print(f"Starting the MCMC sampling chain for iteration {it}")
start_t = time.time()

is_accepted, log_acc_rate = mcmcgan.run_chain()
Expand All @@ -134,7 +129,9 @@ def run_genomcmcgan(
# Draw traceplot and histogram of collected samples
mcmcgan.traceplot_samples(inferable_params, it)
mcmcgan.hist_samples(inferable_params, it)
mcmcgan.jointplot(it)
print(mcmcgan.samples.shape)
if mcmcgan.samples.shape[1] == 2:
mcmcgan.jointplot_samples(inferable_params, it)

for i, p in enumerate(inferable_params):
p.proposals = mcmcgan.samples[:, i].numpy()
Expand Down Expand Up @@ -176,7 +173,7 @@ def run_genomcmcgan(
print(f"A single iteration of the MCMC-GAN took {t} seconds")

print(f"The estimates obtained after {it} iterations are:")
print(means)
print(float(means))


if __name__ == "__main__":
Expand Down
23 changes: 16 additions & 7 deletions mcmcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def D(self, x, num_reps=64):
self.genob.num_reps = num_reps

return tf.reduce_mean(
self.discriminator.predict(self.genob.simulate_msprime(x).astype("float32"))
self.discriminator.predict_on_batch(
self.genob.simulate_msprime(x).astype("float32")
)
)

# Where `D(x)` is the average discriminator output from n independent
Expand Down Expand Up @@ -300,7 +302,7 @@ def setup_mcmc(
self.mcmc_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
mcmc,
num_adaptation_steps=int(self.num_burnin_steps * 0.8),
target_accept_prob=0.75,
target_accept_prob=0.3,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
step_size=new_step_size
),
Expand Down Expand Up @@ -356,7 +358,7 @@ def run_chain(self):

def hist_samples(self, params, it, bins=10):

colors = ["b", "g", "r"]
colors = ["red", "blue", "green", "black", "gold", "chocolate", "teal"]
for i, p in enumerate(params):
sns.distplot(self.samples[:, i], color=colors[i])
ymax = plt.ylim()[1]
Expand All @@ -375,8 +377,10 @@ def hist_samples(self, params, it, bins=10):
def traceplot_samples(self, params, it):

# EXPAND COLORS FOR MORE PARAMETERS
colors = ["b", "g", "r"]
colors = ["red", "blue", "green", "black", "gold", "chocolate", "teal"]
sns.set_style("darkgrid")
print(len(params))
print(self.samples.shape)
for i, p in enumerate(params):
plt.plot(self.samples[:, i], c=colors[i], alpha=0.3)
plt.hlines(
Expand All @@ -400,13 +404,18 @@ def traceplot_samples(self, params, it):
)
plt.clf()

def jointplot(self, it):
def jointplot_samples(self, params, it):

p1 = list(params.values())[0]
p2 = list(params.values())[1]

g = sns.jointplot(self.samples[:, 0], self.samples[:, 1], kind="kde")
g.plot_joint(sns.kdeplot, color="b", zorder=0, levels=6)
g.plot_marginals(sns.rugplot, color="r", height=-0.15, clip_on=False)
plt.xlabel("P1")
plt.ylabel("P2")
plt.xlim(p1.bounds)
plt.ylim(p2.bounds)
plt.xlabel(p1.name)
plt.ylabel(p2.name)
plt.title(f"Jointplot at iteration {it}")
plt.savefig(f"./results/jointplot_{it}.png")
plt.clf()

0 comments on commit c906c05

Please sign in to comment.