From 1aac535a291ef91f4abfc80034b29499c8d1daa0 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 23 Sep 2024 13:47:40 -0700 Subject: [PATCH 01/18] in-person drafting --- netam/molevol.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/netam/molevol.py b/netam/molevol.py index ead81b5c..674b25a4 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -283,6 +283,7 @@ def build_codon_mutsel( codon_mut_probs: Tensor, codon_sub_probs: Tensor, aa_sel_matrices: Tensor, + multihit_model=None, ) -> Tensor: """Build a sequence of codon mutation-selection matrices for codons along a sequence. @@ -301,6 +302,9 @@ def build_codon_mutsel( ) codon_probs = codon_probs_of_mutation_matrices(mut_matrices) + if multihit_model is not None: + codon_probs = multihit_model(parent_codon_idxs, codon_probs) + # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. # So, for each site (s) and codon (c), sum over amino acids (a): @@ -419,7 +423,7 @@ def neutral_aa_mut_probs( return mut_probs -def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs): +def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs, multihit_model=None): """Constructs the log_pcp_probability function specific to given rates and sub_probs. @@ -442,6 +446,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor): mut_probs.reshape(-1, 3), sub_probs.reshape(-1, 3, 4), sel_matrix, + multihit_model=multihit_model, ) # This is a diagnostic generating data for netam issue #7. From a2e02e4f65368911e929c3871fd868ed55223972 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 23 Sep 2024 16:48:22 -0700 Subject: [PATCH 02/18] initial modification of functions --- netam/dnsm.py | 7 ++++++- netam/molevol.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index ef0658d4..75d19e9e 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -49,11 +49,13 @@ def __init__( all_rates: torch.Tensor, all_subs_probs: torch.Tensor, branch_lengths: torch.Tensor, + multihit_model=None, ): self.nt_parents = nt_parents self.nt_children = nt_children self.all_rates = all_rates self.all_subs_probs = all_subs_probs + self.multihit_model = multihit_model assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -220,6 +222,7 @@ def update_neutral_aa_mut_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), normed_subs_probs.reshape(-1, 3, 4), + multihit_model=self.multihit_model, ) if not torch.isfinite(neutral_aa_mut_prob).all(): @@ -360,13 +363,14 @@ def _find_optimal_branch_length( rates, subs_probs, starting_branch_length, + multihit_model, **optimization_kwargs, ): if parent == child: return 0.0 sel_matrix = self.build_selection_matrix_from_parent(parent) log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - sel_matrix, parent, child, rates, subs_probs + sel_matrix, parent, child, rates, subs_probs, multihit_model ) if type(starting_branch_length) == torch.Tensor: starting_branch_length = starting_branch_length.detach().item() @@ -395,6 +399,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): rates[: len(parent)], subs_probs[: len(parent), :], starting_length, + dataset.multihit_model, **optimization_kwargs, ) diff --git a/netam/molevol.py b/netam/molevol.py index 674b25a4..7e731cd9 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -303,7 +303,7 @@ def build_codon_mutsel( codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: - codon_probs = multihit_model(parent_codon_idxs, codon_probs) + codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. @@ -362,6 +362,9 @@ def neutral_aa_probs( ) codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64) + if multihit_model is not None: + codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() + # Get the probability of mutating to each amino acid. aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX @@ -400,6 +403,7 @@ def neutral_aa_mut_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_sub_probs: Tensor, + multihit_model=None, ) -> Tensor: """For every site, what is the probability that the amino acid will have a substution or mutate to a stop under neutral evolution? From 9d0b7bf32e669c6c40e1a3e614ea35b7ca4ce883 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 23 Sep 2024 17:14:52 -0700 Subject: [PATCH 03/18] Add multihit model in a few more places --- netam/dnsm.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 75d19e9e..e5728ec4 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -101,6 +101,7 @@ def of_seriess( all_rates_series: pd.Series, all_subs_probs_series: pd.Series, branch_length_multiplier=5.0, + multihit_model=None, ): """Alternative constructor that takes the raw data and calculates the initial branch lengths. @@ -121,10 +122,11 @@ def of_seriess( stack_heterogeneous(all_rates_series.reset_index(drop=True)), stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)), initial_branch_lengths, + multihit_model=multihit_model, ) @classmethod - def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): + def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None): """Alternative constructor that takes in a pcp_df and calculates the initial branch lengths.""" assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column" @@ -134,6 +136,7 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0): pcp_df["rates"], pcp_df["subs_probs"], branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) def clone(self): @@ -144,6 +147,7 @@ def clone(self): self.all_rates.copy(), self.all_subs_probs.copy(), self._branch_lengths.copy(), + multihit_model=self.multihit_model, ) return new_dataset @@ -160,6 +164,7 @@ def subset_via_indices(self, indices): self.all_rates[indices], self.all_subs_probs[indices], self._branch_lengths[indices], + multihit_model=self.multihit_model, ) return new_dataset @@ -275,7 +280,7 @@ def to(self, device): self.all_subs_probs = self.all_subs_probs.to(device) -def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): +def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0, multihit_model=None): """Perform a train-val split based on a "in_train" column. Stays here so it can be used in tests. @@ -283,13 +288,13 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True) val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True) val_dataset = DNSMDataset.of_pcp_df( - val_df, branch_length_multiplier=branch_length_multiplier + val_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model, ) if len(train_df) == 0: return None, val_dataset # else: train_dataset = DNSMDataset.of_pcp_df( - train_df, branch_length_multiplier=branch_length_multiplier + train_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model, ) return train_dataset, val_dataset From 9d4895558b90ff9b96a25ec2860fca781e3ece9e Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 25 Sep 2024 04:29:12 -0700 Subject: [PATCH 04/18] make format --- netam/dnsm.py | 12 +++++++++--- netam/molevol.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index e5728ec4..a8ff35c9 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -280,7 +280,9 @@ def to(self, device): self.all_subs_probs = self.all_subs_probs.to(device) -def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0, multihit_model=None): +def train_val_datasets_of_pcp_df( + pcp_df, branch_length_multiplier=5.0, multihit_model=None +): """Perform a train-val split based on a "in_train" column. Stays here so it can be used in tests. @@ -288,13 +290,17 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0, multihit_ train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True) val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True) val_dataset = DNSMDataset.of_pcp_df( - val_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model, + val_df, + branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) if len(train_df) == 0: return None, val_dataset # else: train_dataset = DNSMDataset.of_pcp_df( - train_df, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model, + train_df, + branch_length_multiplier=branch_length_multiplier, + multihit_model=multihit_model, ) return train_dataset, val_dataset diff --git a/netam/molevol.py b/netam/molevol.py index 7e731cd9..b72da4d0 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -427,7 +427,9 @@ def neutral_aa_mut_probs( return mut_probs -def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs, multihit_model=None): +def mutsel_log_pcp_probability_of( + sel_matrix, parent, child, rates, sub_probs, multihit_model=None +): """Constructs the log_pcp_probability function specific to given rates and sub_probs. From a73a6ac8c1a37d89ec0c024fdd9df2bb7198c3f1 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 27 Sep 2024 15:40:04 -0700 Subject: [PATCH 05/18] fix shape issue --- netam/molevol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/molevol.py b/netam/molevol.py index b72da4d0..3c504d53 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -360,13 +360,13 @@ def neutral_aa_probs( mut_matrices = build_mutation_matrices( parent_codon_idxs, codon_mut_probs, codon_sub_probs ) - codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64) + codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() # Get the probability of mutating to each amino acid. - aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX + aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX return aa_probs From 35e71b2ec55cd18eb1954aa7a5b9f180d3068e9e Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 27 Sep 2024 17:04:26 -0700 Subject: [PATCH 06/18] fix rebase issue, but tests still fail --- netam/molevol.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/netam/molevol.py b/netam/molevol.py index 3c504d53..fe1bbff0 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -343,6 +343,7 @@ def neutral_aa_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_sub_probs: Tensor, + multihit_model=None, ) -> Tensor: """For every site, what is the probability that the amino acid will mutate to every amino acid? @@ -422,7 +423,12 @@ def neutral_aa_mut_probs( Shape: (codon_count,) """ - aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs) + aa_probs = neutral_aa_probs( + parent_codon_idxs, + codon_mut_probs, + codon_sub_probs, + multihit_model=multihit_model, + ) mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs) return mut_probs From 2ff252412eb6e058242e0df8bfaf326e019e41ff Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 30 Sep 2024 04:11:57 -0700 Subject: [PATCH 07/18] force_spawn in tests --- tests/test_dnsm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index 76f9a5bc..e29070da 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import torch @@ -14,6 +15,15 @@ from netam.dnsm import DNSMBurrito, train_val_datasets_of_pcp_df +def force_spawn(): + """ + Force the spawn start method for multiprocessing. + + This is necessary to avoid conflicts with the internal OpenMP-based thread pool in PyTorch. + """ + mp.set_start_method("spawn", force=True) + + def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int) @@ -36,6 +46,7 @@ def pcp_df(): @pytest.fixture def dnsm_burrito(pcp_df): """Fixture that returns the DNSM Burrito object.""" + force_spawn() pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False train_dataset, val_dataset = train_val_datasets_of_pcp_df(pcp_df) @@ -57,6 +68,7 @@ def dnsm_burrito(pcp_df): def test_parallel_branch_length_optimization(dnsm_burrito): + force_spawn() dataset = dnsm_burrito.val_dataset parallel_branch_lengths = dnsm_burrito.find_optimal_branch_lengths(dataset) branch_lengths = dnsm_burrito.serial_find_optimal_branch_lengths(dataset) From 10db510f9f85324f78af3f74b159d902a123267e Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 30 Sep 2024 09:12:45 -0700 Subject: [PATCH 08/18] reformat --- tests/test_dnsm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index e29070da..9f60821a 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -16,10 +16,10 @@ def force_spawn(): - """ - Force the spawn start method for multiprocessing. + """Force the spawn start method for multiprocessing. - This is necessary to avoid conflicts with the internal OpenMP-based thread pool in PyTorch. + This is necessary to avoid conflicts with the internal OpenMP-based thread pool in + PyTorch. """ mp.set_start_method("spawn", force=True) From 9b7a18608518b843f411fd84dc6770fd9f20b0e3 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 1 Oct 2024 14:31:50 -0700 Subject: [PATCH 09/18] switch to serial branch length optimization --- netam/dnsm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index a8ff35c9..4d4af44e 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -427,15 +427,15 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) # The following can be used when one wants a better traceback. - # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - with mp.Pool(worker_count) as pool: - splits = dataset.split(worker_count) - results = pool.starmap( - worker_optimize_branch_length, - [(self.model, split, optimization_kwargs) for split in splits], - ) - return torch.cat(results) + burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # with mp.Pool(worker_count) as pool: + # splits = dataset.split(worker_count) + # results = pool.starmap( + # worker_optimize_branch_length, + # [(self.model, split, optimization_kwargs) for split in splits], + # ) + # return torch.cat(results) def to_crepe(self): training_hyperparameters = { From d44a60831c0bb1f52e61d4ae6a7bc16a24597740 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Wed, 2 Oct 2024 13:12:03 -0700 Subject: [PATCH 10/18] multihit works with threading --- netam/dnsm.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 4d4af44e..bd81cc88 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -55,7 +55,10 @@ def __init__( self.nt_children = nt_children self.all_rates = all_rates self.all_subs_probs = all_subs_probs - self.multihit_model = multihit_model + self.multihit_model = copy.deepcopy(multihit_model) + if multihit_model is not None: + # We want these parameters to act like fixed data + self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -147,7 +150,7 @@ def clone(self): self.all_rates.copy(), self.all_subs_probs.copy(), self._branch_lengths.copy(), - multihit_model=self.multihit_model, + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -164,7 +167,7 @@ def subset_via_indices(self, indices): self.all_rates[indices], self.all_subs_probs[indices], self._branch_lengths[indices], - multihit_model=self.multihit_model, + multihit_model=copy.deepcopy(self.multihit_model), ) return new_dataset @@ -426,16 +429,16 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - # with mp.Pool(worker_count) as pool: - # splits = dataset.split(worker_count) - # results = pool.starmap( - # worker_optimize_branch_length, - # [(self.model, split, optimization_kwargs) for split in splits], - # ) - # return torch.cat(results) + # # The following can be used when one wants a better traceback. + # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + with mp.Pool(worker_count) as pool: + splits = dataset.split(worker_count) + results = pool.starmap( + worker_optimize_branch_length, + [(self.model, split, optimization_kwargs) for split in splits], + ) + return torch.cat(results) def to_crepe(self): training_hyperparameters = { From e620b1500bee8be702f248fc5e0cf47e95031246 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Thu, 3 Oct 2024 06:44:40 -0700 Subject: [PATCH 11/18] comment --- netam/dnsm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index bd81cc88..681d6603 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -57,7 +57,8 @@ def __init__( self.all_subs_probs = all_subs_probs self.multihit_model = copy.deepcopy(multihit_model) if multihit_model is not None: - # We want these parameters to act like fixed data + # We want these parameters to act like fixed data. This is essential + # for multithreaded branch length optimization to work. self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) From 75c4f520b05ca14f3bc8b7aeb822af9ad27e5754 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 10 Oct 2024 16:47:52 -0700 Subject: [PATCH 12/18] WIP-- nothing works now --- netam/dnsm.py | 29 ++++++++++++++++------------- netam/framework.py | 12 ++++++++++++ netam/hit_class.py | 38 ++++++++++++++++++++++++++++++-------- netam/molevol.py | 13 +++++++++++-- 4 files changed, 69 insertions(+), 23 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index bd81cc88..1029d891 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -56,9 +56,9 @@ def __init__( self.all_rates = all_rates self.all_subs_probs = all_subs_probs self.multihit_model = copy.deepcopy(multihit_model) - if multihit_model is not None: - # We want these parameters to act like fixed data - self.multihit_model.values.requires_grad_(False) + # if multihit_model is not None: + # # We want these parameters to act like fixed data + # self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -281,6 +281,8 @@ def to(self, device): self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device) self.all_rates = self.all_rates.to(device) self.all_subs_probs = self.all_subs_probs.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) def train_val_datasets_of_pcp_df( @@ -429,16 +431,17 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # # The following can be used when one wants a better traceback. - # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - with mp.Pool(worker_count) as pool: - splits = dataset.split(worker_count) - results = pool.starmap( - worker_optimize_branch_length, - [(self.model, split, optimization_kwargs) for split in splits], - ) - return torch.cat(results) + # The following can be used when one wants a better traceback. + burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # TODO reenable + # with mp.Pool(worker_count) as pool: + # splits = dataset.split(worker_count) + # results = pool.starmap( + # worker_optimize_branch_length, + # [(self.model, split, optimization_kwargs) for split in splits], + # ) + # return torch.cat(results) def to_crepe(self): training_hyperparameters = { diff --git a/netam/framework.py b/netam/framework.py index d44ab1d0..cf9db9d0 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -27,6 +27,7 @@ VRC01_NT_SEQ, ) from netam import models +from netam.hit_class import MultihitApplier import netam.molevol as molevol @@ -336,6 +337,17 @@ def crepe_exists(prefix): return os.path.exists(f"{prefix}.yml") and os.path.exists(f"{prefix}.pth") +def load_multihit_adjuster(multihit_crepe_prefix, device=None): + if multihit_crepe_prefix is None: + multihit_model = None + else: + print(f"Loading multihit model from {multihit_crepe_prefix}") + multihit_model = None + # multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) + # multihit_model = MultihitApplier(multihit_adjust) + return multihit_model + + def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" rates, csp_logits = crepe(parents) diff --git a/netam/hit_class.py b/netam/hit_class.py index c42506d7..2ed42c90 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -62,14 +62,36 @@ def apply_multihit_correction( torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible target codon, for each of the N parent codons, after applying the hit class factors. """ - per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) - corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) - reshaped_corrections = corrections[per_parent_hit_class] - unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections - normalizations = torch.logsumexp( - unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True - ) - return unnormalized_corrected_logprobs - normalizations + # TODO: This is for testing + return codon_logprobs + # per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) + # # corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) + # corrections = torch.tensor([0.0, 0.0, 0.0, 0.0]) + # reshaped_corrections = corrections[per_parent_hit_class] + # unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections + # normalizations = torch.logsumexp( + # unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True + # ) + # result = unnormalized_corrected_logprobs - normalizations + # if torch.any(torch.isnan(result)): + # print("NAN found in multihit correction application") + # assert False + # return result + + +class MultihitApplier: + def __init__(self, hit_class_factors, device=None): + # self.corrections = hit_class_factors.to(device) + # TODO This is for testing + self.corrections = torch.tensor([0.0, 0.0, 0.0]).to(device) + + def to(self, device): + self.corrections = self.corrections.to(device) + return self + + def __call__(self, parent_codon_idxs, codon_logprobs): + return codon_logprobs + # return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) def hit_class_probs_tensor( diff --git a/netam/molevol.py b/netam/molevol.py index fe1bbff0..1c2163d4 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -17,6 +17,7 @@ from netam.sequences import CODON_AA_INDICATOR_MATRIX import netam.sequences as sequences +torch.autograd.set_detect_anomaly(True) def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: @@ -499,6 +500,7 @@ def optimize_branch_length( step_idx = 0 + nan_issue = False for step_idx in range(max_optimization_steps): # For some PCPs, the optimizer works very hard optimizing very tiny branch lengths. if log_branch_length < log_branch_length_lower_threshold: @@ -513,7 +515,14 @@ def optimize_branch_length( loss.backward() torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0) optimizer.step() - assert not torch.isnan(log_branch_length) + if not torch.isnan(log_branch_length): + print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length) + if np.isclose(prev_log_branch_length.detach().numpy(), 0): + log_branch_length = prev_log_branch_length + nan_issue = True + break + else: + assert False change_in_log_branch_length = torch.abs( log_branch_length - prev_log_branch_length @@ -523,7 +532,7 @@ def optimize_branch_length( prev_log_branch_length = log_branch_length.clone() - if step_idx == max_optimization_steps - 1: + if step_idx == max_optimization_steps - 1 or nan_issue: print( f"Warning: optimization did not converge after {max_optimization_steps} steps; log branch length is {log_branch_length.detach().item()}" ) From ae6cf4d4e716ba32b5f0a14ff1ce14470418c2d4 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 14 Oct 2024 14:13:03 -0700 Subject: [PATCH 13/18] test+fake multihit working --- netam/dnsm.py | 21 ++++++++++----------- netam/molevol.py | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index c5f3a4ca..6a8172d7 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -432,17 +432,16 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - # TODO reenable - # with mp.Pool(worker_count) as pool: - # splits = dataset.split(worker_count) - # results = pool.starmap( - # worker_optimize_branch_length, - # [(self.model, split, optimization_kwargs) for split in splits], - # ) - # return torch.cat(results) + # # The following can be used when one wants a better traceback. + # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + with mp.Pool(worker_count) as pool: + splits = dataset.split(worker_count) + results = pool.starmap( + worker_optimize_branch_length, + [(self.model, split, optimization_kwargs) for split in splits], + ) + return torch.cat(results) def to_crepe(self): training_hyperparameters = { diff --git a/netam/molevol.py b/netam/molevol.py index 1c2163d4..a28cb752 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -515,7 +515,7 @@ def optimize_branch_length( loss.backward() torch.nn.utils.clip_grad_norm_([log_branch_length], max_norm=5.0) optimizer.step() - if not torch.isnan(log_branch_length): + if torch.isnan(log_branch_length): print("branch length optimization resulted in NAN, previous log branch length:", prev_log_branch_length) if np.isclose(prev_log_branch_length.detach().numpy(), 0): log_branch_length = prev_log_branch_length From 2f83eb63f4544da7c84ac5f715eaffd4e80d8acd Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 14 Oct 2024 14:51:13 -0700 Subject: [PATCH 14/18] re-enable multihit --- netam/hit_class.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/netam/hit_class.py b/netam/hit_class.py index 2ed42c90..e2aa090f 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -62,36 +62,30 @@ def apply_multihit_correction( torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible target codon, for each of the N parent codons, after applying the hit class factors. """ - # TODO: This is for testing - return codon_logprobs - # per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) - # # corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) - # corrections = torch.tensor([0.0, 0.0, 0.0, 0.0]) - # reshaped_corrections = corrections[per_parent_hit_class] - # unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections - # normalizations = torch.logsumexp( - # unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True - # ) - # result = unnormalized_corrected_logprobs - normalizations - # if torch.any(torch.isnan(result)): - # print("NAN found in multihit correction application") - # assert False - # return result + per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) + corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) + reshaped_corrections = corrections[per_parent_hit_class] + unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections + normalizations = torch.logsumexp( + unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True + ) + result = unnormalized_corrected_logprobs - normalizations + if torch.any(torch.isnan(result)): + print("NAN found in multihit correction application") + assert False + return result class MultihitApplier: def __init__(self, hit_class_factors, device=None): - # self.corrections = hit_class_factors.to(device) - # TODO This is for testing - self.corrections = torch.tensor([0.0, 0.0, 0.0]).to(device) + self.corrections = hit_class_factors.to(device) def to(self, device): self.corrections = self.corrections.to(device) return self def __call__(self, parent_codon_idxs, codon_logprobs): - return codon_logprobs - # return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) + return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) def hit_class_probs_tensor( From 4f727780d508e3c6d691673cc3272fcf65adf4be Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 15 Oct 2024 16:41:18 -0700 Subject: [PATCH 15/18] perhaps working, fixed device mismatch and switched to nonlog correction application --- netam/dnsm.py | 11 ++++++----- netam/framework.py | 5 ++--- netam/hit_class.py | 45 ++++++++++++++++++++++++++++++++++++++++-- netam/molevol.py | 13 ++++++++++-- tests/test_multihit.py | 5 +++++ 5 files changed, 67 insertions(+), 12 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 6a8172d7..32b9755e 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -56,10 +56,10 @@ def __init__( self.all_rates = all_rates self.all_subs_probs = all_subs_probs self.multihit_model = copy.deepcopy(multihit_model) - if multihit_model is not None: - # We want these parameters to act like fixed data. This is essential - # for multithreaded branch length optimization to work. - self.multihit_model.values.requires_grad_(False) + # if multihit_model is not None: + # # We want these parameters to act like fixed data. This is essential + # # for multithreaded branch length optimization to work. + # self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) @@ -217,6 +217,7 @@ def update_neutral_aa_mut_probs(self): mask = mask.to("cpu") rates = rates.to("cpu") subs_probs = subs_probs.to("cpu") + multihit_model = self.multihit_model.to("cpu") # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -231,7 +232,7 @@ def update_neutral_aa_mut_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), normed_subs_probs.reshape(-1, 3, 4), - multihit_model=self.multihit_model, + multihit_model=multihit_model, ) if not torch.isfinite(neutral_aa_mut_prob).all(): diff --git a/netam/framework.py b/netam/framework.py index cf9db9d0..2f4810e9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -342,9 +342,8 @@ def load_multihit_adjuster(multihit_crepe_prefix, device=None): multihit_model = None else: print(f"Loading multihit model from {multihit_crepe_prefix}") - multihit_model = None - # multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) - # multihit_model = MultihitApplier(multihit_adjust) + multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) + multihit_model = MultihitApplier(multihit_adjust) return multihit_model diff --git a/netam/hit_class.py b/netam/hit_class.py index e2aa090f..174aba0a 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -62,6 +62,7 @@ def apply_multihit_correction( torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible target codon, for each of the N parent codons, after applying the hit class factors. """ + warnings.warn("hit_class.py:apply_multihit_correction is deprecated, use apply_multihit_correction_nonlog instead") per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) reshaped_corrections = corrections[per_parent_hit_class] @@ -75,17 +76,57 @@ def apply_multihit_correction( assert False return result +def apply_multihit_correction_nonlog( + parent_codon_idxs: torch.Tensor, + codon_probs: torch.Tensor, + hit_class_factors: torch.Tensor, +) -> torch.Tensor: + """Multiply codon probabilities by their hit class factors, and renormalize. + + Suppose there are N codons, then the parameters are as follows: + + Args: + parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the + indices of the parent codon's nucleotides. + codon_probs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the probabilities + of mutating to each possible target codon, for each of the N parent codons. + hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The + factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). + + Returns: + torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the probabilities of mutating to each possible + target codon, for each of the N parent codons, after applying the hit class factors. + """ + per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) + corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]).exp() + reshaped_corrections = corrections[per_parent_hit_class] + unnormalized_corrected_probs = codon_probs * reshaped_corrections + normalizations = torch.sum( + unnormalized_corrected_probs, dim=[1, 2, 3], keepdim=True + ) + result = unnormalized_corrected_probs / normalizations + if torch.any(torch.isnan(result)): + print("NAN found in multihit correction application") + assert False + return result + class MultihitApplier: def __init__(self, hit_class_factors, device=None): + if np.allclose(hit_class_factors, 0): + warnings.warn("Hit class factors are all zero, and will not change probabilities") self.corrections = hit_class_factors.to(device) def to(self, device): self.corrections = self.corrections.to(device) return self - def __call__(self, parent_codon_idxs, codon_logprobs): - return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) + def __call__(self, parent_codon_idxs, codon_probs): + return apply_multihit_correction_nonlog(parent_codon_idxs, codon_probs, self.corrections) + + # def __call__(self, parent_codon_idxs, codon_logprobs): + # return codon_logprobs + # return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) def hit_class_probs_tensor( diff --git a/netam/molevol.py b/netam/molevol.py index a28cb752..9c3266dd 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -10,6 +10,7 @@ """ import numpy as np +import warnings import torch from torch import Tensor, optim @@ -304,7 +305,11 @@ def build_codon_mutsel( codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: - codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() + # codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() + # TODO this is for testing + codon_probs = multihit_model(parent_codon_idxs, codon_probs) + else: + warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Calculate the codon selection matrix for each sequence via Einstein # summation, in which we sum over the repeated indices. @@ -365,7 +370,11 @@ def neutral_aa_probs( codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: - codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() + # codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() + # TODO this is for testing + codon_probs = multihit_model(parent_codon_idxs, codon_probs) + else: + warnings.warn("No multihit model provided. Using uncorrected probabilities.") # Get the probability of mutating to each amino acid. aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX diff --git a/tests/test_multihit.py b/tests/test_multihit.py index 58d3ad35..88c6b002 100644 --- a/tests/test_multihit.py +++ b/tests/test_multihit.py @@ -107,7 +107,11 @@ def test_multihit_correction(): adjusted_codon_probs = hit_class.apply_multihit_correction( codon_idxs, ex_codon_probs.log(), hit_class_factors ).exp() + adjusted_nonlog_codon_probs = hit_class.apply_multihit_correction_nonlog( + codon_idxs, ex_codon_probs, hit_class_factors + ) aggregate_last = hit_class.hit_class_probs_tensor(codon_idxs, adjusted_codon_probs) + aggregate_last_nonlog = hit_class.hit_class_probs_tensor(codon_idxs, adjusted_nonlog_codon_probs) uncorrected_hc_log_probs = hit_class.hit_class_probs_tensor( codon_idxs, ex_codon_probs @@ -121,6 +125,7 @@ def test_multihit_correction(): uncorrected_hc_log_probs += corrections aggregate_first = torch.softmax(uncorrected_hc_log_probs, dim=1) assert torch.allclose(aggregate_first, aggregate_last) + assert torch.allclose(aggregate_first, aggregate_last_nonlog) def test_hit_class_tensor(): From 2aaa52c5cf3974bca007a103092d7bb3a4e3ccbd Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 17 Oct 2024 10:12:43 -0700 Subject: [PATCH 16/18] I think this might have worked --- netam/dnsm.py | 5 ++++- netam/framework.py | 2 ++ netam/hit_class.py | 6 ++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 32b9755e..e89c2212 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -217,7 +217,10 @@ def update_neutral_aa_mut_probs(self): mask = mask.to("cpu") rates = rates.to("cpu") subs_probs = subs_probs.to("cpu") - multihit_model = self.multihit_model.to("cpu") + if self.multihit_model is not None: + multihit_model = self.multihit_model.to("cpu") + else: + multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) diff --git a/netam/framework.py b/netam/framework.py index 2f4810e9..5e3800f9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -341,6 +341,8 @@ def load_multihit_adjuster(multihit_crepe_prefix, device=None): if multihit_crepe_prefix is None: multihit_model = None else: + # # TODO for testing + # multihit_model = None print(f"Loading multihit model from {multihit_crepe_prefix}") multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) multihit_model = MultihitApplier(multihit_adjust) diff --git a/netam/hit_class.py b/netam/hit_class.py index 174aba0a..086c07b4 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -63,6 +63,7 @@ def apply_multihit_correction( target codon, for each of the N parent codons, after applying the hit class factors. """ warnings.warn("hit_class.py:apply_multihit_correction is deprecated, use apply_multihit_correction_nonlog instead") + assert False per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) reshaped_corrections = corrections[per_parent_hit_class] @@ -116,11 +117,16 @@ def __init__(self, hit_class_factors, device=None): if np.allclose(hit_class_factors, 0): warnings.warn("Hit class factors are all zero, and will not change probabilities") self.corrections = hit_class_factors.to(device) + # # TODO This is just for testing + # self.corrections = torch.tensor([0.0, 0.0, 0.0]).to(device) def to(self, device): self.corrections = self.corrections.to(device) return self + # def __call__(self, parent_codon_idxs, codon_probs): + # return codon_probs + def __call__(self, parent_codon_idxs, codon_probs): return apply_multihit_correction_nonlog(parent_codon_idxs, codon_probs, self.corrections) From 364966523651e9f39a3560b268f3762ca15bef87 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 21 Oct 2024 14:13:35 -0700 Subject: [PATCH 17/18] verified working here --- netam/dnsm.py | 1 + netam/framework.py | 2 -- netam/hit_class.py | 9 --------- netam/molevol.py | 2 +- 4 files changed, 2 insertions(+), 12 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index e89c2212..d9b6f0bd 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -39,6 +39,7 @@ translate_sequence, translate_sequences, ) +# mpctx = mp.get_context('spawn') class DNSMDataset(Dataset): diff --git a/netam/framework.py b/netam/framework.py index 5e3800f9..2f4810e9 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -341,8 +341,6 @@ def load_multihit_adjuster(multihit_crepe_prefix, device=None): if multihit_crepe_prefix is None: multihit_model = None else: - # # TODO for testing - # multihit_model = None print(f"Loading multihit model from {multihit_crepe_prefix}") multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) multihit_model = MultihitApplier(multihit_adjust) diff --git a/netam/hit_class.py b/netam/hit_class.py index 086c07b4..0e79ba5c 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -117,23 +117,14 @@ def __init__(self, hit_class_factors, device=None): if np.allclose(hit_class_factors, 0): warnings.warn("Hit class factors are all zero, and will not change probabilities") self.corrections = hit_class_factors.to(device) - # # TODO This is just for testing - # self.corrections = torch.tensor([0.0, 0.0, 0.0]).to(device) def to(self, device): self.corrections = self.corrections.to(device) return self - # def __call__(self, parent_codon_idxs, codon_probs): - # return codon_probs - def __call__(self, parent_codon_idxs, codon_probs): return apply_multihit_correction_nonlog(parent_codon_idxs, codon_probs, self.corrections) - # def __call__(self, parent_codon_idxs, codon_logprobs): - # return codon_logprobs - # return apply_multihit_correction(parent_codon_idxs, codon_logprobs, self.corrections) - def hit_class_probs_tensor( parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor diff --git a/netam/molevol.py b/netam/molevol.py index 9c3266dd..c8594713 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -18,7 +18,7 @@ from netam.sequences import CODON_AA_INDICATOR_MATRIX import netam.sequences as sequences -torch.autograd.set_detect_anomaly(True) +# torch.autograd.set_detect_anomaly(True) def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: From 48da87373fb61fd4a0694ce536b831d07194a660 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 22 Oct 2024 16:00:39 -0700 Subject: [PATCH 18/18] cleanup for PR --- netam/dnsm.py | 9 ++++---- netam/framework.py | 11 --------- netam/hit_class.py | 51 +----------------------------------------- netam/models.py | 5 +++-- netam/molevol.py | 4 ---- netam/multihit.py | 26 ++++++++++----------- tests/test_multihit.py | 5 ----- 7 files changed, 21 insertions(+), 90 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index d9b6f0bd..0b699b3b 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -39,7 +39,6 @@ translate_sequence, translate_sequences, ) -# mpctx = mp.get_context('spawn') class DNSMDataset(Dataset): @@ -57,10 +56,10 @@ def __init__( self.all_rates = all_rates self.all_subs_probs = all_subs_probs self.multihit_model = copy.deepcopy(multihit_model) - # if multihit_model is not None: - # # We want these parameters to act like fixed data. This is essential - # # for multithreaded branch length optimization to work. - # self.multihit_model.values.requires_grad_(False) + if multihit_model is not None: + # We want these parameters to act like fixed data. This is essential + # for multithreaded branch length optimization to work. + self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) diff --git a/netam/framework.py b/netam/framework.py index 2f4810e9..d44ab1d0 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -27,7 +27,6 @@ VRC01_NT_SEQ, ) from netam import models -from netam.hit_class import MultihitApplier import netam.molevol as molevol @@ -337,16 +336,6 @@ def crepe_exists(prefix): return os.path.exists(f"{prefix}.yml") and os.path.exists(f"{prefix}.pth") -def load_multihit_adjuster(multihit_crepe_prefix, device=None): - if multihit_crepe_prefix is None: - multihit_model = None - else: - print(f"Loading multihit model from {multihit_crepe_prefix}") - multihit_adjust = load_crepe(multihit_crepe_prefix, device=device).model.values.detach().to(device) - multihit_model = MultihitApplier(multihit_adjust) - return multihit_model - - def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" rates, csp_logits = crepe(parents) diff --git a/netam/hit_class.py b/netam/hit_class.py index 0e79ba5c..e9462187 100644 --- a/netam/hit_class.py +++ b/netam/hit_class.py @@ -41,43 +41,8 @@ def parent_specific_hit_classes(parent_codon_idxs: torch.Tensor) -> torch.Tensor ] -def apply_multihit_correction( - parent_codon_idxs: torch.Tensor, - codon_logprobs: torch.Tensor, - hit_class_factors: torch.Tensor, -) -> torch.Tensor: - """Multiply codon probabilities by their hit class factors, and renormalize. - - Suppose there are N codons, then the parameters are as follows: - - Args: - parent_codon_idxs (torch.Tensor): A (N, 3) shaped tensor containing for each codon, the - indices of the parent codon's nucleotides. - codon_logprobs (torch.Tensor): A (N, 4, 4, 4) shaped tensor containing the log probabilities - of mutating to each possible target codon, for each of the N parent codons. - hit_class_factors (torch.Tensor): A tensor containing the log hit class factors for hit classes 1, 2, and 3. The - factor for hit class 0 is assumed to be 1 (that is, 0 in log-space). - - Returns: - torch.Tensor: A (N, 4, 4, 4) shaped tensor containing the log probabilities of mutating to each possible - target codon, for each of the N parent codons, after applying the hit class factors. - """ - warnings.warn("hit_class.py:apply_multihit_correction is deprecated, use apply_multihit_correction_nonlog instead") - assert False - per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) - corrections = torch.cat([torch.tensor([0.0]), hit_class_factors]) - reshaped_corrections = corrections[per_parent_hit_class] - unnormalized_corrected_logprobs = codon_logprobs + reshaped_corrections - normalizations = torch.logsumexp( - unnormalized_corrected_logprobs, dim=[1, 2, 3], keepdim=True - ) - result = unnormalized_corrected_logprobs - normalizations - if torch.any(torch.isnan(result)): - print("NAN found in multihit correction application") - assert False - return result -def apply_multihit_correction_nonlog( +def apply_multihit_correction( parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor, hit_class_factors: torch.Tensor, @@ -112,20 +77,6 @@ def apply_multihit_correction_nonlog( return result -class MultihitApplier: - def __init__(self, hit_class_factors, device=None): - if np.allclose(hit_class_factors, 0): - warnings.warn("Hit class factors are all zero, and will not change probabilities") - self.corrections = hit_class_factors.to(device) - - def to(self, device): - self.corrections = self.corrections.to(device) - return self - - def __call__(self, parent_codon_idxs, codon_probs): - return apply_multihit_correction_nonlog(parent_codon_idxs, codon_probs, self.corrections) - - def hit_class_probs_tensor( parent_codon_idxs: torch.Tensor, codon_probs: torch.Tensor ) -> torch.Tensor: diff --git a/netam/models.py b/netam/models.py index 3c73a322..381c8da9 100644 --- a/netam/models.py +++ b/netam/models.py @@ -702,14 +702,15 @@ def __init__(self): def hyperparameters(self): return {} + # TODO changed to nonlog version for testing, need to update all calls to reflect this def forward( - self, parent_codon_idxs: torch.Tensor, uncorrected_log_codon_probs: torch.Tensor + self, parent_codon_idxs: torch.Tensor, uncorrected_codon_probs: torch.Tensor ): """Forward function takes a tensor of target codon distributions, for each observed parent codon, and adjusts the distributions according to the hit class corrections.""" return apply_multihit_correction( - parent_codon_idxs, uncorrected_log_codon_probs, self.values + parent_codon_idxs, uncorrected_codon_probs, self.values ) def reinitialize_weights(self): diff --git a/netam/molevol.py b/netam/molevol.py index c8594713..ca1ae6e7 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -305,8 +305,6 @@ def build_codon_mutsel( codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: - # codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() - # TODO this is for testing codon_probs = multihit_model(parent_codon_idxs, codon_probs) else: warnings.warn("No multihit model provided. Using uncorrected probabilities.") @@ -370,8 +368,6 @@ def neutral_aa_probs( codon_probs = codon_probs_of_mutation_matrices(mut_matrices) if multihit_model is not None: - # codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp() - # TODO this is for testing codon_probs = multihit_model(parent_codon_idxs, codon_probs) else: warnings.warn("No multihit model provided. Using uncorrected probabilities.") diff --git a/netam/multihit.py b/netam/multihit.py index 192d0b03..9544f0fe 100644 --- a/netam/multihit.py +++ b/netam/multihit.py @@ -270,28 +270,28 @@ def child_codon_probs_from_per_parent_probs(per_parent_probs, child_codon_idxs): ] -def child_codon_logprobs_corrected( - uncorrected_per_parent_logprobs, parent_codon_idxs, child_codon_idxs, model +def child_codon_probs_corrected( + uncorrected_per_parent_probs, parent_codon_idxs, child_codon_idxs, model ): """Calculate the probability of each child codon given the parent codon probabilities, corrected by hit class factors. Args: - uncorrected_per_parent_logprobs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the log probabilities + uncorrected_per_parent_probs (torch.Tensor): A (codon_count, 4, 4, 4) shaped tensor containing the probabilities of each possible target codon, for each parent codon. parent_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each parent codon child_codon_idxs (torch.Tensor): A (codon_count, 3) shaped tensor containing the nucleotide indices for each child codon model: A HitClassModel implementing the desired correction. Returns: - torch.Tensor: A (codon_count,) shaped tensor containing the corrected log probabilities of each child codon. + torch.Tensor: A (codon_count,) shaped tensor containing the corrected probabilities of each child codon. """ - corrected_per_parent_logprobs = model( - parent_codon_idxs, uncorrected_per_parent_logprobs + corrected_per_parent_probs = model( + parent_codon_idxs, uncorrected_per_parent_probs ) return child_codon_probs_from_per_parent_probs( - corrected_per_parent_logprobs, child_codon_idxs + corrected_per_parent_probs, child_codon_idxs ) @@ -340,12 +340,12 @@ def loss_of_batch(self, batch): codon_probs, codon_mask=codon_mask ) - child_codon_logprobs = child_codon_logprobs_corrected( - flat_masked_codon_probs.log(), + child_codon_logprobs = child_codon_probs_corrected( + flat_masked_codon_probs, parent_codons_flat, child_codons_flat, self.model, - ) + ).log() return -child_codon_logprobs.sum() def _find_optimal_branch_length( @@ -374,9 +374,9 @@ def log_pcp_probability(log_branch_length): child_codon_idxs = reshape_for_codons(child_idxs)[codon_mask] parent_codon_idxs = reshape_for_codons(parent_idxs)[codon_mask] - return child_codon_logprobs_corrected( - codon_probs.log(), parent_codon_idxs, child_codon_idxs, self.model - ).sum() + return child_codon_probs_corrected( + codon_probs, parent_codon_idxs, child_codon_idxs, self.model + ).log().sum() return optimize_branch_length( log_pcp_probability, diff --git a/tests/test_multihit.py b/tests/test_multihit.py index 88c6b002..f9f887dd 100644 --- a/tests/test_multihit.py +++ b/tests/test_multihit.py @@ -105,13 +105,9 @@ def test_multihit_correction(): # We'll verify that aggregating by hit class then adjusting is the same as adjusting then aggregating by hit class. codon_idxs = reshape_for_codons(ex_parent_codon_idxs) adjusted_codon_probs = hit_class.apply_multihit_correction( - codon_idxs, ex_codon_probs.log(), hit_class_factors - ).exp() - adjusted_nonlog_codon_probs = hit_class.apply_multihit_correction_nonlog( codon_idxs, ex_codon_probs, hit_class_factors ) aggregate_last = hit_class.hit_class_probs_tensor(codon_idxs, adjusted_codon_probs) - aggregate_last_nonlog = hit_class.hit_class_probs_tensor(codon_idxs, adjusted_nonlog_codon_probs) uncorrected_hc_log_probs = hit_class.hit_class_probs_tensor( codon_idxs, ex_codon_probs @@ -125,7 +121,6 @@ def test_multihit_correction(): uncorrected_hc_log_probs += corrections aggregate_first = torch.softmax(uncorrected_hc_log_probs, dim=1) assert torch.allclose(aggregate_first, aggregate_last) - assert torch.allclose(aggregate_first, aggregate_last_nonlog) def test_hit_class_tensor():