Skip to content

Commit

Permalink
Restore compatibility with models handling fewer tokens (#104)
Browse files Browse the repository at this point in the history
This PR addresses #102. It also fixes an issue where DASM model output dimension grew with the number of amino acid tokens, instead of staying fixed at `20`.
  • Loading branch information
willdumm authored Jan 14, 2025
1 parent 35c3efa commit 5cac227
Show file tree
Hide file tree
Showing 22 changed files with 140 additions and 44 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ jobs:
cache-environment: false
post-cleanup: 'none'

- name: Check for TODOs
shell: bash -l {0}
run: |
cd main
make checktodo
- name: Check format
shell: bash -l {0}
run: |
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ checkformat:
docformatter --check --black --recursive netam tests
black --check netam tests

checktodo:
grep -rq --include=\*.{py,Snakemake} "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0

lint:
flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore

Expand All @@ -27,4 +30,4 @@ notebooks:
jupyter nbconvert --to notebook --execute "$$nb" --output _ignore/"$$(basename $$nb)"; \
done

.PHONY: install test notebooks format lint docs
.PHONY: install test notebooks format lint docs checktodo
8 changes: 4 additions & 4 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
apply_aa_mask_to_nt_sequence,
RESERVED_TOKEN_TRANSLATIONS,
BASES,
AA_TOKEN_STR_SORTED,
TOKEN_STR_SORTED,
)

BIG = 1e9
Expand Down Expand Up @@ -66,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str):
character."""
try:
return torch.tensor(
[AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
[TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
)
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
Expand Down Expand Up @@ -126,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor):
Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_TOKEN_STR_SORTED.
indices into TOKEN_STR_SORTED.
Returns:
List[str]: A list of amino acid strings with trailing 'X's removed.
Expand All @@ -135,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor):

aa_str_list = []
for row in idx_tensor:
aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist())
aa_str = "".join(TOKEN_STR_SORTED[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list
Expand Down
4 changes: 1 addition & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def build_selection_matrix_from_parent(self, parent: str):
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros(
(len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float
)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
Expand Down
5 changes: 3 additions & 2 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
nt_mutation_frequency,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
# to the ambiguous amino acid, and then will fill in the actual values
# below.
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX
(pcp_count, self.max_aa_seq_len), AA_AMBIG_IDX
)
self.aa_children_idxss = self.aa_parents_idxss.clone()
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))
Expand Down Expand Up @@ -303,7 +304,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(
# dataset, **optimization_kwargs
Expand Down
5 changes: 5 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def load_crepe(prefix, device=None):
f"Model class '{model_class_name}' not found in 'models' module."
)

if issubclass(model_class, models.TransformerBinarySelectionModelLinAct):
if "embedding_dim" not in config["model_hyperparameters"]:
# Assume the model is from before any new tokens were added, so 21
config["model_hyperparameters"]["embedding_dim"] = 21

model = model_class(**config["model_hyperparameters"])

model_state_path = f"{prefix}.pth"
Expand Down
34 changes: 28 additions & 6 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,29 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension. If extra tokens have been added since
# this model was defined, they are stripped out before feeding the
# sequence to the model, and the returned selection factors will be NaN
# at sites containing those unrecognized tokens.
model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"]
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
result = torch.full(
(len(aa_str), self.hyperparameters["output_dim"]),
float("nan"),
device=self.device,
)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
final_out = torch.exp(model_out)
model_out = self(
aa_idxs[model_valid_sites].unsqueeze(0),
mask[model_valid_sites].unsqueeze(0),
).squeeze(0)
result[model_valid_sites] = torch.exp(model_out)[: model_valid_sites.sum()]

return final_out[: len(aa_str)]
return result


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand All @@ -613,16 +631,18 @@ def __init__(
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
embedding_dim: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
# Note that d_model has to be divisible by nhead, so we make that
# automatic here.
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.embedding_dim = embedding_dim
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(self.embedding_dim, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand All @@ -642,6 +662,7 @@ def hyperparameters(self):
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
"embedding_dim": self.embedding_dim,
}

def init_weights(self) -> None:
Expand Down Expand Up @@ -748,14 +769,15 @@ def predict(self, representation: Tensor):
class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

def __init__(self, output_dim: int = 1):
def __init__(self, output_dim: int = 1, embedding_dim: int = MAX_AA_TOKEN_IDX + 1):
super().__init__()
self.single_value = nn.Parameter(torch.tensor(0.0))
self.output_dim = output_dim
self.embedding_dim = embedding_dim

@property
def hyperparameters(self):
return {"output_dim": self.output_dim}
return {"output_dim": self.output_dim, "embedding_dim": self.embedding_dim}

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from an index-encoded parent sequence."""
Expand Down
4 changes: 2 additions & 2 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, optim

from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX
from netam.sequences import CODON_AA_INDICATOR_MATRIX

import netam.sequences as sequences

Expand Down Expand Up @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of(
"""

assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1)
assert sel_matrix.shape == (len(parent) // 3, 20)

parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)
Expand Down
25 changes: 13 additions & 12 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
# ambiguous must remain last. It is assumed elsewhere that the max index
# denotes the ambiguous base
AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X"
# Must add new tokens to the end of this string.
TOKEN_STR_SORTED = AA_STR_SORTED + "X" + RESERVED_TOKENS
AA_AMBIG_IDX = len(AA_STR_SORTED)

RESERVED_TOKEN_AA_BOUNDS = (
min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
max(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
min(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
max(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
)
MAX_AA_TOKEN_IDX = len(AA_TOKEN_STR_SORTED) - 1
MAX_AA_TOKEN_IDX = len(TOKEN_STR_SORTED) - 1
CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)]
STOP_CODONS = ["TAA", "TAG", "TGA"]
# Each token in RESERVED_TOKENS will appear once in aa strings, and three times
Expand All @@ -48,7 +48,7 @@ def nt_idx_array_of_str(nt_str):
def aa_idx_array_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return np.array([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str])
return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand All @@ -66,7 +66,7 @@ def nt_idx_tensor_of_str(nt_str):
def aa_idx_tensor_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return torch.tensor([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str])
return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand Down Expand Up @@ -135,7 +135,7 @@ def translate_sequences(nt_sequences):
def aa_index_of_codon(codon):
"""Return the index of the amino acid encoded by a codon."""
aa = translate_sequence(codon)
return AA_TOKEN_STR_SORTED.index(aa)
return TOKEN_STR_SORTED.index(aa)


def generic_mutation_frequency(ambig_symb, parent, child):
Expand Down Expand Up @@ -185,15 +185,16 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3):
def generate_codon_aa_indicator_matrix():
"""Generate a matrix that maps codons (rows) to amino acids (columns)."""

matrix = np.zeros((len(CODONS), len(AA_TOKEN_STR_SORTED)))
matrix = np.zeros((len(CODONS), len(AA_STR_SORTED)))

for i, codon in enumerate(CODONS):
try:
aa = translate_sequences([codon])[0]
aa_idx = AA_TOKEN_STR_SORTED.index(aa)
matrix[i, aa_idx] = 1
except ValueError: # Handle STOP codon
pass
else:
aa_idx = AA_STR_SORTED.index(aa)
matrix[i, aa_idx] = 1

return matrix

Expand Down
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 20
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dasm_output
Binary file not shown.
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 1
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dnsm_output
Binary file not shown.
3 changes: 1 addition & 2 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam import pretrained
import random

Expand Down Expand Up @@ -170,7 +169,7 @@ def dasm_model():
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=MAX_AA_TOKEN_IDX + 1,
output_dim=20,
)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_backward_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from netam.framework import load_crepe
from netam.sequences import set_wt_to_nan


# The outputs used for this test are produced by running
# `test_backward_compat_copy.py` on the wd-old-model-runner branch.
# This is to ensure that we can still load older crepes, even if we change the
# dimensions of model layers, as we did with the Embedding layer in
# https://github.com/matsengrp/netam/pull/92.
def test_old_model_outputs():
example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS"
dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint")
dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint")

dasm_vals = torch.nan_to_num(
set_wt_to_nan(
torch.load("tests/old_models/dasm_output", weights_only=True), example_seq
),
0.0,
)
dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True)

dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0)
dnsm_result = dnsm_crepe([example_seq])[0]
assert torch.allclose(dasm_result, dasm_vals)
assert torch.allclose(dnsm_result, dnsm_vals)
4 changes: 2 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of():


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]])
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE^", "FGY"]
assert aa_strings == ["ACDEX^", "FGY^"]
3 changes: 1 addition & 2 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import (
DASMBurrito,
Expand All @@ -30,7 +29,7 @@ def dasm_burrito(pcp_df):
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=MAX_AA_TOKEN_IDX + 1,
output_dim=20,
)

burrito = DASMBurrito(
Expand Down
Loading

0 comments on commit 5cac227

Please sign in to comment.