Skip to content

Commit

Permalink
Merge pull request RacimoLab#25 from pabloswfly/bottleneck
Browse files Browse the repository at this point in the history
Bottleneck
  • Loading branch information
pabloswfly authored Jun 15, 2021
2 parents 5331a9d + 7a950de commit be7a2ff
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 7 deletions.
40 changes: 40 additions & 0 deletions demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,46 @@ def zigzag(args):
)


def bottleneck(args):

genob, params, randomize, i, proposals = args
necessary_params = ["mu", "r", "N0", "T1", "N1", "T2", "N2"]

for p in necessary_params:
if p not in list(params.keys()):
print(
"Invalid combination of parameters. Needed: "
"mu | r | N0 | T1 | N1 | T2 | N2"
)

if proposals:
mu, r, N0, T1, N1, T2, N2 = [
params[p].prop(i) if params[p].inferable else params[p].val
for p in necessary_params
]
else:
mu, r, N0, T1, N1, T2, N2 = [
params[p].rand() if randomize else params[p].val for p in necessary_params
]

# Infer the 3 pop sizes, where N0 = N2
demographic_events = [
msprime.PopulationParametersChange(time=0, initial_size=N0),
msprime.PopulationParametersChange(time=T1, initial_size=N1),
msprime.PopulationParametersChange(time=T2, initial_size=N2),
]

return msprime.simulate(
sample_size=genob.num_samples,
demographic_events=demographic_events,
length=genob.seq_len,
mutation_rate=mu,
recombination_rate=r,
# random_seed=genob.seed,
)



def ghost_migration(args):
"""Mass migration at time T1 from population 1 with pop size N2 to population
0 with pop size N1. Samples are collected only from population 0."""
Expand Down
30 changes: 26 additions & 4 deletions genobuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import zarr
import random
import bisect
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import demography as dm
from parameter import Parameter
import demography as dm
import numpy as np

_ex = None

Expand All @@ -37,6 +37,8 @@ def do_sim(args):
ts = dm.zigzag(args)
elif genob.demo_model == "ghost_migration":
ts = dm.ghost_migration(args)
elif genob.demo_model == "bottleneck":
ts = dm.bottleneck(args)

if params["seqerr"].inferable:
# Return resized matrix
Expand Down Expand Up @@ -704,7 +706,7 @@ def locate(sorted_idx, start=None, stop=None):
parser.add_argument(
"demographic_model",
help="One population demographic model to use for simulations in msprime.",
choices=["constant", "exponential", "zigzag", "ghost_migration"],
choices=["constant", "exponential", "zigzag", "ghost_migration", "bottleneck"],
)

parser.add_argument(
Expand Down Expand Up @@ -808,14 +810,15 @@ def locate(sorted_idx, start=None, stop=None):
params_dict["seqerr"] = Parameter("seqerr", None, (0.00001, 0.01), inferable=False)

if args.demographic_model == "constant":
# Parameters for exponential model: FIX BOUNDS INFERABLE
params_dict["Ne"] = Parameter("Ne", 10000, (5000, 30000), inferable=True)

elif args.demographic_model == "exponential":
# Parameters for exponential model: FIX BOUNDS INFERABLE
params_dict["T1"] = Parameter("T1", 500, (100, 1500), inferable=True)
params_dict["N1"] = Parameter("N1", 10000, (1000, 30000), inferable=True)
params_dict["N1"] = Parameter("N1", 10000, (1000, 30000), inferable=False)
params_dict["T2"] = Parameter("T2", 3000, (1500, 5000), inferable=True)
params_dict["N2"] = Parameter("N2", 5000, (1000, 30000), inferable=True)
params_dict["N2"] = Parameter("N2", 5000, (1000, 30000), inferable=False)
params_dict["growth"] = Parameter("growth", 0.01, (0, 0.05), inferable=True)

elif args.demographic_model == "zigzag":
Expand All @@ -831,7 +834,26 @@ def locate(sorted_idx, start=None, stop=None):
params_dict["T5"] = Parameter("T5", 8533, (5001, 10000), inferable=False)
params_dict["N5"] = Parameter("N5", 71560, (1000, 100000), inferable=True)

elif args.demographic_model == "bottleneck":
# Parameters for exponential model: FIX BOUNDS INFERABLE
params_dict["N0"] = Parameter("N0", 10000, (100, 30000), inferable=True)
params_dict["T1"] = Parameter("T1", 1000, (100, 1500), inferable=False)
params_dict["N1"] = Parameter("N1", 1000, (100, 30000), inferable=True)
params_dict["T2"] = Parameter("T2", 2000, (1500, 5000), inferable=False)
params_dict["N2"] = Parameter("N2", 10000, (100, 30000), inferable=True)

import msprime
demographic_events = [
msprime.PopulationParametersChange(time=0, initial_size=params_dict["N0"].val),
msprime.PopulationParametersChange(time=params_dict["T1"].val, initial_size=params_dict["N1"].val),
msprime.PopulationParametersChange(time=params_dict["T2"].val, initial_size=params_dict["N2"].val),
]

debugger = msprime.DemographyDebugger(Ne=10000, demographic_events=demographic_events)
debugger.print_history()

elif args.demographic_model == "ghost_migration":
# Parameters for exponential model: FIX BOUNDS INFERABLE
params_dict["T1"] = Parameter("T1", 1000, (500, 5000), inferable=False)
params_dict["N1"] = Parameter("N1", 5000, (1000, 20000), inferable=False)
params_dict["N2"] = Parameter("N2", 8000, (1000, 20000), inferable=False)
Expand Down
5 changes: 2 additions & 3 deletions genomcmcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def run_genomcmcgan(

# Use GPUs for Discriminator operations if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Using", torch.cuda.device_count(), "GPUs")
mcmcgan.discriminator = nn.DataParallel(mcmcgan.discriminator)
print("Using", torch.cuda.device_count(), "GPUs")
mcmcgan.discriminator = nn.DataParallel(mcmcgan.discriminator)
mcmcgan.discriminator.to(device)

print("Initializing weights of the model")
Expand Down
1 change: 1 addition & 0 deletions training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def plot_pair_evolution(params, mcmc_kernel):
files.append(file)
files = sorted(files, key=lambda x: int(x[9:-4]))
arvzs, cs = [], []

for i, f in enumerate(files):
with open(f"./results/{f}", "rb") as obj:
i += 1
Expand Down

0 comments on commit be7a2ff

Please sign in to comment.