Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore compatibility with models handling fewer tokens #104

Merged
merged 19 commits into from
Jan 14, 2025
Merged
6 changes: 6 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ jobs:
cache-environment: false
post-cleanup: 'none'

- name: Check for TODOs
shell: bash -l {0}
run: |
cd main
make checktodo

- name: Check format
shell: bash -l {0}
run: |
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ checkformat:
docformatter --check --black --recursive netam tests
black --check netam tests

checktodo:
grep -rq --include=\*.{py,Snakemake} "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0

lint:
flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore

Expand All @@ -27,4 +30,4 @@ notebooks:
jupyter nbconvert --to notebook --execute "$$nb" --output _ignore/"$$(basename $$nb)"; \
done

.PHONY: install test notebooks format lint docs
.PHONY: install test notebooks format lint docs checktodo
2 changes: 1 addition & 1 deletion netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(
# dataset, **optimization_kwargs
Expand Down
5 changes: 5 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def load_crepe(prefix, device=None):
f"Model class '{model_class_name}' not found in 'models' module."
)

if issubclass(model_class, models.TransformerBinarySelectionModelLinAct):
if "embedding_dim" not in config["model_hyperparameters"]:
# Assume the model is from before any new tokens were added, so 21
config["model_hyperparameters"]["embedding_dim"] = 21

model = model_class(**config["model_hyperparameters"])

model_state_path = f"{prefix}.pth"
Expand Down
28 changes: 24 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,28 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension.
if "embedding_dim" in self.hyperparameters:
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to read consider_sites. I think these are meaningful_sites? or aa_sites?

I guess I'm a little confused about what this section of code is doing. It seems like we're making a big empty tensor, partially filling it, then trimming it back again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed to model_valid_sites.

This code is applying the model to an aa string, only considering the entries in that string that the model was trained to handle (marked by model_valid_sites). We construct a big empty result tensor of nans, then feed the amino acid string to the model, stripping out sites containing tokens that aren't valid for the model, then put the model's outputs into the result tensor at the sites that contain tokens that we fed to the model.

For instance, if we have a heavy chain sequence QQQQ^, then the model will see QQQQ, but the output will have first dimension size 5, matching the input sequence length, and the last output will be nans.

If we add a start token in the future, then we can still feed sequences containing it to our model. For instance, if we have 1QQQQ, the old model not handling the new 1 token will see QQQQ, but the returned model output will contain nans in the first site, and will have first dimension size 5, matching the input sequence length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Let's make a free function that encapsulates the test

if "embedding_dim" in self.hyperparameters:

with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an embedding_dim.

(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a function that takes hyperparameters and aa_idxs and returns mode_valid_sites?

The only subclass that doesn't have embedding_dim right now is the Single model. I could just add an embedding dimension for that model that's pinned to MAX_AA_TOKEN_IDX so that it always grows with new tokens. Then I'd be able to remove this test.

Would that be preferable to making a free function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this by adding an embedding_dim to the Single model.

else:
consider_sites = torch.ones_like(aa_idxs, dtype=torch.bool)
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
result = torch.full(
(len(aa_str), self.hyperparameters["output_dim"]),
float("nan"),
device=self.device,
)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
final_out = torch.exp(model_out)
model_out = self(
aa_idxs[consider_sites].unsqueeze(0), mask[consider_sites].unsqueeze(0)
).squeeze(0)
result[consider_sites] = torch.exp(model_out)[: consider_sites.sum()]

return final_out[: len(aa_str)]
return result


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand All @@ -613,16 +630,18 @@ def __init__(
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
embedding_dim: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
# Note that d_model has to be divisible by nhead, so we make that
# automatic here.
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.embedding_dim = embedding_dim
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(self.embedding_dim, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand All @@ -642,6 +661,7 @@ def hyperparameters(self):
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
"embedding_dim": self.embedding_dim,
}

def init_weights(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
# ambiguous must remain last. It is assumed elsewhere that the max index
# denotes the ambiguous base
AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X"
# Must add new tokens to the end of this string.
AA_TOKEN_STR_SORTED = AA_STR_SORTED + "X" + RESERVED_TOKENS
matsen marked this conversation as resolved.
Show resolved Hide resolved
AA_AMBIG_IDX = len(AA_STR_SORTED)

RESERVED_TOKEN_AA_BOUNDS = (
min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
Expand Down
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 20
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dasm_output
Binary file not shown.
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 1
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file added tests/old_models/dnsm_output
Binary file not shown.
25 changes: 25 additions & 0 deletions tests/test_backward_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from netam.framework import load_crepe
from netam.sequences import set_wt_to_nan


# The outputs used for this test are produced by running
matsen marked this conversation as resolved.
Show resolved Hide resolved
# `test_backward_compat_copy.py` on the wd-old-model-runner branch.
def test_old_model_outputs():
example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS"
dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint")
dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint")

dasm_vals = torch.nan_to_num(
set_wt_to_nan(
torch.load("tests/old_models/dasm_output", weights_only=True), example_seq
),
0.0,
)
dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True)

dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0)
dnsm_result = dnsm_crepe([example_seq])[0]
assert torch.allclose(dasm_result, dasm_vals)
assert torch.allclose(dnsm_result, dnsm_vals)
4 changes: 2 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of():


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]])
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE^", "FGY"]
assert aa_strings == ["ACDEX^", "FGY^"]
4 changes: 2 additions & 2 deletions tests/test_dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.common import aa_idx_tensor_of_str_ambig, force_spawn
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dnsm import DNSMBurrito, DNSMDataset
from netam.sequences import AA_AMBIG_IDX


def test_aa_idx_tensor_of_str_ambig():
input_seq = "ACX"
expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int)
expected_output = torch.tensor([0, 1, AA_AMBIG_IDX], dtype=torch.int)
output = aa_idx_tensor_of_str_ambig(input_seq)
assert torch.equal(output, expected_output)

Expand Down
Loading