diff --git a/netam/dasm.py b/netam/dasm.py index 9f844a88..46c010ad 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -96,6 +96,8 @@ def to(self, device): self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device) self.all_rates = self.all_rates.to(device) self.all_subs_probs = self.all_subs_probs.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) def zap_predictions_along_diagonal(predictions, aa_parents_idxs): @@ -203,8 +205,8 @@ def build_selection_matrix_from_parent(self, parent: str): # matrix directly. Note that selection_factors_of_aa_str does the exponentiation # so this indeed gives us the selection factors, not the log selection factors. parent = translate_sequence(parent) - selection_factors = self.model.selection_factors_of_aa_str(parent) + per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) parent_idxs = sequences.aa_idx_array_of_str(parent) - selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 + per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 - return selection_factors + return per_aa_selection_factors diff --git a/netam/dnsm.py b/netam/dnsm.py index 76b263bf..f6a7a84e 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -235,6 +235,8 @@ def update_neutral_probs(self): the DNSM (in which case it's neutral probabilities of any nonsynonymous mutation) and the DASM (in which case it's the neutral probabilities of mutation to the various amino acids). + + This is the case of the DNSM, but the DASM will override this method. """ neutral_aa_mut_prob_l = [] @@ -333,8 +335,8 @@ def load_branch_lengths(self, in_csv_prefix): self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv") def prediction_pair_of_batch(self, batch): - """Get log neutral mutation probabilities and log selection factors for a batch - of data.""" + """Get log neutral amino acid substitution probabilities and log selection + factors for a batch of data.""" aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) mask = batch["mask"].to(self.device) log_neutral_aa_mut_probs = batch["log_neutral_aa_mut_probs"].to(self.device) @@ -346,7 +348,8 @@ def prediction_pair_of_batch(self, batch): return log_neutral_aa_mut_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors): - # Take the product of the neutral mutation probabilities and the selection factors. + """Obtain the predictions for a pair consisting of the log neutral amino acid + mutation substitution probabilities and the log selection factors.""" predictions = torch.exp(log_neutral_aa_mut_probs + log_selection_factors) assert torch.isfinite(predictions).all() predictions = clamp_probability(predictions) @@ -443,7 +446,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) # # The following can be used when one wants a better traceback. - # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, diff --git a/netam/framework.py b/netam/framework.py index 37fa6785..55694be1 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -461,7 +461,7 @@ def multi_train(self, epochs, max_tries=3): and resume training. """ for i in range(max_tries): - train_history = self.train(epochs) + train_history = self.simple_train(epochs) if self.optimizer.param_groups[0]["lr"] < self.min_learning_rate: return train_history else: @@ -571,7 +571,7 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None self.write_loss("Validation loss", average_loss, self.global_epoch) return loss_reduction(average_loss) - def train(self, epochs, out_prefix=None): + def simple_train(self, epochs, out_prefix=None): """Train the model for the given number of epochs. If out_prefix is provided, then a crepe will be saved to that location. @@ -760,7 +760,7 @@ def joint_train( weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate) ) self.reset_optimization(new_lr) - loss_history_l.append(self.train(epochs, out_prefix=out_prefix)) + loss_history_l.append(self.simple_train(epochs, out_prefix=out_prefix)) # We standardize and optimize the branch lengths after each cycle, even the last one. optimize_branch_lengths() self.mark_branch_lengths_optimized(cycle + 1) diff --git a/netam/molevol.py b/netam/molevol.py index a82a085d..63ed7871 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -186,8 +186,10 @@ def aaprob_of_mut_and_sub( parent_codon_idxs: Tensor, mut_probs: Tensor, sub_probs: Tensor ) -> Tensor: """For a sequence of parent codons and given nucleotide mutability and substitution - probabilities, compute the amino acid substitution probabilities for each codon - along the sequence. + probabilities, compute the probability of a substitution to each amino acid for each + codon along the sequence. + + Stop codons don't appear as part of this calculation. Args: parent_codon_idxs (torch.Tensor): A 2D tensor where each row contains indices representing @@ -288,6 +290,8 @@ def build_codon_mutsel( """Build a sequence of codon mutation-selection matrices for codons along a sequence. + These will assign zero for the probability of mutating to a stop codon. + Args: parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3) @@ -310,6 +314,7 @@ def build_codon_mutsel( # So, for each site (s) and codon (c), sum over amino acids (a): # codon_sel_matrices[s, c] = sum_a(CODON_AA_INDICATOR_MATRIX[c, a] * aa_sel_matrices[s, a]) # Resulting shape is (S, C) where S is the number of sites and C is the number of codons. + # Stop codons don't appear in this sum, so columns for stop codons will be zero. codon_sel_matrices = torch.einsum( "ca,sa->sc", CODON_AA_INDICATOR_MATRIX, aa_sel_matrices ) @@ -322,7 +327,8 @@ def build_codon_mutsel( # Now we need to recalculate the probability of staying in the same codon. # In our setup, this is the probability of nothing happening. - # To calculate this, we zero out the previously calculated probabilities... + # To calculate this, we zero out the probabilities of mutating to the parent + # codon... codon_count = parent_codon_idxs.shape[0] codon_mutsel[(torch.arange(codon_count), *parent_codon_idxs.T)] = 0.0 # sum together their probabilities... diff --git a/tests/test_netam.py b/tests/test_netam.py index ffb525ae..78d49804 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -31,7 +31,7 @@ def tiny_model(): @pytest.fixture def tiny_burrito(tiny_dataset, tiny_val_dataset, tiny_model): burrito = SHMBurrito(tiny_dataset, tiny_val_dataset, tiny_model) - burrito.train(epochs=5) + burrito.simple_train(epochs=5) return burrito