From d52a617b55a461f8186b8064a7c567e9d396dd25 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 15 Jan 2025 13:04:19 -0800 Subject: [PATCH] zapping stop codons --- netam/dcsm.py | 9 ++++++++- netam/sequences.py | 8 ++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/netam/dcsm.py b/netam/dcsm.py index 433d1f2b..58840fee 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -10,6 +10,7 @@ assert_pcp_valid, clamp_probability, codon_mask_tensor_of, + BIG, ) from netam.dxsm import DXSMDataset, DXSMBurrito import netam.molevol as molevol @@ -18,6 +19,7 @@ from netam.sequences import ( aa_idx_array_of_str, aa_subs_indicator_tensor_of, + build_stop_codon_indicator_tensor, nt_idx_tensor_of_str, token_mask_of_aa_idxs, translate_sequence, @@ -223,6 +225,7 @@ class DCSMBurrito(DXSMBurrito): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.xent_loss = torch.nn.CrossEntropyLoss() + self.stop_codon_zapper = build_stop_codon_indicator_tensor() * -BIG def prediction_pair_of_batch(self, batch): """Get log neutral codon substitution probabilities and log selection factors @@ -271,7 +274,11 @@ def predictions_of_batch(self, batch): # This indicator lifts things up from aa land to codon land. indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T - log_preds = log_neutral_codon_probs + log_selection_factors @ indicator + log_preds = ( + log_neutral_codon_probs + + log_selection_factors @ indicator + + self.stop_codon_zapper + ) assert torch.isnan(log_preds).sum() == 0 parent_indices = batch["codon_parents_idxs"] # Shape: [B, L] diff --git a/netam/sequences.py b/netam/sequences.py index 726e1b92..01cbca58 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -36,6 +36,14 @@ RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" +def build_stop_codon_indicator_tensor(): + """Return a tensor indicating the stop codons.""" + stop_codon_indicator = torch.zeros(len(CODONS)) + for stop_codon in STOP_CODONS: + stop_codon_indicator[CODONS.index(stop_codon)] = 1.0 + return stop_codon_indicator + + def nt_idx_array_of_str(nt_str): """Return the indices of the nucleotides in a string.""" try: