Skip to content

Commit

Permalink
Paired heavy-light modeling (#92)
Browse files Browse the repository at this point in the history
Enables training on paired heavy/light sequences. A separator token `^` is added after heavy and before light chain sequences. For now, only heavy chain sequences can be used for validation.

---------

Co-authored-by: Will Dumm <[email protected]>
  • Loading branch information
matsen and willdumm authored Jan 2, 2025
1 parent 685a84e commit 35c3efa
Show file tree
Hide file tree
Showing 16 changed files with 240 additions and 63 deletions.
Binary file not shown.
34 changes: 23 additions & 11 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
from torch import nn, Tensor
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence
from netam.sequences import (
iter_codons,
apply_aa_mask_to_nt_sequence,
RESERVED_TOKEN_TRANSLATIONS,
BASES,
AA_TOKEN_STR_SORTED,
)

BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY"
AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X"
MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1

# I needed some sequence to use to normalize the rate of mutation in the SHM model.
# So, I chose perhaps the most famous antibody sequence, VRC01:
Expand Down Expand Up @@ -65,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str):
character."""
try:
return torch.tensor(
[AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int
[AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
)
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
Expand All @@ -88,17 +89,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all("N" not in codon for codon in codons)
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
Expand All @@ -114,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor):
Args:
idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing
indices into AA_STR_SORTED_AMBIG.
indices into AA_TOKEN_STR_SORTED.
Returns:
List[str]: A list of amino acid strings with trailing 'X's removed.
Expand All @@ -123,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor):

aa_str_list = []
for row in idx_tensor:
aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist())
aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist())
aa_str_list.append(aa_str.rstrip("X"))

return aa_str_list
Expand Down
5 changes: 4 additions & 1 deletion netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
# We need the model to see special tokens here. For every other purpose
# they are masked out.
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
Expand Down
4 changes: 3 additions & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def build_selection_matrix_from_parent(self, parent: str):
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
selection_matrix = torch.zeros(
(len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float
)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
Expand Down
27 changes: 18 additions & 9 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from tqdm import tqdm

from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
stack_heterogeneous,
codon_mask_tensor_of,
Expand All @@ -28,6 +27,8 @@
translate_sequences,
apply_aa_mask_to_nt_sequence,
nt_mutation_frequency,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
)


Expand All @@ -43,8 +44,12 @@ def __init__(
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True)
# We will replace reserved tokens with Ns but use the unmodified
# originals for translation and mask creation.
self.nt_children = nt_children.str.replace(
RESERVED_TOKEN_REGEX, "N", regex=True
)
self.nt_ratess = nt_ratess
self.nt_cspss = nt_cspss
self.multihit_model = copy.deepcopy(multihit_model)
Expand All @@ -56,14 +61,16 @@ def __init__(
assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)

aa_parents = translate_sequences(self.nt_parents)
aa_children = translate_sequences(self.nt_children)
# Important to use the unmodified versions of nt_parents and
# nt_children so they still contain special tokens.
aa_parents = translate_sequences(nt_parents)
aa_children = translate_sequences(nt_children)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX
(pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX
)
self.aa_children_idxss = self.aa_parents_idxss.clone()
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))
Expand All @@ -90,7 +97,7 @@ def __init__(
)

assert torch.all(self.masks.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX
assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX

self._branch_lengths = branch_lengths
self.update_neutral_probs()
Expand Down Expand Up @@ -296,9 +303,11 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# # The following can be used when one wants a better traceback.
# The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
# return burrito.serial_find_optimal_branch_lengths(
# dataset, **optimization_kwargs
# )
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
Expand Down
81 changes: 75 additions & 6 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
optimizer_of_name,
tensor_to_np_if_needed,
BASES,
BASES_AND_N_TO_INDEX,
BIG,
VRC01_NT_SEQ,
encode_sequences,
parallelize_function,
)
from netam.sequences import BASES_AND_N_TO_INDEX
from netam import models
import netam.molevol as molevol

Expand Down Expand Up @@ -352,31 +352,100 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
return trimmed_rates, trimmed_csps


def join_chains(pcp_df):
"""Join the parent and child chains in the pcp_df.
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
If parent_h and parent_l are not present, then we assume that the parent is the
heavy chain. If only one of parent_h or parent_l is present, then we place the ^^^
padding to the right of heavy, or to the left of light.
"""
cols = pcp_df.columns
# Look for heavy chain
if "parent_h" in cols:
assert "child_h" in cols, "child_h column missing!"
assert "v_gene_h" in cols, "v_gene_h column missing!"
elif "parent" in cols:
assert "child" in cols, "child column missing!"
assert "v_gene" in cols, "v_gene column missing!"
pcp_df["parent_h"] = pcp_df["parent"]
pcp_df["child_h"] = pcp_df["child"]
pcp_df["v_gene_h"] = pcp_df["v_gene"]
else:
pcp_df["parent_h"] = ""
pcp_df["child_h"] = ""
pcp_df["v_gene_h"] = "N/A"
# Look for light chain
if "parent_l" in cols:
assert "child_l" in cols, "child_l column missing!"
assert "v_gene_l" in cols, "v_gene_l column missing!"
else:
pcp_df["parent_l"] = ""
pcp_df["child_l"] = ""
pcp_df["v_gene_l"] = "N/A"

if (pcp_df["parent_h"].str.len() + pcp_df["parent_l"].str.len()).min() < 3:
raise ValueError("At least one PCP has fewer than three nucleotides.")

pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]

pcp_df.drop(
columns=["parent_h", "parent_l", "child_h", "child_l", "v_gene"],
inplace=True,
errors="ignore",
)
return pcp_df


def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
"""Load a PCP dataframe from a gzipped CSV file.
`orig_pcp_idx` is the index column from the original file, even if we subset by
sampling or by choosing V families.
If we will join the heavy and light chain sequences into a single
sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain
sequence is present, this separator will be added to the appropriate side of the available sequence.
"""
pcp_df = (
pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0)
.reset_index()
.rename(columns={"index": "orig_pcp_idx"})
)
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
pcp_df = join_chains(pcp_df)

pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0]
pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0]
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
pcp_df = pcp_df[
pcp_df["v_family_h"].isin(chosen_v_families)
& pcp_df["v_family_l"].isin(chosen_v_families)
]
if sample_count is not None:
pcp_df = pcp_df.sample(sample_count)
pcp_df.reset_index(drop=True, inplace=True)
return pcp_df


def add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["nt_rates"] = rates
pcp_df["nt_csps"] = csps
# Split parent heavy and light chains to apply neutral model separately
split_parents = pcp_df["parent"].str.split(pat="^^^", expand=True, regex=False)
# To keep prediction aligned to joined h/l sequence, pad parent
h_parents = split_parents[0] + "NNN"
l_parents = split_parents[1]

h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents)
l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents)
# Join predictions
pcp_df["nt_rates"] = [
torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates)
]
pcp_df["nt_csps"] = [
torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps)
]
return pcp_df


Expand Down
4 changes: 2 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
PositionalEncoding,
generate_kmers,
Expand Down Expand Up @@ -622,7 +622,7 @@ def __init__(
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand Down
4 changes: 2 additions & 2 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, optim

from netam.sequences import CODON_AA_INDICATOR_MATRIX
from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX

import netam.sequences as sequences

Expand Down Expand Up @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of(
"""

assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, 20)
assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1)

parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)
Expand Down
Loading

0 comments on commit 35c3efa

Please sign in to comment.