Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Dec 13, 2024
1 parent 867575f commit 7fccd1e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 8 additions & 3 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,12 @@ def predict(self, representation: Tensor):


class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct):
"""Here the beta parameter is fixed at 0.3."""
"""
This version of the model uses an ESM model to embed the amino acid
sequences as an input to the model rather than training an embedding.
PIE stands for Protein Input Embedding.
"""

def __init__(
self,
Expand All @@ -742,7 +747,7 @@ def __init__(
super().__init__(
nhead=self.pie.num_heads,
d_model_per_head=self.pie.d_model_per_head,
# TODO this is hard coded as per Antoine.
# The transformer paper uses 4 * d_model for the feedforward layer.
dim_feedforward=self.pie.d_model * 4,
layer_count=layer_count,
dropout_prob=dropout_prob,
Expand All @@ -760,7 +765,7 @@ def hyperparameters(self):

def to(self, device):
super().to(device)
self.pie.model = self.pie.model.to(device)
self.pie = self.pie.to(device)
return self

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
Expand Down
5 changes: 4 additions & 1 deletion netam/protein_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self, model_name: str):
Args:
model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D").
device (str): Device to run the model on.
"""
self.model, self.alphabet = pretrained.load_model_and_alphabet(model_name)
self.batch_converter = self.alphabet.get_batch_converter()
Expand All @@ -64,6 +63,10 @@ def d_model_per_head(self):
def num_layers(self):
return self.model.num_layers

def to(self, device):
self.model.to(device)
return self

def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""Embeds a batch of sequences.
Expand Down

0 comments on commit 7fccd1e

Please sign in to comment.