Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate the multihit model into the DNSM framework #60

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,17 @@ def __init__(
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.all_rates = all_rates
self.all_subs_probs = all_subs_probs
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
# We want these parameters to act like fixed data. This is essential
# for multithreaded branch length optimization to work.
self.multihit_model.values.requires_grad_(False)

assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)
Expand Down Expand Up @@ -99,6 +105,7 @@ def of_seriess(
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
multihit_model=None,
):
"""Alternative constructor that takes the raw data and calculates the initial
branch lengths.
Expand All @@ -119,10 +126,11 @@ def of_seriess(
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
initial_branch_lengths,
multihit_model=multihit_model,
)

@classmethod
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column"
Expand All @@ -132,6 +140,7 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
pcp_df["rates"],
pcp_df["subs_probs"],
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)

def clone(self):
Expand All @@ -142,6 +151,7 @@ def clone(self):
self.all_rates.copy(),
self.all_subs_probs.copy(),
self._branch_lengths.copy(),
multihit_model=copy.deepcopy(self.multihit_model),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deepcopy and the one below are no longer necessary now that you are copying in the init, right?

)
return new_dataset

Expand All @@ -158,6 +168,7 @@ def subset_via_indices(self, indices):
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
multihit_model=copy.deepcopy(self.multihit_model),
)
return new_dataset

Expand Down Expand Up @@ -206,6 +217,10 @@ def update_neutral_aa_mut_probs(self):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
if self.multihit_model is not None:
multihit_model = self.multihit_model.to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -220,6 +235,7 @@ def update_neutral_aa_mut_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

if not torch.isfinite(neutral_aa_mut_prob).all():
Expand Down Expand Up @@ -270,23 +286,31 @@ def to(self, device):
self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)
if self.multihit_model is not None:
self.multihit_model = self.multihit_model.to(device)


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
def train_val_datasets_of_pcp_df(
pcp_df, branch_length_multiplier=5.0, multihit_model=None
):
"""Perform a train-val split based on a "in_train" column.

Stays here so it can be used in tests.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)
val_dataset = DNSMDataset.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
val_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DNSMDataset.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
train_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
return train_dataset, val_dataset

Expand Down Expand Up @@ -360,13 +384,14 @@ def _find_optimal_branch_length(
rates,
subs_probs,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
if parent == child:
return 0.0
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, subs_probs
sel_matrix, parent, child, rates, subs_probs, multihit_model
)
if type(starting_branch_length) == torch.Tensor:
starting_branch_length = starting_branch_length.detach().item()
Expand Down Expand Up @@ -395,6 +420,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
dataset.multihit_model,
**optimization_kwargs,
)

Expand All @@ -410,7 +436,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
with mp.Pool(worker_count) as pool:
Expand Down
21 changes: 13 additions & 8 deletions netam/hit_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor
]



def apply_multihit_correction(
parent_codon_idxs: torch.Tensor,
codon_logprobs: torch.Tensor,
codon_probs: torch.Tensor,
hit_class_factors: torch.Tensor,
) -> torch.Tensor:
"""Multiply codon probabilities by their hit class factors, and renormalize.
Expand All @@ -53,23 +54,27 @@ def apply_multihit_correction(
Args:
parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the
indices of the parent codon's nucleotides.
codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities
codon_probs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the probabilities
of mutating to each possible target codon, for each of the N parent codons.
hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The
factor for hit class 0 is assumed to be 1 (that is, 0 in log-space).

Returns:
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible
torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible
target codon, for each of the N parent codons, after applying the hit class factors.
"""
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), hit_class_factors])
corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]).exp()
reshaped_corrections = corrections[per_parent_hit_class]
unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections
normalizations = torch.logsumexp(
unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True
unnormalized_corrected_probs = codon_probs * reshaped_corrections
normalizations = torch.sum(
unnormalized_corrected_probs, dim=[1, 2, 3], keepdim=True
)
return unnormalized_corrected_logprobs - normalizations
result = unnormalized_corrected_probs / normalizations
if torch.any(torch.isnan(result)):
print("NAN found in multihit correction application")
assert False
return result


def hit_class_probs_tensor(
Expand Down
5 changes: 3 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,14 +702,15 @@ def __init__(self):
def hyperparameters(self):
return {}

# TODO changed to nonlog version for testing, need to update all calls to reflect this
def forward(
self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor
self, parent_codon_idxs: torch.Tensor, uncorrected_codon_probs: torch.Tensor
):
"""Forward function takes a tensor of target codon distributions, for each
observed parent codon, and adjusts the distributions according to the hit class
corrections."""
return apply_multihit_correction(
parent_codon_idxs, uncorrected_log_codon_probs, self.values
parent_codon_idxs, uncorrected_codon_probs, self.values
)

def reinitialize_weights(self):
Expand Down
43 changes: 37 additions & 6 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
"""

import numpy as np
import warnings

import torch
from torch import Tensor, optim

from netam.sequences import CODON_AA_INDICATOR_MATRIX

import netam.sequences as sequences
# torch.autograd.set_detect_anomaly(True)


def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor:
Expand Down Expand Up @@ -283,6 +285,7 @@ def build_codon_mutsel(
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
aa_sel_matrices: Tensor,
multihit_model=None,
) -> Tensor:
"""Build a sequence of codon mutation-selection matrices for codons along a
sequence.
Expand All @@ -301,6 +304,11 @@ def build_codon_mutsel(
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)
else:
warnings.warn("No multihit model provided. Using uncorrected probabilities.")

# Calculate the codon selection matrix for each sequence via Einstein
# summation, in which we sum over the repeated indices.
# So, for each site (s) and codon (c), sum over amino acids (a):
Expand Down Expand Up @@ -339,6 +347,7 @@ def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will mutate to every
amino acid?
Expand All @@ -356,10 +365,15 @@ def neutral_aa_probs(
mut_matrices = build_mutation_matrices(
parent_codon_idxs, codon_mut_probs, codon_sub_probs
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)
else:
warnings.warn("No multihit model provided. Using uncorrected probabilities.")

# Get the probability of mutating to each amino acid.
aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX
aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX

return aa_probs

Expand Down Expand Up @@ -396,6 +410,7 @@ def neutral_aa_mut_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
Expand All @@ -414,12 +429,19 @@ def neutral_aa_mut_probs(
Shape: (codon_count,)
"""

aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
aa_probs = neutral_aa_probs(
parent_codon_idxs,
codon_mut_probs,
codon_sub_probs,
multihit_model=multihit_model,
)
mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs)
return mut_probs


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, sub_probs, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given rates and
sub_probs.

Expand All @@ -442,6 +464,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor):
mut_probs.reshape(-1, 3),
sub_probs.reshape(-1, 3, 4),
sel_matrix,
multihit_model=multihit_model,
)

# This is a diagnostic generating data for netam issue #7.
Expand Down Expand Up @@ -482,6 +505,7 @@ def optimize_branch_length(

step_idx = 0

nan_issue = False
for step_idx in range(max_optimization_steps):
# For some PCPs, the optimizer works very hard optimizing very tiny branch lengths.
if log_branch_length < log_branch_length_lower_threshold:
Expand All @@ -496,7 +520,14 @@ def optimize_branch_length(
loss.backward()
torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0)
optimizer.step()
assert not torch.isnan(log_branch_length)
if torch.isnan(log_branch_length):
print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length)
if np.isclose(prev_log_branch_length.detach().numpy(), 0):
log_branch_length = prev_log_branch_length
nan_issue = True
break
else:
assert False

change_in_log_branch_length = torch.abs(
log_branch_length - prev_log_branch_length
Expand All @@ -506,7 +537,7 @@ def optimize_branch_length(

prev_log_branch_length = log_branch_length.clone()

if step_idx == max_optimization_steps - 1:
if step_idx == max_optimization_steps - 1 or nan_issue:
print(
f"Warning: optimization did not converge after {max_optimization_steps} steps; log branch length is {log_branch_length.detach().item()}"
)
Expand Down
Loading