Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Dec 13, 2024
1 parent e400070 commit e8d52f6
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 30 deletions.
17 changes: 14 additions & 3 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn, Tensor
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence
from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence, TOKEN_TRANSLATIONS

BIG = 1e9
SMALL_PROB = 1e-6
Expand Down Expand Up @@ -88,17 +88,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences
the "and" mask will be computed for all sequences.
Codons containing marker tokens are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all("N" not in codon for codon in codons)
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
Expand Down
68 changes: 57 additions & 11 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import os
from time import time
from warnings import warn

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -349,28 +350,49 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):

def join_chains(pcp_df):
"""Join the parent and child chains in the pcp_df.
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
TODO update for case of just parent and child
If parent_h and parent_l are not present, then we assume that the parent is the heavy chain.
If only one of parent_h or parent_l is present, then we place the ^^^ padding to the right of
heavy, or to the left of light.
"""
if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns and "child_h" in pcp_df.columns and "child_l" in pcp_df.columns:
cols = pcp_df.columns
if "parent_h" in cols and "parent_l" in cols:
assert "child_h" in cols and "child_l" in cols, "child_h or child_l columns missing!"
pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]
pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True)
else:
# TODO but there is a chance we'll have some data sets that are just light chain, in which case we'll want to pad on the left.
# I suggest in that case that we just ask for the light chain column to be named parent_l and child_l, and that you can ask for that column name here.
# Let's allow that for the parent_h case too, just in case.
elif "parent_h" in cols and "parent_l" not in cols:
assert "child_h" in cols, "child_h column missing!"
pcp_df["parent"] = pcp_df["parent_h"] + "^^^"
pcp_df["child"] = pcp_df["child_h"] + "^^^"
elif "parent_h" not in cols and "parent_l" in cols:
if "parent" in cols:
warn("Both parent and parent_l columns found. Using only parent_l. "
"To use parent as heavy chain, rename to parent_h.")
assert "child_l" in cols, "child_l column missing!"
pcp_df["parent"] = "^^^" + pcp_df["parent_l"]
pcp_df["child"] = "^^^" + pcp_df["child_l"]
elif "parent" in cols:
assert "child" in cols, "child column missing!"
# We assume that this is the heavy chain.
pcp_df["parent"] += "^^^"
pcp_df["child"] += "^^^"
else:
raise ValueError("Could not find parent and child columns.")
pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore")
return pcp_df


def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joined_mode=False):
"""Load a PCP dataframe from a gzipped CSV file.
`orig_pcp_idx` is the index column from the original file, even if we subset by
sampling or by choosing V families.
If `joined_mode` is True, then we will join the heavy and light chain sequences into a single
sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain
sequence is present, this separator will be added to the appropriate side of the available sequence.
"""
pcp_df = (
pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0)
Expand All @@ -379,9 +401,19 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joine
)
if joined_mode:
pcp_df = join_chains(pcp_df)
# TODO assert that we have parent and child columns
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
if not ("parent" in pcp_df.columns and "child" in pcp_df.columns):
if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns:
pcp_df["parent"] = pcp_df["parent_h"]
pcp_df["child"] = pcp_df["child_h"]
pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore")
else:
raise ValueError(
"Could not find parent and child columns. "
"Perhaps you want to use joined_mode=True?"
)

if chosen_v_families is not None:
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
if sample_count is not None:
Expand All @@ -391,11 +423,25 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None, joine


def add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"].str.replace("^", "N"))
pcp_df["nt_rates"] = rates
pcp_df["nt_csps"] = csps
return pcp_df

# TODO we'll need to do something like this:
def _add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
# TODO I can't think of a better generic way to do this..
split_parents = pcp_df["parent"].str.split("^^^", expand=True)
h_parents = split_parents[0]
l_parents = split_parents[1]

h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents)
l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents)
# TODO this is bogus
pcp_df["nt_rates"] = h_rates + [0.0, 0.0, 0.0] + l_rates
pcp_df["nt_csps"] = h_csps + [0.0, 0.0, 0.0] + l_csps
return pcp_df


class Burrito(ABC):
def __init__(
Expand Down
47 changes: 31 additions & 16 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
for codon_list in itertools.product(["A", "C", "G", "T"], repeat=3)
]
STOP_CODONS = ["TAA", "TAG", "TGA"]
TOKEN_TRANSLATIONS = {
"^^^": "^",
}


def nt_idx_array_of_str(nt_str):
Expand All @@ -26,11 +29,18 @@ def nt_idx_array_of_str(nt_str):
print(f"Found an invalid nucleotide in the string: {nt_str}")
raise

def aa_idx_array_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise

def aa_idx_array_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return np.array([AA_STR_SORTED.index(aa) for aa in aa_str])
return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand All @@ -48,7 +58,7 @@ def nt_idx_tensor_of_str(nt_str):
def aa_idx_tensor_of_str(aa_str):
"""Return the indices of the amino acids in a string."""
try:
return torch.tensor([AA_STR_SORTED.index(aa) for aa in aa_str])
return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str])
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
raise
Expand Down Expand Up @@ -91,26 +101,31 @@ def read_fasta_sequences(file_path):
return sequences


def translate_sequences(nt_sequences):
aa_sequences = []
for seq in nt_sequences:
if len(seq) % 3 != 0:
raise ValueError(f"The sequence '{seq}' is not a multiple of 3.")
aa_seq = str(Seq(seq).translate())
if "*" in aa_seq:
raise ValueError(f"The sequence '{seq}' contains a stop codon.")
aa_sequences.append(aa_seq)
return aa_sequences
def translate_codon(codon):
"""Translate a codon to an amino acid."""
if codon in TOKEN_TRANSLATIONS:
return TOKEN_TRANSLATIONS[codon]
else:
return str(Seq(codon).translate())


def translate_sequence(nt_sequence):
return translate_sequences([nt_sequence])[0]
if len(nt_sequence) % 3 != 0:
raise ValueError(f"The sequence '{nt_sequence}' is not a multiple of 3.")
aa_seq = "".join(translate_codon(nt_sequence[i: i + 3]) for i in range(0, len(nt_sequence), 3))
if "*" in aa_seq:
raise ValueError(f"The sequence '{nt_sequence}' contains a stop codon.")
return aa_seq


def translate_sequences(nt_sequences):
return [translate_sequence(seq) for seq in nt_sequences]


def aa_index_of_codon(codon):
"""Return the index of the amino acid encoded by a codon."""
aa = translate_sequence(codon)
return AA_STR_SORTED.index(aa)
return TOKEN_STR_SORTED.index(aa)


def generic_mutation_frequency(ambig_symb, parent, child):
Expand Down Expand Up @@ -160,12 +175,12 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3):
def generate_codon_aa_indicator_matrix():
"""Generate a matrix that maps codons (rows) to amino acids (columns)."""

matrix = np.zeros((len(CODONS), len(AA_STR_SORTED)))
matrix = np.zeros((len(CODONS), len(TOKEN_STR_SORTED)))

for i, codon in enumerate(CODONS):
try:
aa = translate_sequences([codon])[0]
aa_idx = AA_STR_SORTED.index(aa)
aa_idx = TOKEN_STR_SORTED.index(aa)
matrix[i, aa_idx] = 1
except ValueError: # Handle STOP codon
pass
Expand Down
7 changes: 7 additions & 0 deletions tests/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from Bio.Data import CodonTable
from netam.sequences import (
AA_STR_SORTED,
TOKEN_STR_SORTED,
CODONS,
CODON_AA_INDICATOR_MATRIX,
aa_onehot_tensor_of_str,
Expand All @@ -14,6 +15,12 @@
)


def test_token_order():
# If we always add additional tokens to the end, then converting to indices
# will not be affected when we have a proper aa string.
assert TOKEN_STR_SORTED[:len(AA_STR_SORTED)] == AA_STR_SORTED


def test_nucleotide_indices_of_codon():
assert nt_idx_array_of_str("AAA").tolist() == [0, 0, 0]
assert nt_idx_array_of_str("TAC").tolist() == [3, 0, 1]
Expand Down

0 comments on commit e8d52f6

Please sign in to comment.