Skip to content

Commit

Permalink
PIE tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Dec 16, 2024
1 parent 9c78482 commit 55fe2a8
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 11 deletions.
69 changes: 69 additions & 0 deletions netam/dasmpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Here we define a model that outputs a vector of 20 amino acid preferences, using a protein model embedding as input."""

import torch
import torch.nn.functional as F

import esm

from netam.dasm import DASMDataset, DASMBurrito
from netam.sequences import (
translate_sequences,
)


class DASMPIEDataset(DASMDataset):
# TODO does this do anything?
prefix = "dasm"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Note that all ESM2 models use the ESM-1b alphabet
# https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L175
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
batch_converter = alphabet.get_batch_converter()
aa_parents = translate_sequences(self.nt_parents)
_, _, self.pie_tokens = batch_converter(
[(f"seq_{i}", seq) for i, seq in enumerate(aa_parents)]
)

def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxss[idx],
"aa_children_idxs": self.aa_children_idxss[idx],
"pie_tokens": self.pie_tokens[idx],
"subs_indicator": self.aa_subs_indicators[idx],
"mask": self.masks[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probss[idx],
"nt_rates": self.nt_ratess[idx],
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_children_idxss = self.aa_children_idxss.to(device)
self.pie_tokens = self.pie_tokens.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
self.log_neutral_aa_probss = self.log_neutral_aa_probss.to(device)
self.nt_ratess = self.nt_ratess.to(device)
self.nt_cspss = self.nt_cspss.to(device)
if self.multihit_model is not None:
self.multihit_model = self.multihit_model.to(device)


class DASMPIEBurrito(DASMBurrito):
# TODO does this do anything?
prefix = "dasmpie"

def prediction_pair_of_batch(self, batch):
"""Get log neutral AA probabilities and log selection factors for a batch of
data."""
pie_tokens = batch["pie_tokens"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_probs = batch["log_neutral_aa_probs"].to(self.device)
if not torch.isfinite(log_neutral_aa_probs[mask]).all():
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
log_selection_factors = self.model(pie_tokens, mask)
return log_neutral_aa_probs, log_selection_factors
31 changes: 27 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,12 +767,12 @@ def to(self, device):
self.pie = self.pie.to(device)
return self

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
def represent(self, pie_tokens: Tensor, mask: Tensor) -> Tensor:
"""Represent an index-encoded parent sequence in the model's embedding space.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the
indices of parent AA sequences.
pie_tokens: A tensor of shape (B, L) containing the
tokens of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid
amino acid sites.
Expand All @@ -781,7 +781,7 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt(
embedded_amino_acids = self.pie.embed_tokens(pie_tokens) * math.sqrt(
self.d_model
)
# Have to do the permutation because the positional encoding expects the
Expand All @@ -793,6 +793,29 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
# To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439
return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)

def selection_factors_of_aa_str(self, aa_str):
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Args:
aa_str: A string of amino acids.
Returns:
A numpy array of the same length as the input string representing
the level of selection for each amino acid at each site.
"""
tokens = self.pie.tokenize_sequences([aa_str])
tokens = tokens.to(self.device)
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

with torch.no_grad():
# TODO note that I removed unsqueeze(0) from tokens
model_out = self(tokens, mask.unsqueeze(0)).squeeze(0)
final_out = torch.exp(model_out)

return final_out[: len(aa_str)]


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""
Expand Down
35 changes: 28 additions & 7 deletions netam/protein_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,45 @@ def to(self, device):
self.model.to(device)
return self

def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""Embeds a batch of sequences.
def tokenize_sequences(self, sequences: list[str]) -> torch.Tensor:
"""Tokenizes a batch of sequences.
Args:
sequences (list[str]): List of amino acid sequences.
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len).
"""
named_sequences = [(f"seq_{i}", seq) for i, seq in enumerate(sequences)]
batch_labels, batch_strs, batch_tokens = self.batch_converter(named_sequences)
batch_tokens = batch_tokens.to(self.device)
_, _, batch_tokens = self.batch_converter(named_sequences)
return batch_tokens

def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""Embeds a batch of tokens.
Args:
tokens (torch.Tensor): A tensor of shape (batch_size, seq_len).
Returns:
torch.Tensor: A tensor of shape (batch_size, seq_len, embedding_dim).
"""
tokens = tokens.to(self.device)
with torch.no_grad():
results = self.model(batch_tokens, repr_layers=[self.num_layers])
results = self.model(tokens, repr_layers=[self.num_layers])
embeddings = results["representations"][self.num_layers]

return embeddings

def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""Embeds a batch of sequences.
Args:
sequences (list[str]): List of amino acid sequences.
Returns:
torch.Tensor: A tensor of shape (batch_size, max_aa_seq_len, embedding_dim).
"""
return self.embed_tokens(self.tokenize_sequences(sequences))

def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor:
"""Embeds a batch of netam amino acid indices.
Expand Down

0 comments on commit 55fe2a8

Please sign in to comment.