diff --git a/netam/common.py b/netam/common.py index 3bf34169..f57f66cd 100644 --- a/netam/common.py +++ b/netam/common.py @@ -89,38 +89,6 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask -def _consider_codon(codon): - """Return False if codon should be masked, True otherwise.""" - if "N" in codon: - return False - elif codon in RESERVED_TOKEN_TRANSLATIONS: - return False - else: - return True - - -def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): - """Return a mask tensor indicating codons which contain at least one N. - - Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, - the "and" mask will be computed for all sequences. Codons containing marker tokens - are also masked. - """ - if aa_length is None: - aa_length = len(nt_parent) // 3 - sequences = (nt_parent,) + other_nt_seqs - mask = [ - all(_consider_codon(codon) for codon in codons) - for codons in zip(*(iter_codons(sequence) for sequence in sequences)) - ] - if len(mask) < aa_length: - mask += [False] * (aa_length - len(mask)) - else: - mask = mask[:aa_length] - assert len(mask) == aa_length - return torch.tensor(mask, dtype=torch.bool) - - def aa_strs_from_idx_tensor(idx_tensor): """Convert a tensor of amino acid indices back to a list of amino acid strings. @@ -177,6 +145,39 @@ def aa_mask_tensor_of(*args, **kwargs): return generic_mask_tensor_of("X", *args, **kwargs) +def _consider_codon(codon): + """Return False if codon should be masked, True otherwise.""" + if "N" in codon: + return False + elif codon in RESERVED_TOKEN_TRANSLATIONS: + return False + else: + return True + + +def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): + """Return a mask tensor indicating codons which contain at least one N. + + Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, + the "and" mask will be computed for all sequences. Codons containing marker tokens + are also masked. + """ + if aa_length is None: + aa_length = len(nt_parent) // 3 + sequences = (nt_parent,) + other_nt_seqs + mask = [ + all(_consider_codon(codon) for codon in codons) + for codons in zip(*(iter_codons(sequence) for sequence in sequences)) + ] + if len(mask) < aa_length: + mask += [False] * (aa_length - len(mask)) + else: + mask = mask[:aa_length] + assert len(mask) == aa_length + return torch.tensor(mask, dtype=torch.bool) + + + def informative_site_count(seq_str): return sum(c != "N" for c in seq_str) @@ -429,6 +430,23 @@ def chunked(iterable, n): yield chunk +def assume_single_sequence_is_heavy_chain(function): + """Wraps a function that takes a heavy/light sequence pair as its first argument + and returns a tuple of results. + + The wrapped function will assume that if the first argument is a string, it is a + heavy chain sequence, and in that case will return only the heavy chain result.""" + @wraps(function) + def wrapper(*args, **kwargs): + seq = args[0] + if isinstance(seq, str): + seq = (seq, "") + res = function(seq, *args[1:], **kwargs) + return res[0] + else: + return function(*args, **kwargs) + + def chunk_function( first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None ): diff --git a/netam/dasm.py b/netam/dasm.py index 473bd289..c542acba 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -139,10 +139,7 @@ def prediction_pair_of_batch(self, batch): raise ValueError( f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}" ) - # We need the model to see special tokens here. For every other purpose - # they are masked out. - keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs) - log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) + log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) return log_neutral_aa_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): @@ -207,18 +204,17 @@ def loss_of_batch(self, batch): # TODO have a close look at these two functions, I'm feeling unsure about # them def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor): - """Build a selection matrix from a parent amino acid sequence. + """Build a selection matrix from a single parent amino acid sequence. + Inputs are expected to be as prepared in the Dataset constructor. Values at ambiguous sites are meaningless. """ - per_aa_selection_factors = self.model.forward(aa_parent_idxs, mask) - + with torch.no_grad(): + per_aa_selection_factors = self.selection_factors_of_aa_idxs(aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0).exp() # TODO why 1.0? return zap_predictions_along_diagonal(per_aa_selection_factors, aa_parent_idxs, fill=1.0) - return per_aa_selection_factors - def build_selection_matrix_from_parent(self, parent: str): """Build a selection matrix from a parent nucleotide sequence. diff --git a/netam/dnsm.py b/netam/dnsm.py index d30157ee..85cc4622 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -127,7 +127,8 @@ def prediction_pair_of_batch(self, batch): raise ValueError( f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}" ) - log_selection_factors = self.model(aa_parents_idxs, mask) + # Right here is where model is evaluated! + log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) return log_neutral_aa_mut_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors): @@ -156,6 +157,30 @@ def loss_of_batch(self, batch): predictions = self.predictions_of_batch(batch).masked_select(mask) return self.bce_loss(predictions, aa_subs_indicator) + + def _build_selection_matrix_from_selection_factors(self, selection_factors, aa_parent_idxs): + """Build a selection matrix from a selection factor tensor for a single sequence. + upgrades the provided tensor containing a selection factor per site to a matrix + containing a selection factor per site and amino acid. The wildtype aa selection + factor is set ot 1, and the rest are set to the selection factor.""" + selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) + # Every "off-diagonal" entry of the selection matrix is set to the selection + # factor, where "diagonal" means keeping the same amino acid. + selection_matrix[:, :] = selection_factors[:, None] + selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 + return selection_matrix + + def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor): + """Build a selection matrix from a single parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ + with torch.no_grad(): + per_aa_selection_factors = self.selection_factors_of_aa_idxs(aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0).exp() + return self._build_selection_matrix_from_selection_factors( + selection_factors, aa_parent_idxs + + # TODO upgrade this to take pair of heavy and light sequences def build_selection_matrix_from_parent(self, parent: str): """Build a selection matrix from a nucleotide sequence. @@ -163,17 +188,13 @@ def build_selection_matrix_from_parent(self, parent: str): """ parent = sequences.translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) - selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) - # Every "off-diagonal" entry of the selection matrix is set to the selection - # factor, where "diagonal" means keeping the same amino acid. - selection_matrix[:, :] = selection_factors[:, None] parent = parent.replace("X", "A") # Set "diagonal" elements to one. parent_idxs = sequences.aa_idx_array_of_str(parent) - selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 - - return selection_matrix + return self._build_selection_matrix_from_selection_factors( + selection_factors, parent_idxs + ) class DNSMHyperBurrito(HyperBurrito): # Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method. diff --git a/netam/dxsm.py b/netam/dxsm.py index 470a1753..35897a34 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -29,6 +29,7 @@ nt_mutation_frequency, strip_unrecognized_tokens_from_series, dataset_inputs_of_pcp_df, + token_mask_of_aa_idxs, MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, AA_AMBIG_IDX, @@ -36,7 +37,8 @@ class DXSMDataset(framework.BranchLengthDataset, ABC): - prefix = "dxsm" + # Not defining model_type here; instead defining it in subclasses. + # This will raise an error if we aren't using a subclass. def __init__( self, @@ -271,6 +273,16 @@ class DXSMBurrito(framework.Burrito, ABC): # Not defining model_type here; instead defining it in subclasses. # This will raise an error if we aren't using a subclass. + def selection_factors_of_aa_idxs(self, aa_idxs, aa_mask): + """Get the log selection factors for a batch of amino acid indices. + aa_idxs and aa_mask are expected to be as prepared in the Dataset constructor.""" + + # We need the model to see special tokens here. For every other purpose + # they are masked out. + keep_token_mask = mask | token_mask_of_aa_idxs(aa_idxs) + return self.model(aa_idxs, keep_token_mask) + + def _find_optimal_branch_length( self, parent, @@ -278,13 +290,14 @@ def _find_optimal_branch_length( nt_rates, nt_csps, aa_mask, + aa_parents_indices, starting_branch_length, multihit_model, **optimization_kwargs, ): # TODO finish switching to build_selection_matrix_from_parent_aa # thing... - sel_matrix = self.build_selection_matrix_from_parent(parent) + sel_matrix = self.build_selection_matrix_from_parent_aa(aa_parents_indices, aa_mask) trimmed_aa_mask = aa_mask[: len(sel_matrix)] log_pcp_probability = molevol.mutsel_log_pcp_probability_of( sel_matrix[trimmed_aa_mask], @@ -304,13 +317,14 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] failed_count = 0 - for parent, child, nt_rates, nt_csps, aa_mask, starting_length in tqdm( + for parent, child, nt_rates, nt_csps, aa_mask, aa_parents_indices, starting_length in tqdm( zip( dataset.nt_parents, dataset.nt_children, dataset.nt_ratess, dataset.nt_cspss, dataset.masks, + dataset.aa_parents_idxss, dataset.branch_lengths, ), total=len(dataset.nt_parents), @@ -322,6 +336,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): nt_rates[: len(parent)], nt_csps[: len(parent), :], aa_mask, + aa_parents_indices, starting_length, dataset.multihit_model, **optimization_kwargs, diff --git a/netam/models.py b/netam/models.py index a842fe62..beaa217f 100644 --- a/netam/models.py +++ b/netam/models.py @@ -18,10 +18,13 @@ aa_mask_tensor_of, encode_sequences, chunk_function, + assume_single_sequence_is_heavy_chain, ) from netam.sequences import set_wt_to_nan +from typing import Tuple + warnings.filterwarnings( "ignore", category=UserWarning, module="torch.nn.modules.transformer" ) @@ -580,7 +583,10 @@ def predictions_of_sequences(self, sequences, **kwargs): def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor: return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences) - def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: + # TODO make sure that insertion of model tokens is actually done here... + # Also check if this is used anymore... + @assume_single_sequence_is_heavy_chain + def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor: """Do the forward method then exponentiation without gradients from an amino acid string. @@ -591,40 +597,34 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor: Otherwise it should be a tuple, with the first element being the heavy chain and the second element being the light chain sequence. Returns: - A numpy array of the same length as the input string representing + A tuple of numpy arrays of the same length as the input strings representing the level of selection for each amino acid at each site. - If the input was a tuple of heavy/light chain sequences, the output will be a tuple of - numpy arrays. """ + aa_str, added_indices = sequences.prepare_heavy_light_pair(*aa_sequence, self.hyperparameters["embedding_dim"]) aa_idxs = aa_idx_tensor_of_str_ambig(aa_str) aa_idxs = aa_idxs.to(self.device) + # This makes the expected mask because of + # test_sequence.py::test_compare_mask_tensors. + # TODO write test that compares for all possible embedding_dim values + # the output of aa_mask_tensor_of and (codon_mask_tensor_of | token_mask_of_aa_idxs). + # (Here we expect those two to be the same) mask = aa_mask_tensor_of(aa_str) mask = mask.to(self.device) - # Here we're ignoring sites containing tokens that have index greater - # than the embedding dimension. If extra tokens have been added since - # this model was defined, they are stripped out before feeding the - # sequence to the model, and the returned selection factors will be NaN - # at sites containing those unrecognized tokens. - model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"] - if self.hyperparameters["output_dim"] == 1: - result = torch.full((len(aa_str),), float("nan"), device=self.device) - else: - result = torch.full( - (len(aa_str), self.hyperparameters["output_dim"]), - float("nan"), - device=self.device, - ) - with torch.no_grad(): model_out = self( - aa_idxs[model_valid_sites].unsqueeze(0), - mask[model_valid_sites].unsqueeze(0), - ).squeeze(0) - result[model_valid_sites] = torch.exp(model_out)[: model_valid_sites.sum()] - - return result + aa_idxs.unsqueeze(0), + mask.unsqueeze(0), + ).squeeze(0).exp() + + # Now split into heavy and light chain results: + sequence_mask = torch.ones(len(model_out), dtype=bool) + sequence_mask[added_indices] = False + masked_model_out = model_out[sequence_mask] + light_chain = masked_model_out[:len(aa_str[0])] + heavy_chain = masked_model_out[len(aa_str[0]):] + return light_chain, heavy_chain class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel): diff --git a/tests/test_dasm.py b/tests/test_dasm.py index 0d2d0b22..eb413663 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -14,10 +14,12 @@ DASMDataset, zap_predictions_along_diagonal, ) -from netam.sequences import MAX_EMBEDDING_DIM +from netam.sequences import MAX_EMBEDDING_DIM, TOKEN_STR_SORTED -@pytest.fixture(scope="module") +# TODO verify that this loops through both pcp_dfs, even though one is named +# the same as the argument. If not, remember to fix in test_dnsm.py too. +@pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"]) def dasm_burrito(pcp_df): force_spawn() """Fixture that returns the DNSM Burrito object.""" @@ -94,3 +96,23 @@ def test_zap_diagonal(dasm_burrito): zeroed_predictions[batch_idx, i, j] == predictions[batch_idx, i, j] ) + + +# TODO this won't work until build_selection_matrix_from_parent is fixed +def test_build_selection_matrix_from_parent(dasm_burrito): + dataset_row = dasm_burrito.val_dataset[0] + + parent = dasm_burrito.val_dataset.nt_parents[0] + parent_aa_idxs = dasm_burrito.val_dataset.aa_parents_idxss[0] + aa_mask = dasm_burrito.val_dataset.masks[0] + aa_parent = "".join(TOKEN_STR_SORTED[i] for i in parent) + + separator_idx = aa_parent.index('^') * 3 + light_chain_seq = parent[:separator_idx] + heavy_chain_seq = parent[separator_idx + 3:] + + direct_val = dasm_burrito.build_selection_matrix_from_parent_aa(parent_aa_idxs, aa_mask) + + indirect_val = dasm_burrito.build_selection_matrix_from_parent((light_chain_seq, heavy_chain_seq)) + + assert torch.allclose(direct_val, indirect_val) diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index daebf194..3700ca59 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -10,7 +10,7 @@ from netam.common import aa_idx_tensor_of_str_ambig, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dnsm import DNSMBurrito, DNSMDataset -from netam.sequences import AA_AMBIG_IDX, MAX_EMBEDDING_DIM +from netam.sequences import AA_AMBIG_IDX, MAX_EMBEDDING_DIM, TOKEN_STR_SORTED def test_aa_idx_tensor_of_str_ambig(): @@ -20,7 +20,7 @@ def test_aa_idx_tensor_of_str_ambig(): assert torch.equal(output, expected_output) -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"]) def dnsm_burrito(pcp_df): """Fixture that returns the DNSM Burrito object.""" force_spawn() @@ -68,3 +68,23 @@ def test_crepe_roundtrip(dnsm_burrito): dnsm_burrito.model.state_dict().values(), model.state_dict().values() ): assert torch.equal(t1, t2) + + +# TODO this won't work until build_selection_matrix_from_parent is fixed +def test_build_selection_matrix_from_parent(dasm_burrito): + dataset_row = dasm_burrito.val_dataset[0] + + parent = dasm_burrito.val_dataset.nt_parents[0] + parent_aa_idxs = dasm_burrito.val_dataset.aa_parents_idxss[0] + aa_mask = dasm_burrito.val_dataset.masks[0] + aa_parent = "".join(TOKEN_STR_SORTED[i] for i in parent) + + separator_idx = aa_parent.index('^') * 3 + light_chain_seq = parent[:separator_idx] + heavy_chain_seq = parent[separator_idx + 3:] + + direct_val = dasm_burrito.build_selection_matrix_from_parent_aa(parent_aa_idxs, aa_mask) + + indirect_val = dasm_burrito.build_selection_matrix_from_parent((light_chain_seq, heavy_chain_seq)) + + assert torch.allclose(direct_val, indirect_val)