Skip to content

Commit

Permalink
post rebase todos
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 15, 2025
1 parent 326605b commit ddb2442
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
codon_mask_tensor_of,
)
from netam.dxsm import DXSMDataset, DXSMBurrito
from netam.hyper_burrito import HyperBurrito
import netam.molevol as molevol

# TODO strange to have this in common-- I think that it has been moved already upstream?
from netam.common import aa_idx_tensor_of_str_ambig
from netam.sequences import (
aa_idx_array_of_str,
Expand All @@ -25,9 +23,9 @@
translate_sequence,
translate_sequences,
codon_idx_tensor_of_str_ambig,
AA_AMBIG_IDX,
AMBIGUOUS_CODON_IDX,
CODON_AA_INDICATOR_MATRIX,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
)

Expand Down Expand Up @@ -124,12 +122,11 @@ def __init__(
(pcp_count, self.max_codon_seq_len), AMBIGUOUS_CODON_IDX
)
self.codon_children_idxss = self.codon_parents_idxss.clone()
# TODO: want to use ambig token once that's changed.
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX
(pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX
)
self.aa_children_idxss = torch.full(
(pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX
(pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX
)
# TODO here we are computing the subs indicators. This is handy for OE plots.
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len))
Expand Down Expand Up @@ -166,8 +163,8 @@ def __init__(
)

assert torch.all(self.masks.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX
assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX
assert torch.max(self.aa_parents_idxss) <= AA_AMBIG_IDX
assert torch.max(self.aa_children_idxss) <= AA_AMBIG_IDX
assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX

self._branch_lengths = branch_lengths
Expand Down

0 comments on commit ddb2442

Please sign in to comment.