Skip to content

Commit

Permalink
Add mosek into the default solver list. Update multinomial.py in simu…
Browse files Browse the repository at this point in the history
…lator
  • Loading branch information
Uyen Mai committed Mar 6, 2024
1 parent d39d498 commit fdd24b9
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 48 deletions.
150 changes: 106 additions & 44 deletions emd/emd_normal_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,23 @@ def EM_date_random_init(tree,smpl_times,init_rate_distr,s=1000,nrep=100,maxIter=
print("Random seed: " + str(rseeds[r]))
new_tree = read_tree_newick(tree.newick())
#try:
tau,omega,phi,llh,Q = EM_date(new_tree,smpl_times,init_rate_distr,s=s,maxIter=maxIter,refTree=refTree,init_Q=init_Q,fixed_tau=fixed_tau,verbose=verbose,mu_avg=mu_avg,fixed_omega=fixed_omega,pseudo=pseudo,CI_options=None)
ans,constr = EM_date(new_tree,smpl_times,init_rate_distr,s=s,maxIter=maxIter,refTree=refTree,init_Q=init_Q,fixed_tau=fixed_tau,verbose=verbose,mu_avg=mu_avg,fixed_omega=fixed_omega,pseudo=pseudo)
tau,omega,phi,llh,Q = ans['tau'],ans['omega'],ans['phi'],ans['llh'],ans['Q']
convert_to_time(new_tree,tau,omega,phi,Q)
new_ref = new_tree
new_tree = read_tree_newick(tree.newick())
omega_adjusted = [o for o,p in zip(omega,phi) if p > 1e-6]
phi_adjusted = [p for p in phi if p > 1e-6]
sum_phi = sum(phi_adjusted)
phi_adjusted = [p/sum_phi for p in phi_adjusted]
tau,omega,phi,llh,Q = EM_date(new_tree,smpl_times,s=s,init_rate_distr=multinomial(omega_adjusted,phi_adjusted),maxIter=maxIter,refTree=new_ref,init_Q=None,fixed_tau=fixed_tau,verbose=verbose,mu_avg=None,fixed_omega=fixed_omega,pseudo=pseudo,CI_options=CI_options)
ans,constr = EM_date(new_tree,smpl_times,s=s,init_rate_distr=multinomial(omega_adjusted,phi_adjusted),maxIter=maxIter,refTree=new_ref,init_Q=None,fixed_tau=fixed_tau,verbose=verbose,mu_avg=None,fixed_omega=fixed_omega,pseudo=pseudo)
tau,omega,phi,llh,Q = ans['tau'],ans['omega'],ans['phi'],ans['llh'],ans['Q']
#convert branch length to time unit and compute mu for each branch
convert_to_time(new_tree,tau,omega,phi,Q)
# compute divergence times
compute_divergence_time(new_tree,smpl_times,place_mu=place_mu,place_q=place_q,as_date=as_date,bw_time=bw_time)
# place confidence intervals
for node in new_tree.traverse_preorder():
if not node.is_root():
_,tau_lower,_,tau_upper = node.tau_CI
node.edge_length = str(node.edge_length) + "[" + str(tau_lower) + "," + str(tau_upper) + "]"
#print(node.edge_length,tau_lower,tau_upper)
compute_divergence_time(new_tree,smpl_times)
#annotate_divergence_time(new_tree,place_mu=place_mu,place_q=place_q,as_date=as_date,bw_time=bw_time)

# output
if verbose:
print("New llh: " + str(llh))
Expand All @@ -94,15 +92,28 @@ def EM_date_random_init(tree,smpl_times,init_rate_distr,s=1000,nrep=100,maxIter=
if llh > best_llh:
best_llh = llh
best_tree = new_tree
best_tau = tau
best_phi = phi
best_omega = omega
best_Q = Q
#except mosek.Error:
# raise Exception("Mosek license not found!")
#except:
# print("Failed to optimize using this init point!")
# place confidence intervals
if CI_options is not None:
b,M,dt = constr['b'],constr['M'],constr['dt']
get_confidence_interval(best_tree,smpl_times,best_tau,best_omega,best_Q,np.array(b),s,M,dt,CI_options,eps_tau=EPS_tau)
convert_to_time(best_tree,best_tau,best_omega,best_phi,best_Q)
compute_divergence_time(best_tree,smpl_times)
annotate_divergence_time(best_tree,place_mu=place_mu,place_q=place_q,as_date=as_date,bw_time=bw_time)
for node in best_tree.traverse_preorder():
if not node.is_root():
_,tau_lower,_,tau_upper = node.tau_CI
node.edge_length = str(node.edge_length) + "[" + str(tau_lower) + "," + str(tau_upper) + "]"
return best_tree,best_llh,best_phi,best_omega

def EM_date(tree,smpl_times,init_rate_distr,refTree=None,s=1000,df=5e-4,maxIter=100,eps_tau=EPS_tau,fixed_tau=False,verbose=False,mu_avg=None,fixed_omega=False,pseudo=0,init_Q=None,CI_options=None):
def EM_date(tree,smpl_times,init_rate_distr,refTree=None,s=1000,df=5e-4,maxIter=100,eps_tau=EPS_tau,fixed_tau=False,verbose=False,mu_avg=None,fixed_omega=False,pseudo=0,init_Q=None):
M, dt, b = setup_constr(tree,smpl_times,s,eps_tau=eps_tau,pseudo=pseudo)
Q, tau, phi, omega = init_EM(tree,b,init_rate_distr,s=s,refTree=refTree,init_Q=init_Q)
if verbose:
Expand Down Expand Up @@ -133,10 +144,13 @@ def EM_date(tree,smpl_times,init_rate_distr,refTree=None,s=1000,df=5e-4,maxIter=
print("Estep ...")
Q = run_Estep(b,s,omega,tau,phi,var_apprx=True)

if CI_options is not None:
tau_boots = get_confidence_interval(tree,smpl_times,omega,Q,np.array(b),s,M,dt,CI_options,eps_tau=EPS_tau)
#if CI_options is not None:
# get_confidence_interval(tree,smpl_times,omega,Q,np.array(b),s,M,dt,CI_options,eps_tau=EPS_tau)

return tau,omega,phi,llh,Q
#return tau,omega,phi,llh,Q
ans = {'tau':tau,'omega':omega,'phi':phi,'llh':llh,'Q':Q}
constr = {'M':M,'b':b,'dt':dt}
return ans,constr

def convert_to_time(tree,tau,omega,phi,Q):
# convert branch length to time unit and compute mu for each branch
Expand All @@ -149,7 +163,7 @@ def convert_to_time(tree,tau,omega,phi,Q):
node.mu = round(sum(o*p for (o,p) in zip(omega,phi)),nDIGITS)
node.q = None

def compute_divergence_time(tree,sampling_time,bw_time=False,as_date=False,place_mu=True,place_q=False):
def compute_divergence_time(tree,sampling_time):
# compute and place the divergence time onto the node label of the tree
# must have at least one sampling time. Assumming the tree branches have been
# converted to time unit and are consistent with the given sampling_time
Expand Down Expand Up @@ -195,19 +209,35 @@ def compute_divergence_time(tree,sampling_time,bw_time=False,as_date=False,place
stk.append(c)
node.time = t

def convert_divTime(t,bw_time=False,as_date=False):
if as_date:
divTime = years_to_date(t)
else:
divTime = str(round(t,nDIGITS)) if not bw_time else str(-round(t,nDIGITS))
return divTime

def annotate_divergence_time(tree,bw_time=False,as_date=False,place_mu=True,place_q=False):
# place the divergence time and mutation rate onto the label
for node in tree.traverse_postorder():
lb = node.get_label()
assert node.time is not None, "Failed to compute divergence time for node " + lb
if as_date:
divTime = years_to_date(node.time)
else:
divTime = str(round(node.time,nDIGITS)) if not bw_time else str(-round(node.time,nDIGITS))
divTime = convert_divTime(node.time,bw_time=bw_time,as_date=as_date)
tag = "[t=" + divTime
if hasattr(node,'divTime_CI'):
_,t_lower,_,t_upper = node.divTime_CI
divTime_lower = convert_divTime(t_lower,bw_time=bw_time,as_date=as_date)
divTime_upper = convert_divTime(t_upper,bw_time=bw_time,as_date=as_date)
tag += ",t_lower=" + str(divTime_lower)
tag += ",t_upper=" + str(divTime_upper)
if place_mu:
tag += ",mu=" + str(node.mu)
if hasattr(node,'mu_CI'):
_,mu_lower,_,mu_upper = node.mu_CI
tag += ",mu_lower=" + str(mu_lower)
tag += ",mu_upper=" + str(mu_upper)

if place_q and node.q is not None:
tag += ",q=(" + ",".join(str(x) for x in node.q) + ")"
tag += ",q=(" + ",".join(str(x) for x in node.q) + ")"
tag += "]"
lb = lb + tag if lb else tag
node.set_label(lb)
Expand Down Expand Up @@ -710,7 +740,7 @@ def compute_f_MM(tau,omega,Q,b,s,var_apprx=True):
F += s*Q[i][j]*(b[i]-omega[j]*tau[i])**2/w_ij + Q[i][j]*log(w_ij)
return F

def compute_tau_star_cvxpy(tau,omega,Q,b,s,M,dt,eps_tau=EPS_tau,var_apprx=False,solvers=['osqp','cvxopt','ecos']):
def compute_tau_star_cvxpy(tau,omega,Q,b,s,M,dt,eps_tau=EPS_tau,var_apprx=False,solvers=['mosek','osqp','cvxopt','ecos']):
N = len(b)
k = len(omega)
Pd = np.zeros(N)
Expand Down Expand Up @@ -760,7 +790,7 @@ def compute_CI(a_list,p_lower=0.025,p_upper=0.975):
idx_higher = ceil(p_upper*N)-1
return s_list[idx_lower],s_list[idx_higher]

def get_confidence_interval(tree,smpl_times,omega,Q,b,s,M,dt,CI_options,eps_tau=EPS_tau):
def get_confidence_interval(tree,smpl_times,tau,omega,Q,b,s,M,dt,CI_options,eps_tau=EPS_tau):
nboots = CI_options['nboots']
p_lower = CI_options['p_lower']
p_upper = CI_options['p_upper']
Expand All @@ -769,32 +799,64 @@ def get_confidence_interval(tree,smpl_times,omega,Q,b,s,M,dt,CI_options,eps_tau=
mu_boots = [np.zeros(N) for i in range(nboots)]
b_boots = [np.zeros(N) for i in range(nboots)]
tau_boots = [np.zeros(N) for i in range(nboots)]
mu_avg = np.zeros(N)
for node in tree.traverse_postorder():
if node.is_root():
continue
phi = Q[node.idx]
#phi = [1/k]*k
R = multinomial(omega,phi)
for i in range(nboots):
mu_boots[i][node.idx] = R.randomize()
b_boots[i][node.idx] = norm.rvs(b[node.idx],sqrt(b[node.idx]/s))
tau_boots[i][node.idx] = max(EPS_tau,b_boots[i][node.idx]/mu_boots[i][node.idx])

#tau_boots = [[]]*nboots
#for i in range(nboots):
#mu = mu_boots[i]
#bb = b_boots[i]
#var_tau = cp.Variable(N)
#objective = cp.Minimize((s/b).T @ (bb-mu.T @ var_tau)**2)
#constraints = [np.zeros(N)+eps_tau <= var_tau, csr_matrix(M)@var_tau == np.array(dt)]
#prob = cp.Problem(objective,constraints)
#f_star = prob.solve(verbose=False,solver=cp.MOSEK)
#tau_boots[i] = var_tau.value
R = multinomial(omega,phi)
omega_lower = R.get_quantize(p_lower)
omega_upper = R.get_quantize(p_upper)
node.mu_CI = (p_lower,omega_lower,p_upper,omega_upper)
mu_avg[node.idx] = sum(o*p for o,p in zip(omega,phi))
#for i in range(nboots):
# mu_boots[i][node.idx] = R.randomize()
# b_boots[i][node.idx] = b[node.idx]
#b_boots[i][node.idx] = norm.rvs(b[node.idx],sqrt(b[node.idx]/s))
#tau_boots[i][node.idx] = max(EPS_tau,b_boots[i][node.idx]/mu_boots[i][node.idx])

divTime_boots = [np.zeros(N+1) for i in range(nboots)]
i = 0
while i < nboots:
#try:
for node in tree.traverse_postorder():
if node.is_root():
continue
#phi = Q[node.idx]
phi = [1/k]*k
R = multinomial(omega,phi)
mu_boots[i][node.idx] = R.randomize()
#b_boots[i][node.idx] = max(EPS_tau*EPS_omg,norm.rvs(mu_avg[node.idx]*tau[node.idx],sqrt(mu_avg[node.idx]*tau[node.idx]/s)))
#b_boots[i][node.idx] = norm.rvs(b[node.idx],sqrt(b[node.idx]/s))
b_boots[i][node.idx] = b[node.idx]
mu = mu_boots[i]
bb = b_boots[i]
var_tau = cp.Variable(N)
#W = np.diag([sqrt(s/(m*t)) for m,t in zip(mu_avg,tau)])
W = np.diag([sqrt(s/x) for x in bb])
objective = cp.Minimize(cp.sum_squares( W @ (bb-np.diag(mu) @ var_tau)))
#objective = cp.Minimize(cp.sum_squares((bb-np.diag(mu) @ var_tau)))
constraints = [np.zeros(N)+eps_tau <= var_tau, csr_matrix(M)@var_tau == np.array(dt)]
prob = cp.Problem(objective,constraints)
f_star = prob.solve(verbose=False,solver=cp.MOSEK)
tau_boots[i] = var_tau.value
# compute divergence time
for node in tree.traverse_postorder():
if not node.is_root():
node.edge_length = tau_boots[i][node.idx]
compute_divergence_time(tree,smpl_times)
for node in tree.traverse_postorder():
divTime_boots[i][node.idx] = node.time
i = i+1
#except:
# print("Failed to get estimate on this bootstrap. Retrying...")

for node in tree.traverse_postorder():
if node.is_root():
continue
tau_list = [tau_boots[i][node.idx] for i in range(nboots)]
tau_lower,tau_upper = compute_CI(tau_list,p_lower=p_lower,p_upper=p_upper)
node.tau_CI = (p_lower,tau_lower,p_upper,tau_upper)
return tau_boots
divTime_list = [divTime_boots[i][node.idx] for i in range(nboots)]
divTime_lower,divTime_upper = compute_CI(divTime_list,p_lower=p_lower,p_upper=p_upper)
node.divTime_CI = (p_lower,divTime_lower,p_upper,divTime_upper)
if not node.is_root():
tau_list = [tau_boots[i][node.idx] for i in range(nboots)]
tau_lower,tau_upper = compute_CI(tau_list,p_lower=p_lower,p_upper=p_upper)
node.tau_CI = (p_lower,tau_lower,p_upper,tau_upper)
#return tau_boots
11 changes: 7 additions & 4 deletions simulator/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
def cdf_from_pdf(p):
c = [0]*len(p)
c[0] = p[0]

for i in range(1,len(p)):
c[i] = c[i-1] + p[i]

return c

def binary_search(arr,v,start=0,end=None):
Expand All @@ -29,9 +27,14 @@ def binary_search(arr,v,start=0,end=None):
class multinomial:
def __init__(self,omega,phi):
# omega stores the values, phi stores the probability
self.omega = omega
self.phi = phi
S = sorted(zip(omega,phi))
self.omega = [x[0] for x in S]
self.phi = [x[1] for x in S]
self.acc = cdf_from_pdf(phi) # accumulative density

def get_quantize(self,q):
i = binary_search(self.acc,q)
return self.omega[i]

def randomize(self):
r = random()
Expand Down

0 comments on commit fdd24b9

Please sign in to comment.