Skip to content

Commit

Permalink
implementing DCSM
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 15, 2025
1 parent 5cac227 commit 326605b
Show file tree
Hide file tree
Showing 7 changed files with 529 additions and 6 deletions.
405 changes: 405 additions & 0 deletions netam/dcsm.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def update_neutral_probs(self):
mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]

# TODO singular/plural mismatch
neutral_aa_mut_prob = molevol.neutral_aa_mut_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
Expand Down
3 changes: 3 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None
self.optimizer.zero_grad()
scalar_loss.backward()

if torch.isnan(scalar_loss):
raise ValueError(f"NaN in loss: {scalar_loss.item()}")

nan_in_gradients = False
for name, param in self.model.named_parameters():
if torch.isnan(param).any():
Expand Down
40 changes: 34 additions & 6 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,23 +339,23 @@ def build_codon_mutsel(
return codon_mutsel, sums_too_big


def neutral_aa_probs(
def neutral_codon_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_csps: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will mutate to every
amino acid?
"""For every site, what is the probability that the site will mutate to every
alternate codon?
Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3)
codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4)
Returns:
torch.Tensor: The probability that each site will change to each amino acid.
Shape: (codon_count, 20)
torch.Tensor: The probability that each site will change to each codon.
Shape: (codon_count, 64)
"""

mut_matrices = build_mutation_matrices(
Expand All @@ -366,8 +366,36 @@ def neutral_aa_probs(
if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs)

return codon_probs.view(-1, 64)


def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_csps: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the site will mutate to every
alternate amino acid?
Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3)
codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4)
Returns:
torch.Tensor: The probability that each site will change to each codon.
Shape: (codon_count, 20)
"""
codon_probs = neutral_codon_probs(
parent_codon_idxs,
codon_mut_probs,
codon_csps,
multihit_model=multihit_model,
)

# Get the probability of mutating to each amino acid.
aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX
aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX

return aa_probs

Expand Down
24 changes: 24 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,30 @@ def aa_idx_tensor_of_str(aa_str):
raise


# TODO isolating all this stuff here

AMBIGUOUS_CODON_IDX = len(CODONS)


def idx_of_codon_allowing_ambiguous(codon):
# if codon contains an N
if "N" in codon:
return AMBIGUOUS_CODON_IDX
else:
return CODONS.index(codon)


def codon_idx_tensor_of_str_ambig(nt_str):
"""Return the indices of the codons in a string."""
assert len(nt_str) % 3 == 0
return torch.tensor(
[idx_of_codon_allowing_ambiguous(codon) for codon in iter_codons(nt_str)]
)


# TODO end isolating new stuff


def aa_onehot_tensor_of_str(aa_str):
aa_onehot = torch.zeros((len(aa_str), 20))
aa_indices_parent = aa_idx_array_of_str(aa_str)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_dcsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

import torch
import pytest

from netam.common import BIG, force_spawn
from netam.framework import (
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dcsm import (
DCSMBurrito,
DCSMDataset,
)


@pytest.fixture(scope="module")
def dcsm_burrito(pcp_df):
force_spawn()
"""Fixture that returns the DNSM Burrito object."""
pcp_df["in_train"] = True
pcp_df.loc[pcp_df.index[-15:], "in_train"] = False
train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df)

model = TransformerBinarySelectionModelWiggleAct(
nhead=2,
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=MAX_AA_TOKEN_IDX + 1,
)

burrito = DCSMBurrito(
train_dataset,
val_dataset,
model,
batch_size=32,
learning_rate=0.001,
min_learning_rate=0.0001,
)
burrito.joint_train(
epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False
)
return burrito


def test_parallel_branch_length_optimization(dcsm_burrito):
dataset = dcsm_burrito.val_dataset
parallel_branch_lengths = dcsm_burrito.find_optimal_branch_lengths(dataset)
branch_lengths = dcsm_burrito.serial_find_optimal_branch_lengths(dataset)
assert torch.allclose(branch_lengths, parallel_branch_lengths)
9 changes: 9 additions & 0 deletions tests/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
TOKEN_STR_SORTED,
CODONS,
CODON_AA_INDICATOR_MATRIX,
AMBIGUOUS_CODON_IDX,
aa_onehot_tensor_of_str,
codon_idx_tensor_of_str,
nt_idx_array_of_str,
nt_subs_indicator_tensor_of,
translate_sequences,
Expand Down Expand Up @@ -51,6 +53,13 @@ def test_nucleotide_indices_of_codon():
assert nt_idx_array_of_str("GCG").tolist() == [2, 1, 2]


def test_codon_idx_tensor_of_str():
nt_str = "AAAAACTTGTTTNTT"
expected_output = torch.tensor([0, 1, 62, 63, AMBIGUOUS_CODON_IDX])
output = codon_idx_tensor_of_str(nt_str)
assert torch.equal(output, expected_output)


def test_aa_onehot_tensor_of_str():
aa_str = "QY"

Expand Down

0 comments on commit 326605b

Please sign in to comment.