Skip to content

Commit

Permalink
Preliminary implementation of PIE model (#96)
Browse files Browse the repository at this point in the history
Preliminary implementation of DASM model which uses a protein embedding as input, which hilariously back-translates our AA indexing into a string and passes it to ESM for pre-embedding. Slow!
  • Loading branch information
matsen authored Dec 16, 2024
1 parent 22c8873 commit 9c78482
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 1 deletion.
20 changes: 20 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@ def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
return torch.tensor(mask, dtype=torch.bool)


def aa_strs_from_idx_tensor(idx_tensor):
"""Convert a tensor of amino acid indices back to a list of amino acid strings.
Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_STR_SORTED_AMBIG.
Returns:
List[str]: A list of amino acid strings with trailing 'X's removed.
"""
idx_tensor = idx_tensor.cpu()

aa_str_list = []
for row in idx_tensor:
aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list


def assert_pcp_valid(parent, child, aa_mask=None):
"""Check that the parent-child pairs are valid.
Expand Down
68 changes: 68 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.protein_embedders import ESMEmbedder
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
Expand Down Expand Up @@ -726,6 +727,73 @@ def predict(self, representation: Tensor):
return wiggle(super().predict(representation), beta)


class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct):
"""This version of the model uses an ESM model to embed the amino acid sequences as
an input to the model rather than training an embedding.
PIE stands for Protein Input Embedding.
"""

def __init__(
self,
esm_model_name: str,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
self.esm_model_name = esm_model_name
self.pie = ESMEmbedder(model_name=esm_model_name)
super().__init__(
nhead=self.pie.num_heads,
d_model_per_head=self.pie.d_model_per_head,
# The transformer paper uses 4 * d_model for the feedforward layer.
dim_feedforward=self.pie.d_model * 4,
layer_count=layer_count,
dropout_prob=dropout_prob,
output_dim=output_dim,
)

@property
def hyperparameters(self):
return {
"esm_model_name": self.esm_model_name,
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
}

def to(self, device):
super().to(device)
self.pie = self.pie.to(device)
return self

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Represent an index-encoded parent sequence in the model's embedding space.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the
indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid
amino acid sites.
Returns:
The embedded parent sequences, in a tensor of shape (B, L, E),
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt(
self.d_model
)
# Have to do the permutation because the positional encoding expects the
# sequence length to be the first dimension.
embedded_amino_acids = self.pos_encoder(
embedded_amino_acids.permute(1, 0, 2)
).permute(1, 0, 2)

# To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439
return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

Expand Down
103 changes: 103 additions & 0 deletions netam/protein_embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch

from esm import pretrained

from netam.common import aa_strs_from_idx_tensor


def pad_embeddings(embeddings, desired_length):
"""Pads a batch of embeddings to a specified sequence length with zeros.
Args:
embeddings (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).
desired_length (int): The length to which each sequence should be padded.
Returns:
torch.Tensor: A tensor of shape (batch_size, desired_length, embedding_dim).
"""
batch_size, seq_len, embedding_dim = embeddings.size()

if desired_length <= 0:
raise ValueError("desired_length must be a positive integer")

# Truncate seq_len if it exceeds desired_length
if seq_len > desired_length:
embeddings = embeddings[:, :desired_length, :]
seq_len = desired_length

device = embeddings.device
padded_embeddings = torch.zeros(
(batch_size, desired_length, embedding_dim), device=device
)
padded_embeddings[:, :seq_len, :] = embeddings
return padded_embeddings


class ESMEmbedder:
def __init__(self, model_name: str):
"""Initializes the ESMEmbedder object.
Args:
model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D").
"""
self.model, self.alphabet = pretrained.load_model_and_alphabet(model_name)
self.batch_converter = self.alphabet.get_batch_converter()

@property
def device(self):
return next(self.model.parameters()).device

@property
def num_heads(self):
return self.model.layers[0].self_attn.num_heads

@property
def d_model(self):
return self.model.embed_dim

@property
def d_model_per_head(self):
return self.d_model // self.num_heads

@property
def num_layers(self):
return self.model.num_layers

def to(self, device):
self.model.to(device)
return self

def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""Embeds a batch of sequences.
Args:
sequences (list[str]): List of amino acid sequences.
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
"""
named_sequences = [(f"seq_{i}", seq) for i, seq in enumerate(sequences)]
batch_labels, batch_strs, batch_tokens = self.batch_converter(named_sequences)
batch_tokens = batch_tokens.to(self.device)
with torch.no_grad():
results = self.model(batch_tokens, repr_layers=[self.num_layers])
embeddings = results["representations"][self.num_layers]

return embeddings

def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor:
"""Embeds a batch of netam amino acid indices.
For now, we detokenize the amino acid indices and then use embed_sequence_list.
Args:
amino_acid_indices (torch.Tensor): A tensor of shape (batch_size, max_aa_seq_len).
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
"""
sequences = aa_strs_from_idx_tensor(amino_acid_indices)
embedding = self.embed_sequence_list(sequences)
desired_length = amino_acid_indices.size(1)
padded_embedding = pad_embeddings(embedding, desired_length)
return padded_embedding
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
python_requires=">=3.9,<3.12",
install_requires=[
"biopython",
"fair-esm",
"natsort",
"optuna",
"pandas",
Expand Down
13 changes: 12 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch

from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of
from netam.common import (
nt_mask_tensor_of,
aa_mask_tensor_of,
codon_mask_tensor_of,
aa_strs_from_idx_tensor,
)


def test_mask_tensor_of():
Expand All @@ -25,3 +30,9 @@ def test_codon_mask_tensor_of():
expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool)
output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5)
assert torch.equal(output, expected_output)


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20], [4, 5, 19, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE", "FGY"]

0 comments on commit 9c78482

Please sign in to comment.