Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Dec 13, 2024
1 parent 2349c7a commit 442dbae
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 24 deletions.
9 changes: 4 additions & 5 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,10 @@ def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):


def aa_strs_from_idx_tensor(idx_tensor):
"""
Convert a tensor of amino acid indices back to a list of amino acid strings.
"""Convert a tensor of amino acid indices back to a list of amino acid strings.
Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_STR_SORTED_AMBIG.
Returns:
Expand All @@ -126,7 +125,7 @@ def aa_strs_from_idx_tensor(idx_tensor):
for row in idx_tensor:
aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list


Expand Down
22 changes: 11 additions & 11 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,23 +730,23 @@ def predict(self, representation: Tensor):
class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct):
"""Here the beta parameter is fixed at 0.3."""

def __init__(self,
esm_model_name: str,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
def __init__(
self,
esm_model_name: str,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
self.pie = ESMEmbedder(model_name=esm_model_name)
super().__init__(
nhead=self.pie.num_heads,
d_model_per_head=self.pie.d_model_per_head,
# TODO this is hard coded as per Antoine.
dim_feedforward=self.pie.d_model*4,
dim_feedforward=self.pie.d_model * 4,
layer_count=layer_count,
dropout_prob=dropout_prob,
output_dim=output_dim,
)


def to(self, device):
super().to(device)
Expand All @@ -767,9 +767,9 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.pie.embed_batch(
amino_acid_indices
) * math.sqrt(self.d_model)
embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt(
self.d_model
)
# Have to do the permutation because the positional encoding expects the
# sequence length to be the first dimension.
embedded_amino_acids = self.pos_encoder(
Expand Down
12 changes: 4 additions & 8 deletions netam/protein_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@


def pad_embeddings(embeddings, desired_length):
"""
Pads a batch of embeddings to a specified sequence length with zeros.
"""Pads a batch of embeddings to a specified sequence length with zeros.
Args:
embeddings (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).
Expand Down Expand Up @@ -36,8 +35,7 @@ def pad_embeddings(embeddings, desired_length):

class ESMEmbedder:
def __init__(self, model_name: str, device: str = "cpu"):
"""
Initializes the ESMEmbedder object.
"""Initializes the ESMEmbedder object.
Args:
model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D").
Expand Down Expand Up @@ -65,8 +63,7 @@ def num_layers(self):
return self.model.num_layers

def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
"""
Embeds a batch of sequences.
"""Embeds a batch of sequences.
Args:
sequences (list[str]): List of amino acid sequences.
Expand All @@ -84,8 +81,7 @@ def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor:
return embeddings

def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor:
"""
Embeds a batch of netam amino acid indices.
"""Embeds a batch of netam amino acid indices.
For now, we detokenize the amino acid indices and then use embed_sequence_list.
Expand Down

0 comments on commit 442dbae

Please sign in to comment.