Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 17, 2025
1 parent f03e5d3 commit 96025ab
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
21 changes: 18 additions & 3 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def to(self, device):
self.multihit_model = self.multihit_model.to(device)


def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

Expand All @@ -116,7 +116,7 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = -BIG
] = fill

return predictions

Expand Down Expand Up @@ -204,9 +204,24 @@ def loss_of_batch(self, batch):
csp_loss = self.xent_loss(csp_pred, csp_targets)
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
# TODO have a close look at these two functions, I'm feeling unsure about
# them
def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor):
"""Build a selection matrix from a parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
per_aa_selection_factors = self.model.forward(aa_parent_idxs, mask)


# TODO why 1.0?
return zap_predictions_along_diagonal(per_aa_selection_factors, aa_parent_idxs, fill=1.0)

return per_aa_selection_factors

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent nucleotide sequence.
Values at ambiguous sites are meaningless.
"""
# This is simpler than the equivalent in dnsm.py because we get the selection
Expand Down
2 changes: 1 addition & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def loss_of_batch(self, batch):
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
"""Build a selection matrix from a nucleotide sequence.
Values at ambiguous sites are meaningless.
"""
Expand Down
2 changes: 2 additions & 0 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def _find_optimal_branch_length(
multihit_model,
**optimization_kwargs,
):
# TODO finish switching to build_selection_matrix_from_parent_aa
# thing...
sel_matrix = self.build_selection_matrix_from_parent(parent)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
Expand Down
7 changes: 6 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,17 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Insertion of model tokens will be done automatically.
Args:
aa_str: A string of amino acids.
aa_str: A string of amino acids. If a string, we assume this is a light chain sequence.
Otherwise it should be a tuple, with the first element being the heavy chain and the second element being the light chain sequence.
Returns:
A numpy array of the same length as the input string representing
the level of selection for each amino acid at each site.
If the input was a tuple of heavy/light chain sequences, the output will be a tuple of
numpy arrays.
"""

aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
Expand Down

0 comments on commit 96025ab

Please sign in to comment.