Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore compatibility with models handling fewer tokens #104

Merged
merged 19 commits into from
Jan 14, 2025

Conversation

willdumm
Copy link
Contributor

This PR addresses #102. It still needs a unit test (started) and ideally at least a plan for how to apply the neutral model when we add additional tokens (there's some discussion in the issue about this).

from netam.common import aa_idx_tensor_of_str_ambig, force_spawn
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dnsm import DNSMBurrito, DNSMDataset


def test_aa_idx_tensor_of_str_ambig():
input_seq = "ACX"
expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int)
expected_output = torch.tensor([0, 1, 20], dtype=torch.int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable please.

Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! A few comments / questions here.

netam/models.py Outdated
# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension.
if "embedding_dim" in self.hyperparameters:
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to read consider_sites. I think these are meaningful_sites? or aa_sites?

I guess I'm a little confused about what this section of code is doing. It seems like we're making a big empty tensor, partially filling it, then trimming it back again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed to model_valid_sites.

This code is applying the model to an aa string, only considering the entries in that string that the model was trained to handle (marked by model_valid_sites). We construct a big empty result tensor of nans, then feed the amino acid string to the model, stripping out sites containing tokens that aren't valid for the model, then put the model's outputs into the result tensor at the sites that contain tokens that we fed to the model.

For instance, if we have a heavy chain sequence QQQQ^, then the model will see QQQQ, but the output will have first dimension size 5, matching the input sequence length, and the last output will be nans.

If we add a start token in the future, then we can still feed sequences containing it to our model. For instance, if we have 1QQQQ, the old model not handling the new 1 token will see QQQQ, but the returned model output will contain nans in the first site, and will have first dimension size 5, matching the input sequence length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Let's make a free function that encapsulates the test

if "embedding_dim" in self.hyperparameters:

with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an embedding_dim.

(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a function that takes hyperparameters and aa_idxs and returns mode_valid_sites?

The only subclass that doesn't have embedding_dim right now is the Single model. I could just add an embedding dimension for that model that's pinned to MAX_AA_TOKEN_IDX so that it always grows with new tokens. Then I'd be able to remove this test.

Would that be preferable to making a free function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this by adding an embedding_dim to the Single model.

netam/sequences.py Outdated Show resolved Hide resolved
tests/test_backward_compat.py Show resolved Hide resolved
@willdumm
Copy link
Contributor Author

willdumm commented Jan 14, 2025

This now requires the companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/77, which reverts the DASM output dimension to 20.

@willdumm willdumm marked this pull request as ready for review January 14, 2025 20:06
@willdumm willdumm changed the title 102 token back compat Restore compatibility with models handling fewer tokens Jan 14, 2025
Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one little pending suggestion!

netam/models.py Outdated
# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension.
if "embedding_dim" in self.hyperparameters:
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Let's make a free function that encapsulates the test

if "embedding_dim" in self.hyperparameters:

with a nice name that we can look for when we strip this if out. I'm anticipating that we'll have a release in which we assume that all models have an embedding_dim.

(In other news, we could just go in and add an embedding dim into all the old models and assume they have them, right?)

@willdumm willdumm merged commit 5cac227 into main Jan 14, 2025
2 checks passed
@willdumm willdumm deleted the 102-token-back-compat branch January 14, 2025 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants