Skip to content

Commit

Permalink
device fix
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Dec 13, 2024
1 parent 442dbae commit fc21c35
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions netam/protein_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ def pad_embeddings(embeddings, desired_length):


class ESMEmbedder:
def __init__(self, model_name: str, device: str = "cpu"):
def __init__(self, model_name: str):
"""Initializes the ESMEmbedder object.
Args:
model_name (str): Name of the pretrained ESM model (e.g., "esm2_t6_8M_UR50D").
device (str): Device to run the model on.
"""
self.device = device
self.model, self.alphabet = pretrained.load_model_and_alphabet(model_name)
self.model = self.model.to(device)
self.batch_converter = self.alphabet.get_batch_converter()

@property
def device(self):
return next(self.model.parameters()).device

@property
def num_heads(self):
return self.model.layers[0].self_attn.num_heads
Expand Down

0 comments on commit fc21c35

Please sign in to comment.