From d55fed09b882966e7fedd18a6fe32935eb662810 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 23 Sep 2024 13:47:40 -0700 Subject: [PATCH 1/8] in-person drafting --- netam/molevol.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/netam/molevol.py b/netam/molevol.py index ead81b5c..674b25a4 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -283,6 +283,7 @@ def build_codon_mutsel( codon_mut_probs: Tensor, codon_sub_probs: Tensor, aa_sel_matrices: Tensor, + multihit_model=None, ) -> Tensor: """Build a sequence of codon mutation-selection matrices for codons along a sequence. @@ -301,6 +302,9 @@ def build_codon_mutsel( ) codon_probs = codon_probs_of_mutation_matrices(mut_matrices) + if multihit_model is not None: + codon_probs = multihit_model(parent_codon_idxs, codon_probs) + # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. # So, for each site (s) and codon (c), sum over amino acids (a): @@ -419,7 +423,7 @@ def neutral_aa_mut_probs( return mut_probs -def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs): +def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs, multihit_model=None): """Constructs the log_pcp_probability function specific to given rates and sub_probs. @@ -442,6 +446,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor): mut_probs.reshape(-1, 3), sub_probs.reshape(-1, 3, 4), sel_matrix, + multihit_model=multihit_model, ) # This is a diagnostic generating data for netam issue #7. From cc65c0f1f7077fa1cc2003a5a2372ef31c7c258e Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 23 Sep 2024 16:48:22 -0700 Subject: [PATCH 2/8] initial modification of functions Add multihit model in a few more places make format fix shape issue fix rebase issue, but tests still fail force_spawn in tests reformat switch to serial branch length optimization multihit works with threading WIP-- nothing works now comment test+fake multihit working re-enable multihit perhaps working, fixed device mismatch and switched to nonlog correction application I think this might have worked verified working here cleanup for PR --- netam/dnsm.py | 36 ++++++++++++++++++++++++++++++------ netam/hit_class.py | 21 +++++++++++++-------- netam/models.py | 5 +++-- netam/molevol.py | 38 ++++++++++++++++++++++++++++++++------ netam/multihit.py | 26 +++++++++++++------------- tests/test_dnsm.py | 12 ++++++++++++ tests/test_multihit.py | 4 ++-- 7 files changed, 105 insertions(+), 37 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 3cd42be3..0bde8205 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -43,11 +43,17 @@ def __init__( all_rates: torch.Tensor, all_subs_probs: torch.Tensor, branch_lengths: torch.Tensor, + multihit_model=None, ): self.nt_parents = nt_parents self.nt_children = nt_children self.all_rates = all_rates self.all_subs_probs = all_subs_probs + self.multihit_model = copy.deepcopy(multihit_model) + if multihit_model is not None: + # We want these parameters to act like fixed data. This is essential + # for multithreaded branch length optimization to work. + self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -95,6 +101,7 @@ def of_seriess( all_rates_series: pd.Series, all_subs_probs_series: pd.Series, branch_length_multiplier=5.0, + multihit_model=None, ): """Alternative constructor that takes the raw data and calculates the initial branch lengths. @@ -115,10 +122,11 @@ def of_seriess( stack_heterogeneous(all_rates_series.reset_index(drop=True)), stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)), initial_branch_lengths, + multihit_model=multihit_model, ) @classmethod - def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): + def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None): """Alternative constructor that takes in a pcp_df and calculates the initial branch lengths.""" assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column" @@ -128,10 +136,11 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): pcp_df["rates"], pcp_df["subs_probs"], branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) @classmethod - def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): + def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None): """Perform a train-val split based on the 'in_train' column. This is a class method so it works for subclasses. @@ -140,14 +149,18 @@ def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True) val_dataset = cls.of_pcp_df( - val_df, branch_length_multiplier=branch_length_multiplier + val_df, + branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) if len(train_df) == 0: return None, val_dataset # else: train_dataset = cls.of_pcp_df( - train_df, branch_length_multiplier=branch_length_multiplier + train_df, + branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) return train_dataset, val_dataset @@ -160,6 +173,7 @@ def clone(self): self.all_rates.copy(), self.all_subs_probs.copy(), self._branch_lengths.copy(), + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -176,6 +190,7 @@ def subset_via_indices(self, indices): self.all_rates[indices], self.all_subs_probs[indices], self._branch_lengths[indices], + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -231,6 +246,10 @@ def update_neutral_probs(self): mask = mask.to("cpu") rates = rates.to("cpu") subs_probs = subs_probs.to("cpu") + if self.multihit_model is not None: + multihit_model = self.multihit_model.to("cpu") + else: + multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -245,6 +264,7 @@ def update_neutral_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), normed_subs_probs.reshape(-1, 3, 4), + multihit_model=multihit_model, ) if not torch.isfinite(neutral_aa_mut_prob).all(): @@ -295,6 +315,8 @@ def to(self, device): self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device) self.all_rates = self.all_rates.to(device) self.all_subs_probs = self.all_subs_probs.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) class DNSMBurrito(framework.Burrito): @@ -366,13 +388,14 @@ def _find_optimal_branch_length( rates, subs_probs, starting_branch_length, + multihit_model, **optimization_kwargs, ): if parent == child: return 0.0 sel_matrix = self.build_selection_matrix_from_parent(parent) log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - sel_matrix, parent, child, rates, subs_probs + sel_matrix, parent, child, rates, subs_probs, multihit_model ) if type(starting_branch_length) == torch.Tensor: starting_branch_length = starting_branch_length.detach().item() @@ -401,6 +424,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): rates[: len(parent)], subs_probs[: len(parent), :], starting_length, + dataset.multihit_model, **optimization_kwargs, ) @@ -416,7 +440,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. + # # The following can be used when one wants a better traceback. # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( diff --git a/netam/hit_class.py b/netam/hit_class.py index c42506d7..e9462187 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -41,9 +41,10 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor ] + def apply_multihit_correction( parent_codon_idxs: torch.Tensor, - codon_logprobs: torch.Tensor, + codon_probs: torch.Tensor, hit_class_factors: torch.Tensor, ) -> torch.Tensor: """Multiply codon probabilities by their hit class factors, and renormalize. @@ -53,23 +54,27 @@ def apply_multihit_correction( Args: parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the indices of the parent codon's nucleotides. - codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities + codon_probs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible target codon, for each of the N parent codons. hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). Returns: - torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible + torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible target codon, for each of the N parent codons, after applying the hit class factors. """ per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) - corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) + corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]).exp() reshaped_corrections = corrections[per_parent_hit_class] - unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections - normalizations = torch.logsumexp( - unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True + unnormalized_corrected_probs = codon_probs * reshaped_corrections + normalizations = torch.sum( + unnormalized_corrected_probs, dim=[1, 2, 3], keepdim=True ) - return unnormalized_corrected_logprobs - normalizations + result = unnormalized_corrected_probs / normalizations + if torch.any(torch.isnan(result)): + print("NAN found in multihit correction application") + assert False + return result def hit_class_probs_tensor( diff --git a/netam/models.py b/netam/models.py index 17008795..d34b6ea6 100644 --- a/netam/models.py +++ b/netam/models.py @@ -704,14 +704,15 @@ def __init__(self): def hyperparameters(self): return {} + # TODO changed to nonlog version for testing, need to update all calls to reflect this def forward( - self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor + self, parent_codon_idxs: torch.Tensor, uncorrected_codon_probs: torch.Tensor ): """Forward function takes a tensor of target codon distributions, for each observed parent codon, and adjusts the distributions according to the hit class corrections.""" return apply_multihit_correction( - parent_codon_idxs, uncorrected_log_codon_probs, self.values + parent_codon_idxs, uncorrected_codon_probs, self.values ) def reinitialize_weights(self): diff --git a/netam/molevol.py b/netam/molevol.py index 674b25a4..ca1ae6e7 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -10,6 +10,7 @@ """ import numpy as np +import warnings import torch from torch import Tensor, optim @@ -17,6 +18,7 @@ from netam.sequences import CODON_AA_INDICATOR_MATRIX import netam.sequences as sequences +# torch.autograd.set_detect_anomaly(True) def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: @@ -304,6 +306,8 @@ def build_codon_mutsel( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) + else: + warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. @@ -343,6 +347,7 @@ def neutral_aa_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_sub_probs: Tensor, + multihit_model=None, ) -> Tensor: """For every site, what is the probability that the amino acid will mutate to every amino acid? @@ -360,10 +365,15 @@ def neutral_aa_probs( mut_matrices = build_mutation_matrices( parent_codon_idxs, codon_mut_probs, codon_sub_probs ) - codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64) + codon_probs = codon_probs_of_mutation_matrices(mut_matrices) + + if multihit_model is not None: + codon_probs = multihit_model(parent_codon_idxs, codon_probs) + else: + warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Get the probability of mutating to each amino acid. - aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX + aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX return aa_probs @@ -400,6 +410,7 @@ def neutral_aa_mut_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_sub_probs: Tensor, + multihit_model=None, ) -> Tensor: """For every site, what is the probability that the amino acid will have a substution or mutate to a stop under neutral evolution? @@ -418,12 +429,19 @@ def neutral_aa_mut_probs( Shape: (codon_count,) """ - aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs) + aa_probs = neutral_aa_probs( + parent_codon_idxs, + codon_mut_probs, + codon_sub_probs, + multihit_model=multihit_model, + ) mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs) return mut_probs -def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs, multihit_model=None): +def mutsel_log_pcp_probability_of( + sel_matrix, parent, child, rates, sub_probs, multihit_model=None +): """Constructs the log_pcp_probability function specific to given rates and sub_probs. @@ -487,6 +505,7 @@ def optimize_branch_length( step_idx = 0 + nan_issue = False for step_idx in range(max_optimization_steps): # For some PCPs, the optimizer works very hard optimizing very tiny branch lengths. if log_branch_length < log_branch_length_lower_threshold: @@ -501,7 +520,14 @@ def optimize_branch_length( loss.backward() torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0) optimizer.step() - assert not torch.isnan(log_branch_length) + if torch.isnan(log_branch_length): + print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length) + if np.isclose(prev_log_branch_length.detach().numpy(), 0): + log_branch_length = prev_log_branch_length + nan_issue = True + break + else: + assert False change_in_log_branch_length = torch.abs( log_branch_length - prev_log_branch_length @@ -511,7 +537,7 @@ def optimize_branch_length( prev_log_branch_length = log_branch_length.clone() - if step_idx == max_optimization_steps - 1: + if step_idx == max_optimization_steps - 1 or nan_issue: print( f"Warning: optimization did not converge after {max_optimization_steps} steps; log branch length is {log_branch_length.detach().item()}" ) diff --git a/netam/multihit.py b/netam/multihit.py index 192d0b03..9544f0fe 100644 --- a/netam/multihit.py +++ b/netam/multihit.py @@ -270,28 +270,28 @@ def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs): ] -def child_codon_logprobs_corrected( - uncorrected_per_parent_logprobs, parent_codon_idxs, child_codon_idxs, model +def child_codon_probs_corrected( + uncorrected_per_parent_probs, parent_codon_idxs, child_codon_idxs, model ): """Calculate the probability of each child codon given the parent codon probabilities, corrected by hit class factors. Args: - uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities + uncorrected_per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of each possible target codon, for each parent codon. parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon model: A HitClassModel implementing the desired correction. Returns: - torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon. + torch.Tensor: A (codon_count,) shaped tensor containing the corrected probabilities of each child codon. """ - corrected_per_parent_logprobs = model( - parent_codon_idxs, uncorrected_per_parent_logprobs + corrected_per_parent_probs = model( + parent_codon_idxs, uncorrected_per_parent_probs ) return child_codon_probs_from_per_parent_probs( - corrected_per_parent_logprobs, child_codon_idxs + corrected_per_parent_probs, child_codon_idxs ) @@ -340,12 +340,12 @@ def loss_of_batch(self, batch): codon_probs, codon_mask=codon_mask ) - child_codon_logprobs = child_codon_logprobs_corrected( - flat_masked_codon_probs.log(), + child_codon_logprobs = child_codon_probs_corrected( + flat_masked_codon_probs, parent_codons_flat, child_codons_flat, self.model, - ) + ).log() return -child_codon_logprobs.sum() def _find_optimal_branch_length( @@ -374,9 +374,9 @@ def log_pcp_probability(log_branch_length): child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask] parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask] - return child_codon_logprobs_corrected( - codon_probs.log(), parent_codon_idxs, child_codon_idxs, self.model - ).sum() + return child_codon_probs_corrected( + codon_probs, parent_codon_idxs, child_codon_idxs, self.model + ).log().sum() return optimize_branch_length( log_pcp_probability, diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index 0bf45464..82ae520d 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import torch @@ -12,6 +13,15 @@ from netam.dnsm import DNSMBurrito, DNSMDataset +def force_spawn(): + """Force the spawn start method for multiprocessing. + + This is necessary to avoid conflicts with the internal OpenMP-based thread pool in + PyTorch. + """ + mp.set_start_method("spawn", force=True) + + def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int) @@ -22,6 +32,7 @@ def test_aa_idx_tensor_of_str_ambig(): @pytest.fixture(scope="module") def dnsm_burrito(pcp_df): """Fixture that returns the DNSM Burrito object.""" + force_spawn() pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False train_dataset, val_dataset = DNSMDataset.train_val_datasets_of_pcp_df(pcp_df) @@ -43,6 +54,7 @@ def dnsm_burrito(pcp_df): def test_parallel_branch_length_optimization(dnsm_burrito): + force_spawn() dataset = dnsm_burrito.val_dataset parallel_branch_lengths = dnsm_burrito.find_optimal_branch_lengths(dataset) branch_lengths = dnsm_burrito.serial_find_optimal_branch_lengths(dataset) diff --git a/tests/test_multihit.py b/tests/test_multihit.py index 58d3ad35..f9f887dd 100644 --- a/tests/test_multihit.py +++ b/tests/test_multihit.py @@ -105,8 +105,8 @@ def test_multihit_correction(): # We'll verify that aggregating by hit class then adjusting is the same as adjusting then aggregating by hit class. codon_idxs = reshape_for_codons(ex_parent_codon_idxs) adjusted_codon_probs = hit_class.apply_multihit_correction( - codon_idxs, ex_codon_probs.log(), hit_class_factors - ).exp() + codon_idxs, ex_codon_probs, hit_class_factors + ) aggregate_last = hit_class.hit_class_probs_tensor(codon_idxs, adjusted_codon_probs) uncorrected_hc_log_probs = hit_class.hit_class_probs_tensor( From f1684fa59d18f6629a9390c3a079e2b924d2ef7b Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 22 Oct 2024 17:05:13 -0700 Subject: [PATCH 3/8] format and fix dasm test --- netam/dnsm.py | 4 +++- netam/hit_class.py | 1 - netam/molevol.py | 6 +++++- netam/multihit.py | 14 ++++++++------ tests/test_dasm.py | 11 +++++++++++ 5 files changed, 27 insertions(+), 9 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 0bde8205..72a3db78 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -140,7 +140,9 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None): ) @classmethod - def train_val_datasets_of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None): + def train_val_datasets_of_pcp_df( + cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None + ): """Perform a train-val split based on the 'in_train' column. This is a class method so it works for subclasses. diff --git a/netam/hit_class.py b/netam/hit_class.py index e9462187..9f048697 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -41,7 +41,6 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor ] - def apply_multihit_correction( parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor, diff --git a/netam/molevol.py b/netam/molevol.py index ca1ae6e7..940795f3 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -18,6 +18,7 @@ from netam.sequences import CODON_AA_INDICATOR_MATRIX import netam.sequences as sequences + # torch.autograd.set_detect_anomaly(True) @@ -521,7 +522,10 @@ def optimize_branch_length( torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0) optimizer.step() if torch.isnan(log_branch_length): - print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length) + print( + "branch length optimization resulted in NAN, previous log branch length:", + prev_log_branch_length, + ) if np.isclose(prev_log_branch_length.detach().numpy(), 0): log_branch_length = prev_log_branch_length nan_issue = True diff --git a/netam/multihit.py b/netam/multihit.py index 9544f0fe..a50d5dce 100644 --- a/netam/multihit.py +++ b/netam/multihit.py @@ -287,9 +287,7 @@ def child_codon_probs_corrected( torch.Tensor: A (codon_count,) shaped tensor containing the corrected probabilities of each child codon. """ - corrected_per_parent_probs = model( - parent_codon_idxs, uncorrected_per_parent_probs - ) + corrected_per_parent_probs = model(parent_codon_idxs, uncorrected_per_parent_probs) return child_codon_probs_from_per_parent_probs( corrected_per_parent_probs, child_codon_idxs ) @@ -374,9 +372,13 @@ def log_pcp_probability(log_branch_length): child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask] parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask] - return child_codon_probs_corrected( - codon_probs, parent_codon_idxs, child_codon_idxs, self.model - ).log().sum() + return ( + child_codon_probs_corrected( + codon_probs, parent_codon_idxs, child_codon_idxs, self.model + ) + .log() + .sum() + ) return optimize_branch_length( log_pcp_probability, diff --git a/tests/test_dasm.py b/tests/test_dasm.py index bdc827b8..01d7f79f 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -14,10 +14,21 @@ DASMDataset, zap_predictions_along_diagonal, ) +import multiprocessing as mp + + +def force_spawn(): + """Force the spawn start method for multiprocessing. + + This is necessary to avoid conflicts with the internal OpenMP-based thread pool in + PyTorch. + """ + mp.set_start_method("spawn", force=True) @pytest.fixture(scope="module") def dasm_burrito(pcp_df): + force_spawn() """Fixture that returns the DNSM Burrito object.""" pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False From 3c46cd334fec6acb254e750d540933cdf92db16e Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 23 Oct 2024 09:06:51 -0700 Subject: [PATCH 4/8] respond to Erick's comments --- netam/common.py | 10 ++++++++++ netam/dnsm.py | 4 ++-- tests/test_dasm.py | 11 +---------- tests/test_dnsm.py | 11 +---------- 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/netam/common.py b/netam/common.py index 8f2f0c88..0317b3c6 100644 --- a/netam/common.py +++ b/netam/common.py @@ -8,6 +8,7 @@ import torch import torch.optim as optim from torch import nn, Tensor +import multiprocessing as mp BIG = 1e9 SMALL_PROB = 1e-6 @@ -31,6 +32,15 @@ ) +def force_spawn(): + """Force the spawn start method for multiprocessing. + + This is necessary to avoid conflicts with the internal OpenMP-based thread pool in + PyTorch. + """ + mp.set_start_method("spawn", force=True) + + def generate_kmers(kmer_length): # Our strategy for kmers is to have a single representation for any kmer that isn't in ACGT. # This is the first one, which is simply "N", and so this placeholder value is 0. diff --git a/netam/dnsm.py b/netam/dnsm.py index 72a3db78..68cbe2e4 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -175,7 +175,7 @@ def clone(self): self.all_rates.copy(), self.all_subs_probs.copy(), self._branch_lengths.copy(), - multihit_model=copy.deepcopy(self.multihit_model), + multihit_model=self.multihit_model, ) return new_dataset @@ -192,7 +192,7 @@ def subset_via_indices(self, indices): self.all_rates[indices], self.all_subs_probs[indices], self._branch_lengths[indices], - multihit_model=copy.deepcopy(self.multihit_model), + multihit_model=self.multihit_model, ) return new_dataset diff --git a/tests/test_dasm.py b/tests/test_dasm.py index 01d7f79f..cde73120 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -3,7 +3,7 @@ import torch import pytest -from netam.common import BIG +from netam.common import BIG, force_spawn from netam.framework import ( crepe_exists, load_crepe, @@ -17,15 +17,6 @@ import multiprocessing as mp -def force_spawn(): - """Force the spawn start method for multiprocessing. - - This is necessary to avoid conflicts with the internal OpenMP-based thread pool in - PyTorch. - """ - mp.set_start_method("spawn", force=True) - - @pytest.fixture(scope="module") def dasm_burrito(pcp_df): force_spawn() diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index 82ae520d..237f9794 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -8,20 +8,11 @@ crepe_exists, load_crepe, ) -from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX +from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dnsm import DNSMBurrito, DNSMDataset -def force_spawn(): - """Force the spawn start method for multiprocessing. - - This is necessary to avoid conflicts with the internal OpenMP-based thread pool in - PyTorch. - """ - mp.set_start_method("spawn", force=True) - - def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int) From 4982d4a6c37c424c623971120cb9ca82f57291fb Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 23 Oct 2024 09:09:35 -0700 Subject: [PATCH 5/8] remove multihit warning --- netam/molevol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/netam/molevol.py b/netam/molevol.py index 940795f3..56892a80 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -307,8 +307,6 @@ def build_codon_mutsel( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) - else: - warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. @@ -370,8 +368,6 @@ def neutral_aa_probs( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) - else: - warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Get the probability of mutating to each amino acid. aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX From b26de874b76f3257e53afffaaf221e60856a5f79 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 23 Oct 2024 11:01:51 -0700 Subject: [PATCH 6/8] remove unused imports, reformat --- netam/dasm.py | 4 ---- netam/dnsm.py | 2 +- netam/hit_class.py | 2 -- netam/models.py | 9 ++++----- netam/molevol.py | 1 - tests/test_dasm.py | 1 - tests/test_dnsm.py | 1 - 7 files changed, 5 insertions(+), 15 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index bac48135..03d3ef35 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -7,11 +7,7 @@ # optimization on our server. torch.set_num_threads(1) -import numpy as np -import pandas as pd - from netam.common import ( - clamp_log_probability, clamp_probability, BIG, ) diff --git a/netam/dnsm.py b/netam/dnsm.py index 68cbe2e4..d3d17b86 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -399,7 +399,7 @@ def _find_optimal_branch_length( log_pcp_probability = molevol.mutsel_log_pcp_probability_of( sel_matrix, parent, child, rates, subs_probs, multihit_model ) - if type(starting_branch_length) == torch.Tensor: + if isinstance(starting_branch_length, torch.Tensor): starting_branch_length = starting_branch_length.detach().item() return molevol.optimize_branch_length( log_pcp_probability, starting_branch_length, **optimization_kwargs diff --git a/netam/hit_class.py b/netam/hit_class.py index 9f048697..6ba81236 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -1,8 +1,6 @@ import torch import numpy as np -from netam.common import BASES - # Define the number of bases (e.g., 4 for DNA/RNA) _num_bases = 4 diff --git a/netam/models.py b/netam/models.py index d34b6ea6..581716b4 100644 --- a/netam/models.py +++ b/netam/models.py @@ -2,10 +2,6 @@ import math import warnings -warnings.filterwarnings( - "ignore", category=UserWarning, module="torch.nn.modules.transformer" -) - import pandas as pd import torch @@ -22,6 +18,10 @@ aa_mask_tensor_of, ) +warnings.filterwarnings( + "ignore", category=UserWarning, module="torch.nn.modules.transformer" +) + class ModelBase(nn.Module): def reinitialize_weights(self): @@ -704,7 +704,6 @@ def __init__(self): def hyperparameters(self): return {} - # TODO changed to nonlog version for testing, need to update all calls to reflect this def forward( self, parent_codon_idxs: torch.Tensor, uncorrected_codon_probs: torch.Tensor ): diff --git a/netam/molevol.py b/netam/molevol.py index 56892a80..5649f753 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -10,7 +10,6 @@ """ import numpy as np -import warnings import torch from torch import Tensor, optim diff --git a/tests/test_dasm.py b/tests/test_dasm.py index cde73120..6bae92ee 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -14,7 +14,6 @@ DASMDataset, zap_predictions_along_diagonal, ) -import multiprocessing as mp @pytest.fixture(scope="module") diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index 237f9794..e0bca099 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -1,4 +1,3 @@ -import multiprocessing as mp import os import torch From da2079723ca5e9c1356e93ef7d1ec49e8f91e3e7 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 23 Oct 2024 11:18:55 -0700 Subject: [PATCH 7/8] respond to Erick's new comments --- netam/hit_class.py | 6 +++--- netam/molevol.py | 14 ++------------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/netam/hit_class.py b/netam/hit_class.py index 6ba81236..6a661632 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -42,7 +42,7 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor def apply_multihit_correction( parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor, - hit_class_factors: torch.Tensor, + log_hit_class_factors: torch.Tensor, ) -> torch.Tensor: """Multiply codon probabilities by their hit class factors, and renormalize. @@ -53,7 +53,7 @@ def apply_multihit_correction( indices of the parent codon's nucleotides. codon_probs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible target codon, for each of the N parent codons. - hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The + log_hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). Returns: @@ -61,7 +61,7 @@ def apply_multihit_correction( target codon, for each of the N parent codons, after applying the hit class factors. """ per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) - corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]).exp() + corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp() reshaped_corrections = corrections[per_parent_hit_class] unnormalized_corrected_probs = codon_probs * reshaped_corrections normalizations = torch.sum( diff --git a/netam/molevol.py b/netam/molevol.py index 5649f753..762c9800 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -501,7 +501,6 @@ def optimize_branch_length( step_idx = 0 - nan_issue = False for step_idx in range(max_optimization_steps): # For some PCPs, the optimizer works very hard optimizing very tiny branch lengths. if log_branch_length < log_branch_length_lower_threshold: @@ -517,16 +516,7 @@ def optimize_branch_length( torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0) optimizer.step() if torch.isnan(log_branch_length): - print( - "branch length optimization resulted in NAN, previous log branch length:", - prev_log_branch_length, - ) - if np.isclose(prev_log_branch_length.detach().numpy(), 0): - log_branch_length = prev_log_branch_length - nan_issue = True - break - else: - assert False + raise ValueError("branch length optimization resulted in NAN") change_in_log_branch_length = torch.abs( log_branch_length - prev_log_branch_length @@ -536,7 +526,7 @@ def optimize_branch_length( prev_log_branch_length = log_branch_length.clone() - if step_idx == max_optimization_steps - 1 or nan_issue: + if step_idx == max_optimization_steps - 1: print( f"Warning: optimization did not converge after {max_optimization_steps} steps; log branch length is {log_branch_length.detach().item()}" ) From 62d34793689ad8a5ad70d052a5b8108e207993dc Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 23 Oct 2024 13:48:58 -0700 Subject: [PATCH 8/8] Respond to Erick's comments --- netam/dnsm.py | 2 +- netam/molevol.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index d3d17b86..76b263bf 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -249,7 +249,7 @@ def update_neutral_probs(self): rates = rates.to("cpu") subs_probs = subs_probs.to("cpu") if self.multihit_model is not None: - multihit_model = self.multihit_model.to("cpu") + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") else: multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful diff --git a/netam/molevol.py b/netam/molevol.py index 762c9800..a82a085d 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -18,8 +18,6 @@ import netam.sequences as sequences -# torch.autograd.set_detect_anomaly(True) - def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: """Normalize substitution probabilities.