Skip to content

Commit

Permalink
a sketch
Browse files Browse the repository at this point in the history
WIP

WIP

next: fix nonzero subs prob for ^

WIP

WIP

remove joined_mode, make everything be joined_mode always

EOD commit: working on recovering OE plotting with changes

the final solution?

partial cleanup

some cleanup

more cleanup
  • Loading branch information
matsen authored and willdumm committed Dec 19, 2024
1 parent 685a84e commit dda468c
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 53 deletions.
31 changes: 22 additions & 9 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from torch import nn, Tensor
import multiprocessing as mp

from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence
from netam.sequences import (
iter_codons,
apply_aa_mask_to_nt_sequence,
TOKEN_TRANSLATIONS,
BASES,
BASES_AND_N_TO_INDEX,
TOKEN_STR_SORTED
)

BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY"
AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X"
MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1

# I needed some sequence to use to normalize the rate of mutation in the SHM model.
# So, I chose perhaps the most famous antibody sequence, VRC01:
Expand Down Expand Up @@ -65,7 +67,7 @@ def aa_idx_tensor_of_str_ambig(aa_str):
character."""
try:
return torch.tensor(
[AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int
[TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int
)
except ValueError:
print(f"Found an invalid amino acid in the string: {aa_str}")
Expand All @@ -88,17 +90,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences
the "and" mask will be computed for all sequences.
Codons containing marker tokens are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all("N" not in codon for codon in codons)
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
Expand Down
9 changes: 7 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def update_neutral_probs(self):
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
# TODO handle this some other way
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]
nt_mask = mask.repeat_interleave(3)[: len(nt_parent)]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask])
molevol.check_csps(
parent_idxs[nt_mask],
nt_csps[: len(nt_parent)][nt_mask]
)

neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
Expand Down Expand Up @@ -139,7 +143,8 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
Expand Down
25 changes: 15 additions & 10 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from tqdm import tqdm

from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
stack_heterogeneous,
codon_mask_tensor_of,
Expand All @@ -28,6 +27,8 @@
translate_sequences,
apply_aa_mask_to_nt_sequence,
nt_mutation_frequency,
MAX_AA_TOKEN_IDX,
TOKEN_REGEX,
)


Expand All @@ -43,8 +44,9 @@ def __init__(
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
# TODO test this replacement
self.nt_parents = nt_parents.str.replace(TOKEN_REGEX, "N", regex=True)
self.nt_children = nt_children.str.replace(TOKEN_REGEX, "N", regex=True)
self.nt_ratess = nt_ratess
self.nt_cspss = nt_cspss
self.multihit_model = copy.deepcopy(multihit_model)
Expand All @@ -56,14 +58,16 @@ def __init__(
assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)

aa_parents = translate_sequences(self.nt_parents)
aa_children = translate_sequences(self.nt_children)
# Important to use the unmodified versions of nt_parents and
# nt_children so they still contain special tokens.
aa_parents = translate_sequences(nt_parents)
aa_children = translate_sequences(nt_children)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX
(pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX
)
self.aa_children_idxss = self.aa_parents_idxss.clone()
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))
Expand All @@ -90,7 +94,7 @@ def __init__(
)

assert torch.all(self.masks.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX
assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX

self._branch_lengths = branch_lengths
self.update_neutral_probs()
Expand Down Expand Up @@ -296,9 +300,10 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# # The following can be used when one wants a better traceback.
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
# TODO
# The following can be used when one wants a better traceback.
burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,
Expand Down
75 changes: 68 additions & 7 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import os
from time import time
from warnings import warn

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -352,31 +353,91 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
return trimmed_rates, trimmed_csps


def join_chains(pcp_df):
"""Join the parent and child chains in the pcp_df.
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
If parent_h and parent_l are not present, then we assume that the parent is the heavy chain.
If only one of parent_h or parent_l is present, then we place the ^^^ padding to the right of
heavy, or to the left of light.
"""
cols = pcp_df.columns
if "parent_h" in cols and "parent_l" in cols:
assert "child_h" in cols and "child_l" in cols, "child_h or child_l columns missing!"
pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]
elif "parent_h" in cols and "parent_l" not in cols:
assert "child_h" in cols, "child_h column missing!"
pcp_df["parent"] = pcp_df["parent_h"] + "^^^"
pcp_df["child"] = pcp_df["child_h"] + "^^^"
elif "parent_h" not in cols and "parent_l" in cols:
if "parent" in cols:
warn("Both parent and parent_l columns found. Using only parent_l. "
"To use parent as heavy chain, rename to parent_h.")
assert "child_l" in cols, "child_l column missing!"
pcp_df["parent"] = "^^^" + pcp_df["parent_l"]
pcp_df["child"] = "^^^" + pcp_df["child_l"]
elif "parent" in cols:
assert "child" in cols, "child column missing!"
# We assume that this is the heavy chain.
pcp_df["parent"] += "^^^"
pcp_df["child"] += "^^^"
else:
raise ValueError("Could not find parent and child columns.")
pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore")
return pcp_df


def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
"""Load a PCP dataframe from a gzipped CSV file.
`orig_pcp_idx` is the index column from the original file, even if we subset by
sampling or by choosing V families.
If we will join the heavy and light chain sequences into a single
sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain
sequence is present, this separator will be added to the appropriate side of the available sequence.
"""
pcp_df = (
pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0)
.reset_index()
.rename(columns={"index": "orig_pcp_idx"})
)
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
pcp_df = join_chains(pcp_df)
if not ("parent" in pcp_df.columns and "child" in pcp_df.columns):
if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns:
pcp_df["parent"] = pcp_df["parent_h"]
pcp_df["child"] = pcp_df["child_h"]
pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore")
else:
raise ValueError(
"Could not find parent and child columns. "
)

# figure out what to do here: TODO this is only needed for oe plotting, but
# the way its set up will fail without a helpful message.
if "v_gene" in pcp_df.columns:
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
if sample_count is not None:
pcp_df = pcp_df.sample(sample_count)
pcp_df.reset_index(drop=True, inplace=True)
return pcp_df


def add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["nt_rates"] = rates
pcp_df["nt_csps"] = csps
# TODO what happens when one of these is empty? or if there's no split?
split_parents = pcp_df["parent"].copy().str.split(pat="^^^", expand=True, regex=False)
h_parents = split_parents[0] + "NNN"
l_parents = split_parents[1]

h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents)
l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents)
pcp_df["nt_rates"] = [torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates)]
pcp_df["nt_csps"] = [torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps)]
return pcp_df


Expand Down
4 changes: 2 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch import Tensor

from netam.hit_class import apply_multihit_correction
from netam.sequences import MAX_AA_TOKEN_IDX
from netam.common import (
MAX_AMBIG_AA_IDX,
aa_idx_tensor_of_str_ambig,
PositionalEncoding,
generate_kmers,
Expand Down Expand Up @@ -622,7 +622,7 @@ def __init__(
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model)
self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
Expand Down
4 changes: 2 additions & 2 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor, optim

from netam.sequences import CODON_AA_INDICATOR_MATRIX
from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX

import netam.sequences as sequences

Expand Down Expand Up @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of(
"""

assert len(parent) % 3 == 0
assert sel_matrix.shape == (len(parent) // 3, 20)
assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1)

parent_idxs = sequences.nt_idx_tensor_of_str(parent)
child_idxs = sequences.nt_idx_tensor_of_str(child)
Expand Down
Loading

0 comments on commit dda468c

Please sign in to comment.