-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore compatibility with models handling fewer tokens #104
Conversation
tests/test_dnsm.py
Outdated
from netam.common import aa_idx_tensor_of_str_ambig, force_spawn | ||
from netam.models import TransformerBinarySelectionModelWiggleAct | ||
from netam.dnsm import DNSMBurrito, DNSMDataset | ||
|
||
|
||
def test_aa_idx_tensor_of_str_ambig(): | ||
input_seq = "ACX" | ||
expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int) | ||
expected_output = torch.tensor([0, 1, 20], dtype=torch.int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! A few comments / questions here.
netam/models.py
Outdated
# Here we're ignoring sites containing tokens that have index greater | ||
# than the embedding dimension. | ||
if "embedding_dim" in self.hyperparameters: | ||
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how to read consider_sites
. I think these are meaningful_sites
? or aa_sites
?
I guess I'm a little confused about what this section of code is doing. It seems like we're making a big empty tensor, partially filling it, then trimming it back again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed to model_valid_sites
.
This code is applying the model to an aa string, only considering the entries in that string that the model was trained to handle (marked by model_valid_sites
). We construct a big empty result
tensor of nan
s, then feed the amino acid string to the model, stripping out sites containing tokens that aren't valid for the model, then put the model's outputs into the result
tensor at the sites that contain tokens that we fed to the model.
For instance, if we have a heavy chain sequence QQQQ^
, then the model will see QQQQ
, but the output will have first dimension size 5, matching the input sequence length, and the last output will be nan
s.
If we add a start token in the future, then we can still feed sequences containing it to our model. For instance, if we have 1QQQQ
, the old model not handling the new 1
token will see QQQQ
, but the returned model output will contain nan
s in the first site, and will have first dimension size 5, matching the input sequence length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
Let's make a free function that encapsulates the test
if "embedding_dim" in self.hyperparameters:
with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an embedding_dim
.
(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a function that takes hyperparameters
and aa_idxs
and returns mode_valid_sites
?
The only subclass that doesn't have embedding_dim
right now is the Single model. I could just add an embedding dimension for that model that's pinned to MAX_AA_TOKEN_IDX
so that it always grows with new tokens. Then I'd be able to remove this test.
Would that be preferable to making a free function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed this by adding an embedding_dim
to the Single model.
This now requires the companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/77, which reverts the DASM output dimension to 20. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one little pending suggestion!
netam/models.py
Outdated
# Here we're ignoring sites containing tokens that have index greater | ||
# than the embedding dimension. | ||
if "embedding_dim" in self.hyperparameters: | ||
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
Let's make a free function that encapsulates the test
if "embedding_dim" in self.hyperparameters:
with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an embedding_dim
.
(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)
This PR addresses #102. It still needs a unit test (started) and ideally at least a plan for how to apply the neutral model when we add additional tokens (there's some discussion in the issue about this).