Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore compatibility with models handling fewer tokens #104

Merged
merged 19 commits into from
Jan 14, 2025
Merged
6 changes: 6 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ jobs:
cache-environment: false
post-cleanup: 'none'

- name: Check for TODOs
shell: bash -l {0}
run: |
cd main
make checktodo

- name: Check format
shell: bash -l {0}
run: |
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ checkformat:
docformatter --check --black --recursive netam tests
black --check netam tests

checktodo:
grep -rq --include=\*.{py,Snakemake} "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0

lint:
flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore

Expand All @@ -27,4 +30,4 @@ notebooks:
jupyter nbconvert --to notebook --execute "$$nb" --output _ignore/"$$(basename $$nb)"; \
done

.PHONY: install test notebooks format lint docs
.PHONY: install test notebooks format lint docs checktodo
8 changes: 4 additions & 4 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
apply_aa_mask_to_nt_sequence,
RESERVED_TOKEN_TRANSLATIONS,
BASES,
AA_TOKEN_STR_SORTED,
TOKEN_STR_SORTED,
)

BIG = 1e9
Expand Down Expand Up @@ -66,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str):
character."""
try:
return torch.tensor(
[AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
[TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
)
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
Expand Down Expand Up @@ -126,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor):

Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_TOKEN_STR_SORTED.
indices into TOKEN_STR_SORTED.

Returns:
List[str]: A list of amino acid strings with trailing 'X's removed.
Expand All @@ -135,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor):

aa_str_list = []
for row in idx_tensor:
aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist())
aa_str = "".join(TOKEN_STR_SORTED[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list
Expand Down
4 changes: 1 addition & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def build_selection_matrix_from_parent(self, parent: str):
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros(
(len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float
)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
Expand Down
5 changes: 3 additions & 2 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
nt_mutation_frequency,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
# to the ambiguous amino acid, and then will fill in the actual values
# below.
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX
(pcp_count, self.max_aa_seq_len), AA_AMBIG_IDX
)
self.aa_children_idxss = self.aa_parents_idxss.clone()
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))
Expand Down Expand Up @@ -303,7 +304,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(
# dataset, **optimization_kwargs
Expand Down
5 changes: 5 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def load_crepe(prefix, device=None):
f"Model class '{model_class_name}' not found in 'models' module."
)

if issubclass(model_class, models.TransformerBinarySelectionModelLinAct):
if "embedding_dim" not in config["model_hyperparameters"]:
# Assume the model is from before any new tokens were added, so 21
config["model_hyperparameters"]["embedding_dim"] = 21

model = model_class(**config["model_hyperparameters"])

model_state_path = f"{prefix}.pth"
Expand Down
34 changes: 28 additions & 6 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,29 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension. If extra tokens have been added since
# this model was defined, they are stripped out before feeding the
# sequence to the model, and the returned selection factors will be NaN
# at sites containing those unrecognized tokens.
model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"]
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
result = torch.full(
(len(aa_str), self.hyperparameters["output_dim"]),
float("nan"),
device=self.device,
)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
final_out = torch.exp(model_out)
model_out = self(
aa_idxs[model_valid_sites].unsqueeze(0),
mask[model_valid_sites].unsqueeze(0),
).squeeze(0)
result[model_valid_sites] = torch.exp(model_out)[: model_valid_sites.sum()]

return final_out[: len(aa_str)]
return result


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand All @@ -613,16 +631,18 @@ def __init__(
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
embedding_dim: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
# Note that d_model has to be divisible by nhead, so we make that
# automatic here.
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.embedding_dim = embedding_dim
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(self.embedding_dim, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand All @@ -642,6 +662,7 @@ def hyperparameters(self):
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
"embedding_dim": self.embedding_dim,
}

def init_weights(self) -> None:
Expand Down Expand Up @@ -748,14 +769,15 @@ def predict(self, representation: Tensor):
class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

def __init__(self, output_dim: int = 1):
def __init__(self, output_dim: int = 1, embedding_dim: int = MAX_AA_TOKEN_IDX + 1):
super().__init__()
self.single_value = nn.Parameter(torch.tensor(0.0))
self.output_dim = output_dim
self.embedding_dim = embedding_dim

@property
def hyperparameters(self):
return {"output_dim": self.output_dim}
return {"output_dim": self.output_dim, "embedding_dim": self.embedding_dim}

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from an index-encoded parent sequence."""
Expand Down
4 changes: 2 additions & 2 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, optim

from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX
from netam.sequences import CODON_AA_INDICATOR_MATRIX

import netam.sequences as sequences

Expand Down Expand Up @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of(
"""

assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1)
assert sel_matrix.shape == (len(parent) // 3, 20)

parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)
Expand Down
25 changes: 13 additions & 12 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
# ambiguous must remain last. It is assumed elsewhere that the max index
# denotes the ambiguous base
AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X"
# Must add new tokens to the end of this string.
TOKEN_STR_SORTED = AA_STR_SORTED + "X" + RESERVED_TOKENS
AA_AMBIG_IDX = len(AA_STR_SORTED)

RESERVED_TOKEN_AA_BOUNDS = (
min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
max(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
min(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
max(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
)
MAX_AA_TOKEN_IDX = len(AA_TOKEN_STR_SORTED) - 1
MAX_AA_TOKEN_IDX = len(TOKEN_STR_SORTED) - 1
CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)]
STOP_CODONS = ["TAA", "TAG", "TGA"]
# Each token in RESERVED_TOKENS will appear once in aa strings, and three times
Expand All @@ -48,7 +48,7 @@ def nt_idx_array_of_str(nt_str):
def aa_idx_array_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return np.array([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str])
return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand All @@ -66,7 +66,7 @@ def nt_idx_tensor_of_str(nt_str):
def aa_idx_tensor_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return torch.tensor([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str])
return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand Down Expand Up @@ -135,7 +135,7 @@ def translate_sequences(nt_sequences):
def aa_index_of_codon(codon):
"""Return the index of the amino acid encoded by a codon."""
aa = translate_sequence(codon)
return AA_TOKEN_STR_SORTED.index(aa)
return TOKEN_STR_SORTED.index(aa)


def generic_mutation_frequency(ambig_symb, parent, child):
Expand Down Expand Up @@ -185,15 +185,16 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3):
def generate_codon_aa_indicator_matrix():
"""Generate a matrix that maps codons (rows) to amino acids (columns)."""

matrix = np.zeros((len(CODONS), len(AA_TOKEN_STR_SORTED)))
matrix = np.zeros((len(CODONS), len(AA_STR_SORTED)))

for i, codon in enumerate(CODONS):
try:
aa = translate_sequences([codon])[0]
aa_idx = AA_TOKEN_STR_SORTED.index(aa)
matrix[i, aa_idx] = 1
except ValueError: # Handle STOP codon
pass
else:
aa_idx = AA_STR_SORTED.index(aa)
matrix[i, aa_idx] = 1

return matrix

Expand Down
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 20
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dasm_output
Binary file not shown.
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 1
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dnsm_output
Binary file not shown.
3 changes: 1 addition & 2 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam import pretrained
import random

Expand Down Expand Up @@ -170,7 +169,7 @@ def dasm_model():
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=MAX_AA_TOKEN_IDX + 1,
output_dim=20,
)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_backward_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from netam.framework import load_crepe
from netam.sequences import set_wt_to_nan


# The outputs used for this test are produced by running
matsen marked this conversation as resolved.
Show resolved Hide resolved
# `test_backward_compat_copy.py` on the wd-old-model-runner branch.
# This is to ensure that we can still load older crepes, even if we change the
# dimensions of model layers, as we did with the Embedding layer in
# https://github.com/matsengrp/netam/pull/92.
def test_old_model_outputs():
example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS"
dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint")
dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint")

dasm_vals = torch.nan_to_num(
set_wt_to_nan(
torch.load("tests/old_models/dasm_output", weights_only=True), example_seq
),
0.0,
)
dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True)

dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0)
dnsm_result = dnsm_crepe([example_seq])[0]
assert torch.allclose(dasm_result, dasm_vals)
assert torch.allclose(dnsm_result, dnsm_vals)
4 changes: 2 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of():


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]])
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE^", "FGY"]
assert aa_strings == ["ACDEX^", "FGY^"]
3 changes: 1 addition & 2 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dasm import (
DASMBurrito,
Expand All @@ -30,7 +29,7 @@ def dasm_burrito(pcp_df):
d_model_per_head=4,
dim_feedforward=256,
layer_count=2,
output_dim=MAX_AA_TOKEN_IDX + 1,
output_dim=20,
)

burrito = DASMBurrito(
Expand Down
Loading
Loading