diff --git a/netam/dcsm.py b/netam/dcsm.py index c766da0c..e73a0192 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -12,10 +12,8 @@ codon_mask_tensor_of, ) from netam.dxsm import DXSMDataset, DXSMBurrito -from netam.hyper_burrito import HyperBurrito import netam.molevol as molevol -# TODO strange to have this in common-- I think that it has been moved already upstream? from netam.common import aa_idx_tensor_of_str_ambig from netam.sequences import ( aa_idx_array_of_str, @@ -25,9 +23,9 @@ translate_sequence, translate_sequences, codon_idx_tensor_of_str_ambig, + AA_AMBIG_IDX, AMBIGUOUS_CODON_IDX, CODON_AA_INDICATOR_MATRIX, - MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, ) @@ -124,12 +122,11 @@ def __init__( (pcp_count, self.max_codon_seq_len), AMBIGUOUS_CODON_IDX ) self.codon_children_idxss = self.codon_parents_idxss.clone() - # TODO: want to use ambig token once that's changed. self.aa_parents_idxss = torch.full( - (pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX + (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX ) self.aa_children_idxss = torch.full( - (pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX + (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX ) # TODO here we are computing the subs indicators. This is handy for OE plots. self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len)) @@ -166,8 +163,8 @@ def __init__( ) assert torch.all(self.masks.sum(dim=1) > 0) - assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX - assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.aa_parents_idxss) <= AA_AMBIG_IDX + assert torch.max(self.aa_children_idxss) <= AA_AMBIG_IDX assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX self._branch_lengths = branch_lengths