-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore compatibility with models handling fewer tokens #104
Merged
Merged
Changes from 14 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
d2d4e30
add embedding_dim for wiggle model
willdumm 006bf83
fix tests
willdumm 0025351
format and lint
willdumm ed60f8c
add old models
willdumm fa5e445
add start of new tests
willdumm 24fe710
fix old model test
willdumm 17918c4
add old outputs
willdumm 2a16057
add explanatory comment
willdumm 516376d
format and lint
willdumm 2fc3808
add check for TODOs
willdumm 79a596a
fix tests
willdumm ac9b484
Update build-and-test.yml
willdumm ef0a2fb
add aa ambig index
willdumm f922f3d
rewrite TODO check
willdumm 666b765
fix output dimension
willdumm 057ad67
format
willdumm 181f5cd
refactor AA_TOKEN_STR_SORTED
willdumm d4956bb
format
willdumm 9d15324
add embedding_dim to single model
willdumm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
encoder_class: PlaceholderEncoder | ||
encoder_parameters: {} | ||
model_class: TransformerBinarySelectionModelWiggleAct | ||
model_hyperparameters: | ||
d_model_per_head: 4 | ||
dim_feedforward: 64 | ||
dropout_prob: 0.1 | ||
layer_count: 3 | ||
nhead: 4 | ||
output_dim: 20 | ||
serialization_version: 0 | ||
training_hyperparameters: | ||
batch_size: 1024 | ||
learning_rate: 0.001 | ||
min_learning_rate: 1.0e-06 | ||
optimizer_name: RMSprop | ||
weight_decay: 1.0e-06 |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
encoder_class: PlaceholderEncoder | ||
encoder_parameters: {} | ||
model_class: TransformerBinarySelectionModelWiggleAct | ||
model_hyperparameters: | ||
d_model_per_head: 4 | ||
dim_feedforward: 64 | ||
dropout_prob: 0.1 | ||
layer_count: 3 | ||
nhead: 4 | ||
output_dim: 1 | ||
serialization_version: 0 | ||
training_hyperparameters: | ||
batch_size: 1024 | ||
learning_rate: 0.001 | ||
min_learning_rate: 1.0e-06 | ||
optimizer_name: RMSprop | ||
weight_decay: 1.0e-06 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
|
||
from netam.framework import load_crepe | ||
from netam.sequences import set_wt_to_nan | ||
|
||
|
||
# The outputs used for this test are produced by running | ||
matsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# `test_backward_compat_copy.py` on the wd-old-model-runner branch. | ||
def test_old_model_outputs(): | ||
example_seq = "QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTVYLQMNSLRAEDTAVYYCAREGHSNYPYYYYYMDVWGKGTTVTVSS" | ||
dasm_crepe = load_crepe("tests/old_models/dasm_13k-v1jaffe+v1tang-joint") | ||
dnsm_crepe = load_crepe("tests/old_models/dnsm_13k-v1jaffe+v1tang-joint") | ||
|
||
dasm_vals = torch.nan_to_num( | ||
set_wt_to_nan( | ||
torch.load("tests/old_models/dasm_output", weights_only=True), example_seq | ||
), | ||
0.0, | ||
) | ||
dnsm_vals = torch.load("tests/old_models/dnsm_output", weights_only=True) | ||
|
||
dasm_result = torch.nan_to_num(dasm_crepe([example_seq])[0], 0.0) | ||
dnsm_result = dnsm_crepe([example_seq])[0] | ||
assert torch.allclose(dasm_result, dasm_vals) | ||
assert torch.allclose(dnsm_result, dnsm_vals) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how to read
consider_sites
. I think these aremeaningful_sites
? oraa_sites
?I guess I'm a little confused about what this section of code is doing. It seems like we're making a big empty tensor, partially filling it, then trimming it back again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed to
model_valid_sites
.This code is applying the model to an aa string, only considering the entries in that string that the model was trained to handle (marked by
model_valid_sites
). We construct a big emptyresult
tensor ofnan
s, then feed the amino acid string to the model, stripping out sites containing tokens that aren't valid for the model, then put the model's outputs into theresult
tensor at the sites that contain tokens that we fed to the model.For instance, if we have a heavy chain sequence
QQQQ^
, then the model will seeQQQQ
, but the output will have first dimension size 5, matching the input sequence length, and the last output will benan
s.If we add a start token in the future, then we can still feed sequences containing it to our model. For instance, if we have
1QQQQ
, the old model not handling the new1
token will seeQQQQ
, but the returned model output will containnan
s in the first site, and will have first dimension size 5, matching the input sequence length.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
Let's make a free function that encapsulates the test
with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an
embedding_dim
.(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a function that takes
hyperparameters
andaa_idxs
and returnsmode_valid_sites
?The only subclass that doesn't have
embedding_dim
right now is the Single model. I could just add an embedding dimension for that model that's pinned toMAX_AA_TOKEN_IDX
so that it always grows with new tokens. Then I'd be able to remove this test.Would that be preferable to making a free function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed this by adding an
embedding_dim
to the Single model.