Skip to content

Commit

Permalink
bring in the 🥧
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Dec 13, 2024
1 parent a15eba0 commit 2349c7a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 20 deletions.
54 changes: 54 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.protein_embedders import ESMEmbedder
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
Expand Down Expand Up @@ -726,6 +727,59 @@ def predict(self, representation: Tensor):
return wiggle(super().predict(representation), beta)


class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct):
"""Here the beta parameter is fixed at 0.3."""

def __init__(self,
esm_model_name: str,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
self.pie = ESMEmbedder(model_name=esm_model_name)
super().__init__(
nhead=self.pie.num_heads,
d_model_per_head=self.pie.d_model_per_head,
# TODO this is hard coded as per Antoine.
dim_feedforward=self.pie.d_model*4,
layer_count=layer_count,
dropout_prob=dropout_prob,
output_dim=output_dim,
)


def to(self, device):
super().to(device)
self.pie.model = self.pie.model.to(device)
return self

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Represent an index-encoded parent sequence in the model's embedding space.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the
indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid
amino acid sites.
Returns:
The embedded parent sequences, in a tensor of shape (B, L, E),
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.pie.embed_batch(
amino_acid_indices
) * math.sqrt(self.d_model)
# Have to do the permutation because the positional encoding expects the
# sequence length to be the first dimension.
embedded_amino_acids = self.pos_encoder(
embedded_amino_acids.permute(1, 0, 2)
).permute(1, 0, 2)

# To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439
return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

Expand Down
73 changes: 53 additions & 20 deletions netam/protein_embedders.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,49 @@
import torch

from esm import pretrained

from netam.common import aa_strs_from_idx_tensor


def pad_embeddings(embeddings, desired_length):
"""
Pads a batch of embeddings to a specified sequence length with zeros.
Args:
embeddings (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).
desired_length (int): The length to which each sequence should be padded.
Returns:
torch.Tensor: A tensor of shape (batch_size, desired_length, embedding_dim).
"""
batch_size, seq_len, embedding_dim = embeddings.size()

if desired_length <= 0:
raise ValueError("desired_length must be a positive integer")

# Truncate seq_len if it exceeds desired_length
if seq_len > desired_length:
embeddings = embeddings[:, :desired_length, :]
seq_len = desired_length

device = embeddings.device
padded_embeddings = torch.zeros(
(batch_size, desired_length, embedding_dim), device=device
)
padded_embeddings[:, :seq_len, :] = embeddings
return padded_embeddings


class ESMEmbedder:
def __init__(self, model_name: str, max_aa_seq_len: int, device: str = "cpu"):
def __init__(self, model_name: str, device: str = "cpu"):
"""
Initializes the ESMEmbedder object.
Args:
model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D").
max_aa_seq_len (int): Maximum sequence length allowed.
device (str): Device to run the model on.
"""
self.device = device
self.max_aa_seq_len = max_aa_seq_len
self.model, self.alphabet = pretrained.load_model_and_alphabet(model_name)
self.model = self.model.to(device)
self.batch_converter = self.alphabet.get_batch_converter()
Expand All @@ -24,7 +54,7 @@ def num_heads(self):

@property
def d_model(self):
return self.model.layers[0].self_attn.hidden_size
return self.model.embed_dim

@property
def d_model_per_head(self):
Expand All @@ -34,7 +64,7 @@ def d_model_per_head(self):
def num_layers(self):
return self.model.num_layers

def embed_batch(self, sequences: list[str]) -> torch.Tensor:
def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""
Embeds a batch of sequences.
Expand All @@ -43,27 +73,30 @@ def embed_batch(self, sequences: list[str]) -> torch.Tensor:
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
Raises:
ValueError: If any sequence exceeds max_aa_seq_len.
"""
for seq in sequences:
if len(seq) > self.max_aa_seq_len:
raise ValueError(
f"Sequence length {len(seq)} exceeds max_aa_seq_len {self.max_aa_seq_len}"
)

named_sequences = [(f"seq_{i}", seq) for i, seq in enumerate(sequences)]
batch_labels, batch_strs, batch_tokens = self.batch_converter(named_sequences)
batch_tokens = batch_tokens.to(self.device)
with torch.no_grad():
results = self.model(batch_tokens, repr_layers=[self.num_layers])
embeddings = results["representations"][self.num_layers]

batch_size, seq_len, embedding_dim = embeddings.size()
padded_embeddings = torch.zeros(
(batch_size, self.max_aa_seq_len, embedding_dim), device=self.device
)
padded_embeddings[:, :seq_len, :] = embeddings
return embeddings

def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor:
"""
Embeds a batch of netam amino acid indices.
return padded_embeddings
For now, we detokenize the amino acid indices and then use embed_sequence_list.
Args:
amino_acid_indices (torch.Tensor): A tensor of shape (batch_size, max_aa_seq_len).
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
"""
sequences = aa_strs_from_idx_tensor(amino_acid_indices)
embedding = self.embed_sequence_list(sequences)
desired_length = amino_acid_indices.size(1)
padded_embedding = pad_embeddings(embedding, desired_length)
return padded_embedding

0 comments on commit 2349c7a

Please sign in to comment.