diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index d6d41059..0006aac6 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -34,6 +34,12 @@ jobs: cache-environment: false post-cleanup: 'none' + - name: Check for TODOs + shell: bash -l {0} + run: | + cd main + make checktodo + - name: Check format shell: bash -l {0} run: | diff --git a/Makefile b/Makefile index 5ccd44c7..c4f16f5e 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,9 @@ checkformat: docformatter --check --black --recursive netam tests black --check netam tests +checktodo: + grep -rq --include=\*.{py,Snakemake} "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0 + lint: flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore @@ -27,4 +30,4 @@ notebooks: jupyter nbconvert --to notebook --execute "$$nb" --output _ignore/"$$(basename $$nb)"; \ done -.PHONY: install test notebooks format lint docs +.PHONY: install test notebooks format lint docs checktodo diff --git a/netam/common.py b/netam/common.py index 4c4fb40b..3bf34169 100644 --- a/netam/common.py +++ b/netam/common.py @@ -18,7 +18,7 @@ apply_aa_mask_to_nt_sequence, RESERVED_TOKEN_TRANSLATIONS, BASES, - AA_TOKEN_STR_SORTED, + TOKEN_STR_SORTED, ) BIG = 1e9 @@ -66,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str): character.""" try: return torch.tensor( - [AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int + [TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int ) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") @@ -126,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor): Args: idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing - indices into AA_TOKEN_STR_SORTED. + indices into TOKEN_STR_SORTED. Returns: List[str]: A list of amino acid strings with trailing 'X's removed. @@ -135,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor): aa_str_list = [] for row in idx_tensor: - aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist()) + aa_str = "".join(TOKEN_STR_SORTED[idx] for idx in row.tolist()) aa_str_list.append(aa_str.rstrip("X")) return aa_str_list diff --git a/netam/dnsm.py b/netam/dnsm.py index bd41e0af..6efeda85 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -163,9 +163,7 @@ def build_selection_matrix_from_parent(self, parent: str): """ parent = sequences.translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) - selection_matrix = torch.zeros( - (len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float - ) + selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) # Every "off-diagonal" entry of the selection matrix is set to the selection # factor, where "diagonal" means keeping the same amino acid. selection_matrix[:, :] = selection_factors[:, None] diff --git a/netam/dxsm.py b/netam/dxsm.py index 8539174e..80b6e7cd 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -29,6 +29,7 @@ nt_mutation_frequency, MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, + AA_AMBIG_IDX, ) @@ -70,7 +71,7 @@ def __init__( # to the ambiguous amino acid, and then will fill in the actual values # below. self.aa_parents_idxss = torch.full( - (pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX + (pcp_count, self.max_aa_seq_len), AA_AMBIG_IDX ) self.aa_children_idxss = self.aa_parents_idxss.clone() self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len)) @@ -303,7 +304,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. + # # The following can be used when one wants a better traceback. # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) # return burrito.serial_find_optimal_branch_lengths( # dataset, **optimization_kwargs diff --git a/netam/framework.py b/netam/framework.py index 93962018..b4db2d7e 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -321,6 +321,11 @@ def load_crepe(prefix, device=None): f"Model class '{model_class_name}' not found in 'models' module." ) + if issubclass(model_class, models.TransformerBinarySelectionModelLinAct): + if "embedding_dim" not in config["model_hyperparameters"]: + # Assume the model is from before any new tokens were added, so 21 + config["model_hyperparameters"]["embedding_dim"] = 21 + model = model_class(**config["model_hyperparameters"]) model_state_path = f"{prefix}.pth" diff --git a/netam/models.py b/netam/models.py index 1edc8989..abd5fbca 100644 --- a/netam/models.py +++ b/netam/models.py @@ -597,11 +597,29 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: mask = aa_mask_tensor_of(aa_str) mask = mask.to(self.device) + # Here we're ignoring sites containing tokens that have index greater + # than the embedding dimension. If extra tokens have been added since + # this model was defined, they are stripped out before feeding the + # sequence to the model, and the returned selection factors will be NaN + # at sites containing those unrecognized tokens. + model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"] + if self.hyperparameters["output_dim"] == 1: + result = torch.full((len(aa_str),), float("nan"), device=self.device) + else: + result = torch.full( + (len(aa_str), self.hyperparameters["output_dim"]), + float("nan"), + device=self.device, + ) + with torch.no_grad(): - model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0) - final_out = torch.exp(model_out) + model_out = self( + aa_idxs[model_valid_sites].unsqueeze(0), + mask[model_valid_sites].unsqueeze(0), + ).squeeze(0) + result[model_valid_sites] = torch.exp(model_out)[: model_valid_sites.sum()] - return final_out[: len(aa_str)] + return result class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel): @@ -613,6 +631,7 @@ def __init__( layer_count: int, dropout_prob: float = 0.5, output_dim: int = 1, + embedding_dim: int = MAX_AA_TOKEN_IDX + 1, ): super().__init__() # Note that d_model has to be divisible by nhead, so we make that @@ -620,9 +639,10 @@ def __init__( self.d_model_per_head = d_model_per_head self.d_model = d_model_per_head * nhead self.nhead = nhead + self.embedding_dim = embedding_dim self.dim_feedforward = dim_feedforward self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob) - self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model) + self.amino_acid_embedding = nn.Embedding(self.embedding_dim, self.d_model) self.encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, @@ -642,6 +662,7 @@ def hyperparameters(self): "layer_count": self.encoder.num_layers, "dropout_prob": self.pos_encoder.dropout.p, "output_dim": self.linear.out_features, + "embedding_dim": self.embedding_dim, } def init_weights(self) -> None: @@ -748,14 +769,15 @@ def predict(self, representation: Tensor): class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" - def __init__(self, output_dim: int = 1): + def __init__(self, output_dim: int = 1, embedding_dim: int = MAX_AA_TOKEN_IDX + 1): super().__init__() self.single_value = nn.Parameter(torch.tensor(0.0)) self.output_dim = output_dim + self.embedding_dim = embedding_dim @property def hyperparameters(self): - return {"output_dim": self.output_dim} + return {"output_dim": self.output_dim, "embedding_dim": self.embedding_dim} def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: """Build a binary log selection matrix from an index-encoded parent sequence.""" diff --git a/netam/molevol.py b/netam/molevol.py index c089764d..2aef1c10 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -9,7 +9,7 @@ import torch from torch import Tensor, optim -from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX +from netam.sequences import CODON_AA_INDICATOR_MATRIX import netam.sequences as sequences @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of( """ assert len(parent) % 3 == 0 - assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1) + assert sel_matrix.shape == (len(parent) // 3, 20) parent_idxs = sequences.nt_idx_tensor_of_str(parent) child_idxs = sequences.nt_idx_tensor_of_str(child) diff --git a/netam/sequences.py b/netam/sequences.py index f9800ff7..3ddf6d88 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -17,15 +17,15 @@ NT_STR_SORTED = "".join(BASES) BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")} -# ambiguous must remain last. It is assumed elsewhere that the max index -# denotes the ambiguous base -AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X" +# Must add new tokens to the end of this string. +TOKEN_STR_SORTED = AA_STR_SORTED + "X" + RESERVED_TOKENS +AA_AMBIG_IDX = len(AA_STR_SORTED) RESERVED_TOKEN_AA_BOUNDS = ( - min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), - max(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), + min(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), + max(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), ) -MAX_AA_TOKEN_IDX = len(AA_TOKEN_STR_SORTED) - 1 +MAX_AA_TOKEN_IDX = len(TOKEN_STR_SORTED) - 1 CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] # Each token in RESERVED_TOKENS will appear once in aa strings, and three times @@ -48,7 +48,7 @@ def nt_idx_array_of_str(nt_str): def aa_idx_array_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return np.array([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str]) + return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -66,7 +66,7 @@ def nt_idx_tensor_of_str(nt_str): def aa_idx_tensor_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return torch.tensor([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str]) + return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -135,7 +135,7 @@ def translate_sequences(nt_sequences): def aa_index_of_codon(codon): """Return the index of the amino acid encoded by a codon.""" aa = translate_sequence(codon) - return AA_TOKEN_STR_SORTED.index(aa) + return TOKEN_STR_SORTED.index(aa) def generic_mutation_frequency(ambig_symb, parent, child): @@ -185,15 +185,16 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3): def generate_codon_aa_indicator_matrix(): """Generate a matrix that maps codons (rows) to amino acids (columns).""" - matrix = np.zeros((len(CODONS), len(AA_TOKEN_STR_SORTED))) + matrix = np.zeros((len(CODONS), len(AA_STR_SORTED))) for i, codon in enumerate(CODONS): try: aa = translate_sequences([codon])[0] - aa_idx = AA_TOKEN_STR_SORTED.index(aa) - matrix[i, aa_idx] = 1 except ValueError: # Handle STOP codon pass + else: + aa_idx = AA_STR_SORTED.index(aa) + matrix[i, aa_idx] = 1 return matrix diff --git a/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.pth b/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.pth new file mode 100644 index 00000000..fc251deb Binary files /dev/null and b/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.pth differ diff --git a/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml b/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml new file mode 100644 index 00000000..ed42859f --- /dev/null +++ b/tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml @@ -0,0 +1,17 @@ +encoder_class: PlaceholderEncoder +encoder_parameters: {} +model_class: TransformerBinarySelectionModelWiggleAct +model_hyperparameters: + d_model_per_head: 4 + dim_feedforward: 64 + dropout_prob: 0.1 + layer_count: 3 + nhead: 4 + output_dim: 20 +serialization_version: 0 +training_hyperparameters: + batch_size: 1024 + learning_rate: 0.001 + min_learning_rate: 1.0e-06 + optimizer_name: RMSprop + weight_decay: 1.0e-06 diff --git a/tests/old_models/dasm_output b/tests/old_models/dasm_output new file mode 100644 index 00000000..0db18502 Binary files /dev/null and b/tests/old_models/dasm_output differ diff --git a/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.pth b/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.pth new file mode 100644 index 00000000..4608c96a Binary files /dev/null and b/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.pth differ diff --git a/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml b/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml new file mode 100644 index 00000000..cf84623f --- /dev/null +++ b/tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml @@ -0,0 +1,17 @@ +encoder_class: PlaceholderEncoder +encoder_parameters: {} +model_class: TransformerBinarySelectionModelWiggleAct +model_hyperparameters: + d_model_per_head: 4 + dim_feedforward: 64 + dropout_prob: 0.1 + layer_count: 3 + nhead: 4 + output_dim: 1 +serialization_version: 0 +training_hyperparameters: + batch_size: 1024 + learning_rate: 0.001 + min_learning_rate: 1.0e-06 + optimizer_name: RMSprop + weight_decay: 1.0e-06 diff --git a/tests/old_models/dnsm_output b/tests/old_models/dnsm_output new file mode 100644 index 00000000..00b09ec0 Binary files /dev/null and b/tests/old_models/dnsm_output differ diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py index c62fbd13..61464ef1 100644 --- a/tests/test_ambiguous.py +++ b/tests/test_ambiguous.py @@ -11,7 +11,6 @@ load_pcp_df, add_shm_model_outputs_to_pcp_df, ) -from netam.sequences import MAX_AA_TOKEN_IDX from netam import pretrained import random @@ -170,7 +169,7 @@ def dasm_model(): d_model_per_head=4, dim_feedforward=256, layer_count=2, - output_dim=MAX_AA_TOKEN_IDX + 1, + output_dim=20, ) diff --git a/tests/test_backward_compat.py b/tests/test_backward_compat.py new file mode 100644 index 00000000..8b756876 --- /dev/null +++ b/tests/test_backward_compat.py @@ -0,0 +1,28 @@ +import torch + +from netam.framework import load_crepe +from netam.sequences import set_wt_to_nan + + +# The outputs used for this test are produced by running +# `test_backward_compat_copy.py` on the wd-old-model-runner branch. +# This is to ensure that we can still load older crepes, even if we change the +# dimensions of model layers, as we did with the Embedding layer in +# https://github.com/matsengrp/netam/pull/92. +def test_old_model_outputs(): + example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS" + dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint") + dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint") + + dasm_vals = torch.nan_to_num( + set_wt_to_nan( + torch.load("tests/old_models/dasm_output", weights_only=True), example_seq + ), + 0.0, + ) + dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True) + + dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0) + dnsm_result = dnsm_crepe([example_seq])[0] + assert torch.allclose(dasm_result, dasm_vals) + assert torch.allclose(dnsm_result, dnsm_vals) diff --git a/tests/test_common.py b/tests/test_common.py index 14787d5c..22b5fc67 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of(): def test_aa_strs_from_idx_tensor(): - aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]]) + aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 20, 20]]) aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor) - assert aa_strings == ["ACDE^", "FGY"] + assert aa_strings == ["ACDEX^", "FGY^"] diff --git a/tests/test_dasm.py b/tests/test_dasm.py index 2c749ed8..6bae92ee 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -8,7 +8,6 @@ crepe_exists, load_crepe, ) -from netam.sequences import MAX_AA_TOKEN_IDX from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dasm import ( DASMBurrito, @@ -30,7 +29,7 @@ def dasm_burrito(pcp_df): d_model_per_head=4, dim_feedforward=256, layer_count=2, - output_dim=MAX_AA_TOKEN_IDX + 1, + output_dim=20, ) burrito = DASMBurrito( diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index 18b449e7..e5ae9fe8 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -7,15 +7,15 @@ crepe_exists, load_crepe, ) -from netam.sequences import MAX_AA_TOKEN_IDX from netam.common import aa_idx_tensor_of_str_ambig, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dnsm import DNSMBurrito, DNSMDataset +from netam.sequences import AA_AMBIG_IDX def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" - expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int) + expected_output = torch.tensor([0, 1, AA_AMBIG_IDX], dtype=torch.int) output = aa_idx_tensor_of_str_ambig(input_seq) assert torch.equal(output, expected_output) diff --git a/tests/test_molevol.py b/tests/test_molevol.py index 3d4ab630..0b313fa1 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -7,7 +7,7 @@ from netam.sequences import ( nt_idx_tensor_of_str, translate_sequence, - AA_TOKEN_STR_SORTED, + AA_STR_SORTED, CODONS, NT_STR_SORTED, ) @@ -114,7 +114,7 @@ def test_check_csps(): def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps): """Original version of codon_to_aa_probabilities, used for testing.""" aa_probs = {} - for aa in AA_TOKEN_STR_SORTED: + for aa in AA_STR_SORTED: aa_probs[aa] = 0.0 # iterate through all possible child codons @@ -139,7 +139,7 @@ def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps): # since probabilities to STOP codon are dropped psum = sum(aa_probs.values()) - return torch.tensor([aa_probs[aa] / psum for aa in AA_TOKEN_STR_SORTED]) + return torch.tensor([aa_probs[aa] / psum for aa in AA_STR_SORTED]) def test_aaprob_of_mut_and_sub(): diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 761e92ef..f1a0c8d3 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -8,7 +8,7 @@ RESERVED_TOKENS, AA_STR_SORTED, RESERVED_TOKEN_REGEX, - AA_TOKEN_STR_SORTED, + TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, aa_onehot_tensor_of_str, @@ -23,11 +23,11 @@ def test_token_order(): # If we always add additional tokens to the end, then converting to indices # will not be affected when we have a proper aa string. - assert AA_TOKEN_STR_SORTED[: len(AA_STR_SORTED)] == AA_STR_SORTED + assert TOKEN_STR_SORTED[: len(AA_STR_SORTED)] == AA_STR_SORTED def test_token_replace(): - df = pd.DataFrame({"seq": ["AGCGTC" + token for token in AA_TOKEN_STR_SORTED]}) + df = pd.DataFrame({"seq": ["AGCGTC" + token for token in TOKEN_STR_SORTED]}) newseqs = df["seq"].str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) for seq, nseq in zip(df["seq"], newseqs): for token in RESERVED_TOKENS: