From 442dbaeba791c77d7a7e6087c6d62b1f4e23bca2 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 13 Dec 2024 05:17:30 -0800 Subject: [PATCH] make format --- netam/common.py | 9 ++++----- netam/models.py | 22 +++++++++++----------- netam/protein_embedders.py | 12 ++++-------- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/netam/common.py b/netam/common.py index 4aee2328..7a1970df 100644 --- a/netam/common.py +++ b/netam/common.py @@ -110,11 +110,10 @@ def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): def aa_strs_from_idx_tensor(idx_tensor): - """ - Convert a tensor of amino acid indices back to a list of amino acid strings. - + """Convert a tensor of amino acid indices back to a list of amino acid strings. + Args: - idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing + idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing indices into AA_STR_SORTED_AMBIG. Returns: @@ -126,7 +125,7 @@ def aa_strs_from_idx_tensor(idx_tensor): for row in idx_tensor: aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist()) aa_str_list.append(aa_str.rstrip("X")) - + return aa_str_list diff --git a/netam/models.py b/netam/models.py index 4e30f963..5a62e3db 100644 --- a/netam/models.py +++ b/netam/models.py @@ -730,23 +730,23 @@ def predict(self, representation: Tensor): class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct): """Here the beta parameter is fixed at 0.3.""" - def __init__(self, - esm_model_name: str, - layer_count: int, - dropout_prob: float = 0.5, - output_dim: int = 1, - ): + def __init__( + self, + esm_model_name: str, + layer_count: int, + dropout_prob: float = 0.5, + output_dim: int = 1, + ): self.pie = ESMEmbedder(model_name=esm_model_name) super().__init__( nhead=self.pie.num_heads, d_model_per_head=self.pie.d_model_per_head, # TODO this is hard coded as per Antoine. - dim_feedforward=self.pie.d_model*4, + dim_feedforward=self.pie.d_model * 4, layer_count=layer_count, dropout_prob=dropout_prob, output_dim=output_dim, ) - def to(self, device): super().to(device) @@ -767,9 +767,9 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: where E is the dimensionality of the embedding space. """ # Multiply by sqrt(d_model) to match the transformer paper. - embedded_amino_acids = self.pie.embed_batch( - amino_acid_indices - ) * math.sqrt(self.d_model) + embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt( + self.d_model + ) # Have to do the permutation because the positional encoding expects the # sequence length to be the first dimension. embedded_amino_acids = self.pos_encoder( diff --git a/netam/protein_embedders.py b/netam/protein_embedders.py index 020fa710..455c6158 100644 --- a/netam/protein_embedders.py +++ b/netam/protein_embedders.py @@ -6,8 +6,7 @@ def pad_embeddings(embeddings, desired_length): - """ - Pads a batch of embeddings to a specified sequence length with zeros. + """Pads a batch of embeddings to a specified sequence length with zeros. Args: embeddings (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim). @@ -36,8 +35,7 @@ def pad_embeddings(embeddings, desired_length): class ESMEmbedder: def __init__(self, model_name: str, device: str = "cpu"): - """ - Initializes the ESMEmbedder object. + """Initializes the ESMEmbedder object. Args: model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D"). @@ -65,8 +63,7 @@ def num_layers(self): return self.model.num_layers def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor: - """ - Embeds a batch of sequences. + """Embeds a batch of sequences. Args: sequences (list[str]): List of amino acid sequences. @@ -84,8 +81,7 @@ def embed_sequence_list(self, sequences: list[str]) -> torch.Tensor: return embeddings def embed_batch(self, amino_acid_indices: torch.Tensor) -> torch.Tensor: - """ - Embeds a batch of netam amino acid indices. + """Embeds a batch of netam amino acid indices. For now, we detokenize the amino acid indices and then use embed_sequence_list.