Skip to content

Commit

Permalink
Dropping PIE models and python version bump (#99)
Browse files Browse the repository at this point in the history
Also Burrito.prefix -> Burrito.model_type
  • Loading branch information
matsen authored Dec 19, 2024
1 parent 52ee618 commit 685a84e
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 179 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.9, "3.11"]
python-version: [3.9, "3.12"]

runs-on: ${{ matrix.os }}

Expand Down
3 changes: 1 addition & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class DASMDataset(DXSMDataset):
prefix = "dasm"

def update_neutral_probs(self):
neutral_aa_probs_l = []
Expand Down Expand Up @@ -123,7 +122,7 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):


class DASMBurrito(framework.TwoLossMixin, DXSMBurrito):
prefix = "dasm"
model_type = "dasm"

def __init__(self, *args, loss_weights: list = [1.0, 0.01], **kwargs):
super().__init__(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class DNSMDataset(DXSMDataset):
prefix = "dnsm"

def update_neutral_probs(self):
"""Update the neutral mutation probabilities for the dataset.
Expand Down Expand Up @@ -112,7 +111,8 @@ def to(self, device):


class DNSMBurrito(DXSMBurrito):
prefix = "dnsm"

model_type = "dnsm"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def update_neutral_probs(self):


class DXSMBurrito(framework.Burrito, ABC):
prefix = "dxsm"
# Not defining model_type here; instead defining it in subclasses.
# This will raise an error if we aren't using a subclass.

def _find_optimal_branch_length(
self,
Expand Down
68 changes: 0 additions & 68 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.protein_embedders import ESMEmbedder
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
Expand Down Expand Up @@ -746,73 +745,6 @@ def predict(self, representation: Tensor):
return wiggle(super().predict(representation), beta)


class TransformerBinarySelectionModelPIE(TransformerBinarySelectionModelWiggleAct):
"""This version of the model uses an ESM model to embed the amino acid sequences as
an input to the model rather than training an embedding.
PIE stands for Protein Input Embedding.
"""

def __init__(
self,
esm_model_name: str,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
):
self.esm_model_name = esm_model_name
self.pie = ESMEmbedder(model_name=esm_model_name)
super().__init__(
nhead=self.pie.num_heads,
d_model_per_head=self.pie.d_model_per_head,
# The transformer paper uses 4 * d_model for the feedforward layer.
dim_feedforward=self.pie.d_model * 4,
layer_count=layer_count,
dropout_prob=dropout_prob,
output_dim=output_dim,
)

@property
def hyperparameters(self):
return {
"esm_model_name": self.esm_model_name,
"layer_count": self.encoder.num_layers,
"dropout_prob": self.pos_encoder.dropout.p,
"output_dim": self.linear.out_features,
}

def to(self, device):
super().to(device)
self.pie = self.pie.to(device)
return self

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Represent an index-encoded parent sequence in the model's embedding space.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the
indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid
amino acid sites.
Returns:
The embedded parent sequences, in a tensor of shape (B, L, E),
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.pie.embed_batch(amino_acid_indices) * math.sqrt(
self.d_model
)
# Have to do the permutation because the positional encoding expects the
# sequence length to be the first dimension.
embedded_amino_acids = self.pos_encoder(
embedded_amino_acids.permute(1, 0, 2)
).permute(1, 0, 2)

# To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439
return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

Expand Down
103 changes: 0 additions & 103 deletions netam/protein_embedders.py

This file was deleted.

3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
packages=find_packages(),
python_requires=">=3.9,<3.12",
python_requires=">=3.9,<3.13",
install_requires=[
"biopython",
"fair-esm",
"natsort",
"optuna",
"pandas",
Expand Down

0 comments on commit 685a84e

Please sign in to comment.