Skip to content

Commit

Permalink
Add function to convert amino acid index tensor to strings and corres…
Browse files Browse the repository at this point in the history
…ponding tests
  • Loading branch information
matsen committed Dec 13, 2024
1 parent eeb4353 commit a15eba0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
21 changes: 21 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,27 @@ def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
return torch.tensor(mask, dtype=torch.bool)


def aa_strs_from_idx_tensor(idx_tensor):
"""
Convert a tensor of amino acid indices back to a list of amino acid strings.
Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_STR_SORTED_AMBIG.
Returns:
List[str]: A list of amino acid strings with trailing 'X's removed.
"""
idx_tensor = idx_tensor.cpu()

aa_str_list = []
for row in idx_tensor:
aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list


def assert_pcp_valid(parent, child, aa_mask=None):
"""Check that the parent-child pairs are valid.
Expand Down
13 changes: 12 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch

from netam.common import nt_mask_tensor_of, aa_mask_tensor_of, codon_mask_tensor_of
from netam.common import (
nt_mask_tensor_of,
aa_mask_tensor_of,
codon_mask_tensor_of,
aa_strs_from_idx_tensor,
)


def test_mask_tensor_of():
Expand All @@ -25,3 +30,9 @@ def test_codon_mask_tensor_of():
expected_output = torch.tensor([0, 0, 1, 0, 0], dtype=torch.bool)
output = codon_mask_tensor_of(input_seq, input_seq2, aa_length=5)
assert torch.equal(output, expected_output)


def test_aa_strs_from_idx_tensor():
aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20], [4, 5, 19, 20, 20]])
aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor)
assert aa_strings == ["ACDE", "FGY"]

0 comments on commit a15eba0

Please sign in to comment.