Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore compatibility with models handling fewer tokens #104

Merged
merged 19 commits into from
Jan 14, 2025
Merged
9 changes: 5 additions & 4 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,12 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# # TODO
# The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(
# dataset, **optimization_kwargs
# )
burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
return burrito.serial_find_optimal_branch_lengths(
dataset, **optimization_kwargs
)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
Expand Down
5 changes: 5 additions & 0 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def load_crepe(prefix, device=None):
f"Model class '{model_class_name}' not found in 'models' module."
)

if issubclass(model_class, models.AbstractBinarySelectionModel):
if "embedding_dim" not in config["model_hyperparameters"]:
# Assume the model is from before any new tokens were added, so 21
config["model_hyperparameters"]["embedding_dim"] = 21

model = model_class(**config["model_hyperparameters"])

model_state_path = f"{prefix}.pth"
Expand Down
26 changes: 22 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,26 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension.
consider_sites = aa_idxs < self.hyperparameters["embedding_dim"]
# TODO test with DNSM that this is really how the outputs are shaped
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
result = torch.full(
(len(aa_str), self.hyperparameters["output_dim"]),
float("nan"),
device=self.device,
)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
final_out = torch.exp(model_out)
model_out = self(
aa_idxs[consider_sites].unsqueeze(0), mask[consider_sites].unsqueeze(0)
).squeeze(0)
result[consider_sites] = torch.exp(model_out)[: consider_sites.sum()]

return final_out[: len(aa_str)]
return result


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand All @@ -613,16 +628,18 @@ def __init__(
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
embedding_dim: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
# Note that d_model has to be divisible by nhead, so we make that
# automatic here.
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.embedding_dim = embedding_dim
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(self.embedding_dim, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand All @@ -642,6 +659,7 @@ def hyperparameters(self):
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
"embedding_dim": self.embedding_dim,
}

def init_weights(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
# ambiguous must remain last. It is assumed elsewhere that the max index
# denotes the ambiguous base
AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X"
# Must add new tokens to the end of this string.
# TODO It is assumed elsewhere that the max index denotes the ambiguous base
AA_TOKEN_STR_SORTED = AA_STR_SORTED + "X" + RESERVED_TOKENS
matsen marked this conversation as resolved.
Show resolved Hide resolved

RESERVED_TOKEN_AA_BOUNDS = (
min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS),
Expand Down
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dasm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 20
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/old_models/dnsm_13k-v1jaffe+v1tang-joint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
encoder_class: PlaceholderEncoder
encoder_parameters: {}
model_class: TransformerBinarySelectionModelWiggleAct
model_hyperparameters:
d_model_per_head: 4
dim_feedforward: 64
dropout_prob: 0.1
layer_count: 3
nhead: 4
output_dim: 1
serialization_version: 0
training_hyperparameters:
batch_size: 1024
learning_rate: 0.001
min_learning_rate: 1.0e-06
optimizer_name: RMSprop
weight_decay: 1.0e-06
4 changes: 2 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of():


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]])
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE^", "FGY"]
assert aa_strings == ["ACDEX^", "FGY^"]
3 changes: 1 addition & 2 deletions tests/test_dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.common import aa_idx_tensor_of_str_ambig, force_spawn
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dnsm import DNSMBurrito, DNSMDataset


def test_aa_idx_tensor_of_str_ambig():
input_seq = "ACX"
expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int)
expected_output = torch.tensor([0, 1, 20], dtype=torch.int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable please.

output = aa_idx_tensor_of_str_ambig(input_seq)
assert torch.equal(output, expected_output)

Expand Down
Loading