From 9c7848208204deab39e5bf210df6dd0dfecfef33 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Mon, 16 Dec 2024 03:11:08 -0800 Subject: [PATCH] Preliminary implementation of PIE model (#96) Preliminary implementation of DASM model which uses a protein embedding as input, which hilariously back-translates our AA indexing into a string and passes it to ESM for pre-embedding. Slow! --- netam/common.py | 20 +++++++ netam/models.py | 68 ++++++++++++++++++++++++ netam/protein_embedders.py | 103 +++++++++++++++++++++++++++++++++++++ setup.py | 1 + tests/test_common.py | 13 ++++- 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 netam/protein_embedders.py diff --git a/netam/common.py b/netam/common.py index 00979d47..7a1970df 100644 --- a/netam/common.py +++ b/netam/common.py @@ -109,6 +109,26 @@ def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): return torch.tensor(mask, dtype=torch.bool) +def aa_strs_from_idx_tensor(idx_tensor): + """Convert a tensor of amino acid indices back to a list of amino acid strings. + + Args: + idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing + indices into AA_STR_SORTED_AMBIG. + + Returns: + List[str]: A list of amino acid strings with trailing 'X's removed. + """ + idx_tensor = idx_tensor.cpu() + + aa_str_list = [] + for row in idx_tensor: + aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist()) + aa_str_list.append(aa_str.rstrip("X")) + + return aa_str_list + + def assert_pcp_valid(parent, child, aa_mask=None): """Check that the parent-child pairs are valid. diff --git a/netam/models.py b/netam/models.py index 9a67e6c8..957b52ad 100644 --- a/netam/models.py +++ b/netam/models.py @@ -10,6 +10,7 @@ from torch import Tensor from netam.hit_class import apply_multihit_correction +from netam.protein_embedders import ESMEmbedder from netam.common import ( MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, @@ -726,6 +727,73 @@ def predict(self, representation: Tensor): return wiggle(super().predict(representation), beta) +class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct): + """This version of the model uses an ESM model to embed the amino acid sequences as + an input to the model rather than training an embedding. + + PIE stands for Protein Input Embedding. + """ + + def __init__( + self, + esm_model_name: str, + layer_count: int, + dropout_prob: float = 0.5, + output_dim: int = 1, + ): + self.esm_model_name = esm_model_name + self.pie = ESMEmbedder(model_name=esm_model_name) + super().__init__( + nhead=self.pie.num_heads, + d_model_per_head=self.pie.d_model_per_head, + # The transformer paper uses 4 * d_model for the feedforward layer. + dim_feedforward=self.pie.d_model * 4, + layer_count=layer_count, + dropout_prob=dropout_prob, + output_dim=output_dim, + ) + + @property + def hyperparameters(self): + return { + "esm_model_name": self.esm_model_name, + "layer_count": self.encoder.num_layers, + "dropout_prob": self.pos_encoder.dropout.p, + "output_dim": self.linear.out_features, + } + + def to(self, device): + super().to(device) + self.pie = self.pie.to(device) + return self + + def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + """Represent an index-encoded parent sequence in the model's embedding space. + + Args: + amino_acid_indices: A tensor of shape (B, L) containing the + indices of parent AA sequences. + mask: A tensor of shape (B, L) representing the mask of valid + amino acid sites. + + Returns: + The embedded parent sequences, in a tensor of shape (B, L, E), + where E is the dimensionality of the embedding space. + """ + # Multiply by sqrt(d_model) to match the transformer paper. + embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt( + self.d_model + ) + # Have to do the permutation because the positional encoding expects the + # sequence length to be the first dimension. + embedded_amino_acids = self.pos_encoder( + embedded_amino_acids.permute(1, 0, 2) + ).permute(1, 0, 2) + + # To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439 + return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask) + + class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" diff --git a/netam/protein_embedders.py b/netam/protein_embedders.py new file mode 100644 index 00000000..7e02149c --- /dev/null +++ b/netam/protein_embedders.py @@ -0,0 +1,103 @@ +import torch + +from esm import pretrained + +from netam.common import aa_strs_from_idx_tensor + + +def pad_embeddings(embeddings, desired_length): + """Pads a batch of embeddings to a specified sequence length with zeros. + + Args: + embeddings (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim). + desired_length (int): The length to which each sequence should be padded. + + Returns: + torch.Tensor: A tensor of shape (batch_size, desired_length, embedding_dim). + """ + batch_size, seq_len, embedding_dim = embeddings.size() + + if desired_length <= 0: + raise ValueError("desired_length must be a positive integer") + + # Truncate seq_len if it exceeds desired_length + if seq_len > desired_length: + embeddings = embeddings[:, :desired_length, :] + seq_len = desired_length + + device = embeddings.device + padded_embeddings = torch.zeros( + (batch_size, desired_length, embedding_dim), device=device + ) + padded_embeddings[:, :seq_len, :] = embeddings + return padded_embeddings + + +class ESMEmbedder: + def __init__(self, model_name: str): + """Initializes the ESMEmbedder object. + + Args: + model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D"). + """ + self.model, self.alphabet = pretrained.load_model_and_alphabet(model_name) + self.batch_converter = self.alphabet.get_batch_converter() + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def num_heads(self): + return self.model.layers[0].self_attn.num_heads + + @property + def d_model(self): + return self.model.embed_dim + + @property + def d_model_per_head(self): + return self.d_model // self.num_heads + + @property + def num_layers(self): + return self.model.num_layers + + def to(self, device): + self.model.to(device) + return self + + def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor: + """Embeds a batch of sequences. + + Args: + sequences (list[str]): List of amino acid sequences. + + Returns: + torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim). + """ + named_sequences = [(f"seq_{i}", seq) for i, seq in enumerate(sequences)] + batch_labels, batch_strs, batch_tokens = self.batch_converter(named_sequences) + batch_tokens = batch_tokens.to(self.device) + with torch.no_grad(): + results = self.model(batch_tokens, repr_layers=[self.num_layers]) + embeddings = results["representations"][self.num_layers] + + return embeddings + + def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor: + """Embeds a batch of netam amino acid indices. + + For now, we detokenize the amino acid indices and then use embed_sequence_list. + + Args: + amino_acid_indices (torch.Tensor): A tensor of shape (batch_size, max_aa_seq_len). + + Returns: + torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim). + """ + sequences = aa_strs_from_idx_tensor(amino_acid_indices) + embedding = self.embed_sequence_list(sequences) + desired_length = amino_acid_indices.size(1) + padded_embedding = pad_embeddings(embedding, desired_length) + return padded_embedding diff --git a/setup.py b/setup.py index 8d0f317b..8fdb6018 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ python_requires=">=3.9,<3.12", install_requires=[ "biopython", + "fair-esm", "natsort", "optuna", "pandas", diff --git a/tests/test_common.py b/tests/test_common.py index e7f2f67b..eb22c07f 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,11 @@ import torch -from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of +from netam.common import ( + nt_mask_tensor_of, + aa_mask_tensor_of, + codon_mask_tensor_of, + aa_strs_from_idx_tensor, +) def test_mask_tensor_of(): @@ -25,3 +30,9 @@ def test_codon_mask_tensor_of(): expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool) output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5) assert torch.equal(output, expected_output) + + +def test_aa_strs_from_idx_tensor(): + aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20], [4, 5, 19, 20, 20]]) + aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor) + assert aa_strings == ["ACDE", "FGY"]