From e8d52f654fb98c08a83c9e951c4a599971667158 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 13 Dec 2024 14:43:28 -0800 Subject: [PATCH] WIP --- netam/common.py | 17 +++++++++-- netam/framework.py | 68 ++++++++++++++++++++++++++++++++++------- netam/sequences.py | 47 ++++++++++++++++++---------- tests/test_sequences.py | 7 +++++ 4 files changed, 109 insertions(+), 30 deletions(-) diff --git a/netam/common.py b/netam/common.py index 00979d47..d065f03a 100644 --- a/netam/common.py +++ b/netam/common.py @@ -13,7 +13,7 @@ from torch import nn, Tensor import multiprocessing as mp -from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence +from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence, TOKEN_TRANSLATIONS BIG = 1e9 SMALL_PROB = 1e-6 @@ -88,17 +88,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask +def _consider_codon(codon): + """Return False if codon should be masked, True otherwise.""" + if "N" in codon: + return False + elif codon in TOKEN_TRANSLATIONS: + return False + else: + return True + + def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): """Return a mask tensor indicating codons which contain at least one N. Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, - the "and" mask will be computed for all sequences + the "and" mask will be computed for all sequences. + Codons containing marker tokens are also masked. """ if aa_length is None: aa_length = len(nt_parent) // 3 sequences = (nt_parent,) + other_nt_seqs mask = [ - all("N" not in codon for codon in codons) + all(_consider_codon(codon) for codon in codons) for codons in zip(*(iter_codons(sequence) for sequence in sequences)) ] if len(mask) < aa_length: diff --git a/netam/framework.py b/netam/framework.py index 28eae061..ac3db628 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -2,6 +2,7 @@ import copy import os from time import time +from warnings import warn import pandas as pd import numpy as np @@ -349,28 +350,49 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents): def join_chains(pcp_df): """Join the parent and child chains in the pcp_df. - + Make a parent column that is the parent_h + "^^^" + parent_l, and same for child. - TODO update for case of just parent and child + If parent_h and parent_l are not present, then we assume that the parent is the heavy chain. + If only one of parent_h or parent_l is present, then we place the ^^^ padding to the right of + heavy, or to the left of light. """ - if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns and "child_h" in pcp_df.columns and "child_l" in pcp_df.columns: + cols = pcp_df.columns + if "parent_h" in cols and "parent_l" in cols: + assert "child_h" in cols and "child_l" in cols, "child_h or child_l columns missing!" pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"] pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"] - pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True) - else: - # TODO but there is a chance we'll have some data sets that are just light chain, in which case we'll want to pad on the left. - # I suggest in that case that we just ask for the light chain column to be named parent_l and child_l, and that you can ask for that column name here. - # Let's allow that for the parent_h case too, just in case. + elif "parent_h" in cols and "parent_l" not in cols: + assert "child_h" in cols, "child_h column missing!" + pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["child"] = pcp_df["child_h"] + "^^^" + elif "parent_h" not in cols and "parent_l" in cols: + if "parent" in cols: + warn("Both parent and parent_l columns found. Using only parent_l. " + "To use parent as heavy chain, rename to parent_h.") + assert "child_l" in cols, "child_l column missing!" + pcp_df["parent"] = "^^^" + pcp_df["parent_l"] + pcp_df["child"] = "^^^" + pcp_df["child_l"] + elif "parent" in cols: + assert "child" in cols, "child column missing!" + # We assume that this is the heavy chain. pcp_df["parent"] += "^^^" pcp_df["child"] += "^^^" + else: + raise ValueError("Could not find parent and child columns.") + pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore") return pcp_df + def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joined_mode=False): """Load a PCP dataframe from a gzipped CSV file. `orig_pcp_idx` is the index column from the original file, even if we subset by sampling or by choosing V families. + + If `joined_mode` is True, then we will join the heavy and light chain sequences into a single + sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain + sequence is present, this separator will be added to the appropriate side of the available sequence. """ pcp_df = ( pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0) @@ -379,9 +401,19 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joine ) if joined_mode: pcp_df = join_chains(pcp_df) - # TODO assert that we have parent and child columns - pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0] + if not ("parent" in pcp_df.columns and "child" in pcp_df.columns): + if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns: + pcp_df["parent"] = pcp_df["parent_h"] + pcp_df["child"] = pcp_df["child_h"] + pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore") + else: + raise ValueError( + "Could not find parent and child columns. " + "Perhaps you want to use joined_mode=True?" + ) + if chosen_v_families is not None: + pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0] chosen_v_families = set(chosen_v_families) pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)] if sample_count is not None: @@ -391,11 +423,25 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joine def add_shm_model_outputs_to_pcp_df(pcp_df, crepe): - rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"]) + rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"].str.replace("^", "N")) pcp_df["nt_rates"] = rates pcp_df["nt_csps"] = csps return pcp_df +# TODO we'll need to do something like this: +def _add_shm_model_outputs_to_pcp_df(pcp_df, crepe): + # TODO I can't think of a better generic way to do this.. + split_parents = pcp_df["parent"].str.split("^^^", expand=True) + h_parents = split_parents[0] + l_parents = split_parents[1] + + h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents) + l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents) + # TODO this is bogus + pcp_df["nt_rates"] = h_rates + [0.0, 0.0, 0.0] + l_rates + pcp_df["nt_csps"] = h_csps + [0.0, 0.0, 0.0] + l_csps + return pcp_df + class Burrito(ABC): def __init__( diff --git a/netam/sequences.py b/netam/sequences.py index ee9c1adf..e590c978 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -16,6 +16,9 @@ for codon_list in itertools.product(["A", "C", "G", "T"], repeat=3) ] STOP_CODONS = ["TAA", "TAG", "TGA"] +TOKEN_TRANSLATIONS = { + "^^^": "^", +} def nt_idx_array_of_str(nt_str): @@ -26,11 +29,18 @@ def nt_idx_array_of_str(nt_str): print(f"Found an invalid nucleotide in the string: {nt_str}") raise +def aa_idx_array_of_str(aa_str): + """Return the indices of the amino acids in a string.""" + try: + return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) + except ValueError: + print(f"Found an invalid amino acid in the string: {aa_str}") + raise def aa_idx_array_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return np.array([AA_STR_SORTED.index(aa) for aa in aa_str]) + return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -48,7 +58,7 @@ def nt_idx_tensor_of_str(nt_str): def aa_idx_tensor_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return torch.tensor([AA_STR_SORTED.index(aa) for aa in aa_str]) + return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -91,26 +101,31 @@ def read_fasta_sequences(file_path): return sequences -def translate_sequences(nt_sequences): - aa_sequences = [] - for seq in nt_sequences: - if len(seq) % 3 != 0: - raise ValueError(f"The sequence '{seq}' is not a multiple of 3.") - aa_seq = str(Seq(seq).translate()) - if "*" in aa_seq: - raise ValueError(f"The sequence '{seq}' contains a stop codon.") - aa_sequences.append(aa_seq) - return aa_sequences +def translate_codon(codon): + """Translate a codon to an amino acid.""" + if codon in TOKEN_TRANSLATIONS: + return TOKEN_TRANSLATIONS[codon] + else: + return str(Seq(codon).translate()) def translate_sequence(nt_sequence): - return translate_sequences([nt_sequence])[0] + if len(nt_sequence) % 3 != 0: + raise ValueError(f"The sequence '{nt_sequence}' is not a multiple of 3.") + aa_seq = "".join(translate_codon(nt_sequence[i: i + 3]) for i in range(0, len(nt_sequence), 3)) + if "*" in aa_seq: + raise ValueError(f"The sequence '{nt_sequence}' contains a stop codon.") + return aa_seq + + +def translate_sequences(nt_sequences): + return [translate_sequence(seq) for seq in nt_sequences] def aa_index_of_codon(codon): """Return the index of the amino acid encoded by a codon.""" aa = translate_sequence(codon) - return AA_STR_SORTED.index(aa) + return TOKEN_STR_SORTED.index(aa) def generic_mutation_frequency(ambig_symb, parent, child): @@ -160,12 +175,12 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3): def generate_codon_aa_indicator_matrix(): """Generate a matrix that maps codons (rows) to amino acids (columns).""" - matrix = np.zeros((len(CODONS), len(AA_STR_SORTED))) + matrix = np.zeros((len(CODONS), len(TOKEN_STR_SORTED))) for i, codon in enumerate(CODONS): try: aa = translate_sequences([codon])[0] - aa_idx = AA_STR_SORTED.index(aa) + aa_idx = TOKEN_STR_SORTED.index(aa) matrix[i, aa_idx] = 1 except ValueError: # Handle STOP codon pass diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 8866214e..376154f6 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -5,6 +5,7 @@ from Bio.Data import CodonTable from netam.sequences import ( AA_STR_SORTED, + TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, aa_onehot_tensor_of_str, @@ -14,6 +15,12 @@ ) +def test_token_order(): + # If we always add additional tokens to the end, then converting to indices + # will not be affected when we have a proper aa string. + assert TOKEN_STR_SORTED[:len(AA_STR_SORTED)] == AA_STR_SORTED + + def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("AAA").tolist() == [0, 0, 0] assert nt_idx_array_of_str("TAC").tolist() == [3, 0, 1]