diff --git a/netam/dcsm.py b/netam/dcsm.py new file mode 100644 index 00000000..c766da0c --- /dev/null +++ b/netam/dcsm.py @@ -0,0 +1,405 @@ +"""Defining the deep natural selection model (DNSM).""" + +import copy + +import pandas as pd +import torch +import torch.nn.functional as F + +from netam.common import ( + assert_pcp_valid, + clamp_probability, + codon_mask_tensor_of, +) +from netam.dxsm import DXSMDataset, DXSMBurrito +from netam.hyper_burrito import HyperBurrito +import netam.molevol as molevol + +# TODO strange to have this in common-- I think that it has been moved already upstream? +from netam.common import aa_idx_tensor_of_str_ambig +from netam.sequences import ( + aa_idx_array_of_str, + aa_subs_indicator_tensor_of, + nt_idx_tensor_of_str, + token_mask_of_aa_idxs, + translate_sequence, + translate_sequences, + codon_idx_tensor_of_str_ambig, + AMBIGUOUS_CODON_IDX, + CODON_AA_INDICATOR_MATRIX, + MAX_AA_TOKEN_IDX, + RESERVED_TOKEN_REGEX, +) + + +def masked_log_subtract(log_probs, mask, parent_indices, eps=1e-8): + """ + Calculates log(1 - sum(exp(log_probs_children))) in a numerically stable way. + + Args: + log_probs: Tensor of shape [B, L, C] (batch, sequence length, classes) + mask: Boolean tensor of shape [B, L] indicating valid positions. + parent_indices: Tensor of shape [B, L] indicating parent indices. + eps: Small value for numerical stability. + + Returns: + Tensor of shape [B, L, C] with parent log probabilities updated. + """ + + # 1. Mask out parent positions in the log_probs + # Use a very negative value (-1e9 is typical, but be mindful of your logit range) + masked_log_probs = log_probs.masked_fill( + torch.zeros_like(log_probs, dtype=torch.bool) + .scatter_(-1, parent_indices.unsqueeze(-1), 1) + .to(device=log_probs.device, dtype=torch.bool), + -1e9, + ) + + # 2. Calculate log(sum(exp(log_probs_children))) using log-sum-exp trick + log_sum_exp_children = torch.logsumexp(masked_log_probs, dim=-1, keepdim=True) + + # 3. Calculate log(1 - sum(exp(log_probs_children))) using a modified log-sum-exp trick: + # log(1 - x) = log(1 - exp(log(x))) + # = log(exp(0) - exp(log(x))) + # = log(exp(log(exp(0) - exp(log(x))))) + # = log(exp(log1mexp(log(x)))) (where log1mexp(a) = log(1 - exp(a))) + # We need to handle cases where log_sum_exp_children is close to 0. + # We use torch.where to select between direct computation and log1mexp approximation. + + log_parent_probs = torch.where( + log_sum_exp_children > torch.log(torch.tensor(0.5)), # Using 0.5 as a threshold + torch.log(eps + 1.0 - torch.exp(log_sum_exp_children)), + torch.log1p(-torch.exp(log_sum_exp_children)), + ) + + # 4. Scatter the log_parent_probs back into the original log_probs + log_probs = log_probs.scatter(-1, parent_indices.unsqueeze(-1), log_parent_probs) + + # 5. Apply the mask for valid positions + log_probs = torch.where( + mask.unsqueeze(-1), log_probs, torch.tensor(0.0, device=log_probs.device) + ) + + return log_probs + + +class DCSMDataset(DXSMDataset): + + def __init__( + self, + nt_parents: pd.Series, + nt_children: pd.Series, + nt_ratess: torch.Tensor, + nt_cspss: torch.Tensor, + branch_lengths: torch.Tensor, + multihit_model=None, + ): + self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) + # We will replace reserved tokens with Ns but use the unmodified + # originals for codons and mask creation. + self.nt_children = nt_children.str.replace( + RESERVED_TOKEN_REGEX, "N", regex=True + ) + self.nt_ratess = nt_ratess + self.nt_cspss = nt_cspss + self.multihit_model = copy.deepcopy(multihit_model) + if multihit_model is not None: + # We want these parameters to act like fixed data. This is essential + # for multithreaded branch length optimization to work. + self.multihit_model.values.requires_grad_(False) + + assert len(self.nt_parents) == len(self.nt_children) + pcp_count = len(self.nt_parents) + + # Important to use the unmodified versions of nt_parents and + # nt_children so they still contain special tokens. + aa_parents = translate_sequences(nt_parents) + aa_children = translate_sequences(nt_children) + + self.max_codon_seq_len = max(len(seq) for seq in aa_parents) + # We have sequences of varying length, so we start with all tensors set + # to the ambiguous amino acid, and then will fill in the actual values + # below. + self.codon_parents_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), AMBIGUOUS_CODON_IDX + ) + self.codon_children_idxss = self.codon_parents_idxss.clone() + # TODO: want to use ambig token once that's changed. + self.aa_parents_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX + ) + self.aa_children_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX + ) + # TODO here we are computing the subs indicators. This is handy for OE plots. + self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len)) + + self.masks = torch.ones((pcp_count, self.max_codon_seq_len), dtype=torch.bool) + + # We are using the modified nt_parents and nt_children here because we + # don't want any funky symbols in our codon indices. + for i, (nt_parent, nt_child, aa_parent, aa_child) in enumerate( + zip(self.nt_parents, self.nt_children, aa_parents, aa_children) + ): + self.masks[i, :] = codon_mask_tensor_of( + nt_parent, nt_child, aa_length=self.max_codon_seq_len + ) + assert len(nt_parent) % 3 == 0 + codon_seq_len = len(nt_parent) // 3 + + assert_pcp_valid(nt_parent, nt_child, aa_mask=self.masks[i][:codon_seq_len]) + + self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( + nt_parent + ) + self.codon_children_idxss[i, :codon_seq_len] = ( + codon_idx_tensor_of_str_ambig(nt_child) + ) + self.aa_parents_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( + aa_parent + ) + self.aa_children_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( + aa_child + ) + self.aa_subs_indicators[i, :codon_seq_len] = aa_subs_indicator_tensor_of( + aa_parent, aa_child + ) + + assert torch.all(self.masks.sum(dim=1) > 0) + assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX + + self._branch_lengths = branch_lengths + self.update_neutral_probs() + + def update_neutral_probs(self): + """Update the neutral mutation probabilities for the dataset. + + This is a somewhat vague name, but that's because it includes all of the various + types of neutral mutation probabilities that we might want to compute. + + In this case it's the neutral codon probabilities. + """ + neutral_codon_probs_l = [] + + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( + self.nt_parents, + self.masks, + self.nt_ratess, + self.nt_cspss, + self._branch_lengths, + ): + mask = mask.to("cpu") + nt_rates = nt_rates.to("cpu") + nt_csps = nt_csps.to("cpu") + if self.multihit_model is not None: + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") + else: + multihit_model = None + # Note we are replacing all Ns with As, which means that we need to be careful + # with masking out these positions later. We do this below. + parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) + parent_len = len(nt_parent) + + mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) + nt_csps = nt_csps[:parent_len, :] + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) + + neutral_codon_probs = molevol.neutral_codon_probs( + parent_idxs.reshape(-1, 3), + mut_probs.reshape(-1, 3), + nt_csps.reshape(-1, 3, 4), + multihit_model=multihit_model, + ) + + if not torch.isfinite(neutral_codon_probs).all(): + print(f"Found a non-finite neutral_codon_prob") + print(f"nt_parent: {nt_parent}") + print(f"mask: {mask}") + print(f"nt_rates: {nt_rates}") + print(f"nt_csps: {nt_csps}") + print(f"branch_length: {branch_length}") + raise ValueError( + f"neutral_codon_probs is not finite: {neutral_codon_probs}" + ) + + # Ensure that all values are positive before taking the log later + neutral_codon_probs = clamp_probability(neutral_codon_probs) + + pad_len = self.max_codon_seq_len - neutral_codon_probs.shape[0] + if pad_len > 0: + neutral_codon_probs = F.pad( + neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 + ) + # Here we zero out masked positions. + neutral_codon_probs *= mask[:, None] + + neutral_codon_probs_l.append(neutral_codon_probs) + + # Note that our masked out positions will have a nan log probability, + # which will require us to handle them correctly downstream. + self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l)) + + def __getitem__(self, idx): + return { + "codon_parents_idxs": self.codon_parents_idxss[idx], + "codon_children_idxs": self.codon_children_idxss[idx], + "aa_parents_idxs": self.aa_parents_idxss[idx], + "aa_children_idxs": self.aa_children_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_codon_probs": self.log_neutral_codon_probss[idx], + "nt_rates": self.nt_ratess[idx], + "nt_csps": self.nt_cspss[idx], + } + + def to(self, device): + self.codon_parents_idxss = self.codon_parents_idxss.to(device) + self.codon_children_idxss = self.codon_children_idxss.to(device) + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_children_idxss = self.aa_children_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device) + self.nt_ratess = self.nt_ratess.to(device) + self.nt_cspss = self.nt_cspss.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) + + +class DCSMBurrito(DXSMBurrito): + + model_type = "dcsm" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = torch.nn.CrossEntropyLoss() + + def prediction_pair_of_batch(self, batch): + """Get log neutral codon substitution probabilities and log selection factors + for a batch of data. + + We don't mask on the output, which will thus contain junk in all of the masked + sites. + """ + aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + log_neutral_codon_probs = batch["log_neutral_codon_probs"].to(self.device) + if not torch.isfinite(log_neutral_codon_probs[mask]).all(): + raise ValueError( + f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" + ) + # We need the model to see special tokens here. For every other purpose + # they are masked out. + keep_token_mask = mask | token_mask_of_aa_idxs(aa_parents_idxs) + log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) + return log_neutral_codon_probs, log_selection_factors + + def predictions_of_pair(self, log_neutral_codon_probs, log_selection_factors): + """Obtain the predictions for a pair consisting of the log neutral codon + substitution probabilities and the log selection factors.""" + # This indicator lifts things up from aa land to codon land. + indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T + # multiply (matrix/tensor sense) log_selection_factors by the indicator + predictions = log_neutral_codon_probs + log_selection_factors @ indicator + assert torch.isnan(predictions).sum() == 0 + return predictions + + def predictions_of_batch(self, batch): + """Make predictions for a batch of data. + + In this case they are log probabilities of codons, which are made to be + probabilities by setting the parent codon to 1 - sum(children). + + After all this, we clip the probabilities below to avoid log(0) issues. + So, in cases when the sum of the children is > 1, we don't give a + normalized probability distribution, but that's OK for loss calculation + because we are doing softmax. + TODO think about this more. + + Note that make all ambiguous codons nan in the output, ensuring that + they must get properly masked downstream. + """ + log_neutral_codon_probs, log_selection_factors = self.prediction_pair_of_batch( + batch + ) + log_preds = self.predictions_of_pair( + log_neutral_codon_probs, log_selection_factors + ) + + parent_indices = batch["codon_parents_idxs"] # Shape: [B, L] + valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] + + # Convert to linear space. + preds = torch.exp(log_preds) + + # The below uses + # https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html + # Note that everywhere, `dim` is -1, which in this case corresponds to + # this setting from the docs: + # self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + # Because we have just unsqueezed the last dimension, this is the + # same as + # self[i][j][index[i][j]] = src[i][j] + # if we hadn't unqueezed. + # However, we have to unsqueeze because scatter_ requires the `index` + # tensor and the `src` value(s) to have the same shape for broadcasting. + + parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] + + # Create mask for valid (non-ambiguous) codons. + valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] + + # Now unsqueeze indices for scatter operation. + parent_indices = parent_indices.unsqueeze(-1) # Shape: [B, L, 1] + + # Zero out valid parent codon entries. + preds[valid_mask, :].scatter_(-1, parent_indices[valid_mask, :], 0.0) + + # Sum non-parent probabilities for valid indices. + non_parent_sum = preds[valid_mask, :].sum(dim=-1, keepdim=True) + + # Set parent probability for valid indices. + preds[valid_mask, :].scatter_( + -1, parent_indices[valid_mask, :], 1.0 - non_parent_sum + ) + + # We have to clamp the predictions to avoid log(0) issues. + preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) + + log_preds = torch.log(preds) + + # Set ambiguous codons to nan to make sure that we handle them correctly downstream. + log_preds[~valid_mask, :] = float("nan") + + return log_preds + + def loss_of_batch(self, batch): + codon_children_idxs = batch["codon_children_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + + predictions = self.predictions_of_batch(batch)[mask] + assert torch.isnan(predictions).sum() == 0 + codon_children_idxs = codon_children_idxs[mask] + + return self.xent_loss(predictions, codon_children_idxs) + + # TODO copied from dasm.py + def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ + # This is simpler than the equivalent in dnsm.py because we get the selection + # matrix directly. Note that selection_factors_of_aa_str does the exponentiation + # so this indeed gives us the selection factors, not the log selection factors. + parent = translate_sequence(parent) + per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) + + parent = parent.replace("X", "A") + parent_idxs = aa_idx_array_of_str(parent) + per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 + + return per_aa_selection_factors diff --git a/netam/dnsm.py b/netam/dnsm.py index 6efeda85..ba159cf6 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -54,6 +54,7 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] + # TODO singular/plural mismatch neutral_aa_mut_prob = molevol.neutral_aa_mut_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), diff --git a/netam/framework.py b/netam/framework.py index b4db2d7e..548ad59c 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -612,6 +612,9 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None self.optimizer.zero_grad() scalar_loss.backward() + if torch.isnan(scalar_loss): + raise ValueError(f"NaN in loss: {scalar_loss.item()}") + nan_in_gradients = False for name, param in self.model.named_parameters(): if torch.isnan(param).any(): diff --git a/netam/molevol.py b/netam/molevol.py index 2aef1c10..8c03909b 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -339,14 +339,14 @@ def build_codon_mutsel( return codon_mutsel, sums_too_big -def neutral_aa_probs( +def neutral_codon_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_csps: Tensor, multihit_model=None, ) -> Tensor: - """For every site, what is the probability that the amino acid will mutate to every - amino acid? + """For every site, what is the probability that the site will mutate to every + alternate codon? Args: parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) @@ -354,8 +354,8 @@ def neutral_aa_probs( codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) Returns: - torch.Tensor: The probability that each site will change to each amino acid. - Shape: (codon_count, 20) + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 64) """ mut_matrices = build_mutation_matrices( @@ -366,8 +366,36 @@ def neutral_aa_probs( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) + return codon_probs.view(-1, 64) + + +def neutral_aa_probs( + parent_codon_idxs: Tensor, + codon_mut_probs: Tensor, + codon_csps: Tensor, + multihit_model=None, +) -> Tensor: + """For every site, what is the probability that the site will mutate to every + alternate amino acid? + + Args: + parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) + codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3) + codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) + + Returns: + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 20) + """ + codon_probs = neutral_codon_probs( + parent_codon_idxs, + codon_mut_probs, + codon_csps, + multihit_model=multihit_model, + ) + # Get the probability of mutating to each amino acid. - aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX + aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX return aa_probs diff --git a/netam/sequences.py b/netam/sequences.py index 3ddf6d88..726e1b92 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -72,6 +72,30 @@ def aa_idx_tensor_of_str(aa_str): raise +# TODO isolating all this stuff here + +AMBIGUOUS_CODON_IDX = len(CODONS) + + +def idx_of_codon_allowing_ambiguous(codon): + # if codon contains an N + if "N" in codon: + return AMBIGUOUS_CODON_IDX + else: + return CODONS.index(codon) + + +def codon_idx_tensor_of_str_ambig(nt_str): + """Return the indices of the codons in a string.""" + assert len(nt_str) % 3 == 0 + return torch.tensor( + [idx_of_codon_allowing_ambiguous(codon) for codon in iter_codons(nt_str)] + ) + + +# TODO end isolating new stuff + + def aa_onehot_tensor_of_str(aa_str): aa_onehot = torch.zeros((len(aa_str), 20)) aa_indices_parent = aa_idx_array_of_str(aa_str) diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py new file mode 100644 index 00000000..d78fa5b2 --- /dev/null +++ b/tests/test_dcsm.py @@ -0,0 +1,53 @@ +import os + +import torch +import pytest + +from netam.common import BIG, force_spawn +from netam.framework import ( + crepe_exists, + load_crepe, +) +from netam.sequences import MAX_AA_TOKEN_IDX +from netam.models import TransformerBinarySelectionModelWiggleAct +from netam.dcsm import ( + DCSMBurrito, + DCSMDataset, +) + + +@pytest.fixture(scope="module") +def dcsm_burrito(pcp_df): + force_spawn() + """Fixture that returns the DNSM Burrito object.""" + pcp_df["in_train"] = True + pcp_df.loc[pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df) + + model = TransformerBinarySelectionModelWiggleAct( + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=MAX_AA_TOKEN_IDX + 1, + ) + + burrito = DCSMBurrito( + train_dataset, + val_dataset, + model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) + return burrito + + +def test_parallel_branch_length_optimization(dcsm_burrito): + dataset = dcsm_burrito.val_dataset + parallel_branch_lengths = dcsm_burrito.find_optimal_branch_lengths(dataset) + branch_lengths = dcsm_burrito.serial_find_optimal_branch_lengths(dataset) + assert torch.allclose(branch_lengths, parallel_branch_lengths) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index f1a0c8d3..1630948c 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -11,7 +11,9 @@ TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, + AMBIGUOUS_CODON_IDX, aa_onehot_tensor_of_str, + codon_idx_tensor_of_str, nt_idx_array_of_str, nt_subs_indicator_tensor_of, translate_sequences, @@ -51,6 +53,13 @@ def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("GCG").tolist() == [2, 1, 2] +def test_codon_idx_tensor_of_str(): + nt_str = "AAAAACTTGTTTNTT" + expected_output = torch.tensor([0, 1, 62, 63, AMBIGUOUS_CODON_IDX]) + output = codon_idx_tensor_of_str(nt_str) + assert torch.equal(output, expected_output) + + def test_aa_onehot_tensor_of_str(): aa_str = "QY"