Skip to content

Commit

Permalink
Split forward functions in DXSM models (#74)
Browse files Browse the repository at this point in the history
* a start at splitting forward function

* edit docstrings

* format

* fix docstring

* adjust lint

* remove TODO
  • Loading branch information
willdumm authored Oct 28, 2024
1 parent 3208899 commit dda7987
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ checkformat:
black --check netam tests

lint:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=_ignore
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=30 --max-line-length=127 --statistics --exclude=_ignore
flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore

docs:
make -C docs html
Expand Down
64 changes: 44 additions & 20 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def site_rates(self):
class AbstractBinarySelectionModel(ABC, nn.Module):
"""A transformer-based model for binary selection.
This is a model that takes in a batch of one-hot encoded sequences and
This is a model that takes in a batch of index-encoded sequences and
outputs a vector that represents the log level of selection for each amino
acid site, which after exponentiating is a multiplier on the probability of
an amino-acid substitution at that site.
Expand All @@ -545,7 +545,7 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
Returns:
A numpy array of the same length as the input string representing
the level of selection for wildtype at each amino acid site.
the level of selection for each amino acid at each site.
"""

model_device = next(self.parameters()).device
Expand Down Expand Up @@ -607,19 +607,18 @@ def init_weights(self) -> None:
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange, initrange)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from a one-hot encoded parent sequence.
Because we're predicting log of the selection factor, we don't use an
activation function after the transformer.
def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Represent an index-encoded parent sequence in the model's embedding space.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid amino acid sites.
amino_acid_indices: A tensor of shape (B, L) containing the
indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid
amino acid sites.
Returns:
A tensor of shape (B, L, 1) representing the log level of selection
for each amino acid site.
The embedded parent sequences, in a tensor of shape (B, L, E),
where E is the dimensionality of the embedding space.
"""
# Multiply by sqrt(d_model) to match the transformer paper.
embedded_amino_acids = self.amino_acid_embedding(
Expand All @@ -632,9 +631,35 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
).permute(1, 0, 2)

# To learn about src_key_padding_mask, see https://stackoverflow.com/q/62170439
out = self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)
out = self.linear(out)
return out.squeeze(-1)
return self.encoder(embedded_amino_acids, src_key_padding_mask=~mask)

def predict(self, representation: Tensor) -> Tensor:
"""Predict selection from the model embedding of a parent sequence.
Args:
representation: A tensor of shape (B, L, E) representing the
embedded parent sequences.
Returns:
A tensor of shape (B, L, out_features) representing the log level
of selection for each amino acid site.
"""
return self.linear(representation).squeeze(-1)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from an index-encoded parent sequence.
Because we're predicting log of the selection factor, we don't use an
activation function after the transformer.
Args:
amino_acid_indices: A tensor of shape (B, L) containing the indices of parent AA sequences.
mask: A tensor of shape (B, L) representing the mask of valid amino acid sites.
Returns:
A tensor of shape (B, L, out_features) representing the log level
of selection for each possible amino acid at each site.
"""
return self.predict(self.represent(amino_acid_indices, mask))


def wiggle(x, beta):
Expand All @@ -650,8 +675,8 @@ def wiggle(x, beta):
class TransformerBinarySelectionModelWiggleAct(TransformerBinarySelectionModelLinAct):
"""Here the beta parameter is fixed at 0.3."""

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
return wiggle(super().forward(amino_acid_indices, mask), 0.3)
def predict(self, representation: Tensor):
return wiggle(super().predict(representation), 0.3)


class TransformerBinarySelectionModelTrainableWiggleAct(
Expand All @@ -672,10 +697,10 @@ def __init__(self, *args, **kwargs):
torch.tensor([init_logit_beta], dtype=torch.float32)
)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
def predict(self, representation: Tensor):
# Apply sigmoid to transform logit_beta back to the range (0, 1)
beta = torch.sigmoid(self.logit_beta)
return wiggle(super().forward(amino_acid_indices, mask), beta)
return wiggle(super().predict(representation), beta)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
Expand All @@ -690,8 +715,7 @@ def hyperparameters(self):
return {}

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from a one-hot encoded parent
sequence."""
"""Build a binary log selection matrix from an index-encoded parent sequence."""
replicated_value = self.single_value.expand_as(amino_acid_indices)
return replicated_value

Expand Down

0 comments on commit dda7987

Please sign in to comment.