From 96025ab2f728cc871986a544aa1e330a162e64af Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 16 Jan 2025 16:38:41 -0800 Subject: [PATCH] WIP --- netam/dasm.py | 21 ++++++++++++++++++--- netam/dnsm.py | 2 +- netam/dxsm.py | 2 ++ netam/models.py | 7 ++++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 6dc60321..473bd289 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -99,7 +99,7 @@ def to(self, device): self.multihit_model = self.multihit_model.to(device) -def zap_predictions_along_diagonal(predictions, aa_parents_idxs): +def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, except where aa_parents_idxs >= 20, which indicates no update should be done.""" @@ -116,7 +116,7 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs): batch_indices[valid_mask], sequence_indices[valid_mask], aa_parents_idxs[valid_mask], - ] = -BIG + ] = fill return predictions @@ -204,9 +204,24 @@ def loss_of_batch(self, batch): csp_loss = self.xent_loss(csp_pred, csp_targets) return torch.stack([subs_pos_loss, csp_loss]) - def build_selection_matrix_from_parent(self, parent: str): + # TODO have a close look at these two functions, I'm feeling unsure about + # them + def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor): """Build a selection matrix from a parent amino acid sequence. + Values at ambiguous sites are meaningless. + """ + per_aa_selection_factors = self.model.forward(aa_parent_idxs, mask) + + + # TODO why 1.0? + return zap_predictions_along_diagonal(per_aa_selection_factors, aa_parent_idxs, fill=1.0) + + return per_aa_selection_factors + + def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent nucleotide sequence. + Values at ambiguous sites are meaningless. """ # This is simpler than the equivalent in dnsm.py because we get the selection diff --git a/netam/dnsm.py b/netam/dnsm.py index 6efeda85..d30157ee 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -157,7 +157,7 @@ def loss_of_batch(self, batch): return self.bce_loss(predictions, aa_subs_indicator) def build_selection_matrix_from_parent(self, parent: str): - """Build a selection matrix from a parent amino acid sequence. + """Build a selection matrix from a nucleotide sequence. Values at ambiguous sites are meaningless. """ diff --git a/netam/dxsm.py b/netam/dxsm.py index 0d50aaf1..470a1753 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -282,6 +282,8 @@ def _find_optimal_branch_length( multihit_model, **optimization_kwargs, ): + # TODO finish switching to build_selection_matrix_from_parent_aa + # thing... sel_matrix = self.build_selection_matrix_from_parent(parent) trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( diff --git a/netam/models.py b/netam/models.py index abd5fbca..a842fe62 100644 --- a/netam/models.py +++ b/netam/models.py @@ -584,12 +584,17 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: """Do the forward method then exponentiation without gradients from an amino acid string. + Insertion of model tokens will be done automatically. + Args: - aa_str: A string of amino acids. + aa_str: A string of amino acids. If a string, we assume this is a light chain sequence. + Otherwise it should be a tuple, with the first element being the heavy chain and the second element being the light chain sequence. Returns: A numpy array of the same length as the input string representing the level of selection for each amino acid at each site. + If the input was a tuple of heavy/light chain sequences, the output will be a tuple of + numpy arrays. """ aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)