Skip to content

Commit

Permalink
zapping stop codons
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 15, 2025
1 parent 7d6dac2 commit d52a617
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
9 changes: 8 additions & 1 deletion netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
assert_pcp_valid,
clamp_probability,
codon_mask_tensor_of,
BIG,
)
from netam.dxsm import DXSMDataset, DXSMBurrito
import netam.molevol as molevol
Expand All @@ -18,6 +19,7 @@
from netam.sequences import (
aa_idx_array_of_str,
aa_subs_indicator_tensor_of,
build_stop_codon_indicator_tensor,
nt_idx_tensor_of_str,
token_mask_of_aa_idxs,
translate_sequence,
Expand Down Expand Up @@ -223,6 +225,7 @@ class DCSMBurrito(DXSMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = torch.nn.CrossEntropyLoss()
self.stop_codon_zapper = build_stop_codon_indicator_tensor() * -BIG

def prediction_pair_of_batch(self, batch):
"""Get log neutral codon substitution probabilities and log selection factors
Expand Down Expand Up @@ -271,7 +274,11 @@ def predictions_of_batch(self, batch):

# This indicator lifts things up from aa land to codon land.
indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T
log_preds = log_neutral_codon_probs + log_selection_factors @ indicator
log_preds = (
log_neutral_codon_probs
+ log_selection_factors @ indicator
+ self.stop_codon_zapper
)
assert torch.isnan(log_preds).sum() == 0

parent_indices = batch["codon_parents_idxs"] # Shape: [B, L]
Expand Down
8 changes: 8 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]"


def build_stop_codon_indicator_tensor():
"""Return a tensor indicating the stop codons."""
stop_codon_indicator = torch.zeros(len(CODONS))
for stop_codon in STOP_CODONS:
stop_codon_indicator[CODONS.index(stop_codon)] = 1.0
return stop_codon_indicator


def nt_idx_array_of_str(nt_str):
"""Return the indices of the nucleotides in a string."""
try:
Expand Down

0 comments on commit d52a617

Please sign in to comment.