Skip to content

Commit

Permalink
sequence handling WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 17, 2025
1 parent 96025ab commit 3bfb072
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 81 deletions.
82 changes: 50 additions & 32 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,38 +89,6 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def aa_strs_from_idx_tensor(idx_tensor):
"""Convert a tensor of amino acid indices back to a list of amino acid strings.
Expand Down Expand Up @@ -177,6 +145,39 @@ def aa_mask_tensor_of(*args, **kwargs):
return generic_mask_tensor_of("X", *args, **kwargs)


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)



def informative_site_count(seq_str):
return sum(c != "N" for c in seq_str)

Expand Down Expand Up @@ -429,6 +430,23 @@ def chunked(iterable, n):
yield chunk


def assume_single_sequence_is_heavy_chain(function):
"""Wraps a function that takes a heavy/light sequence pair as its first argument
and returns a tuple of results.
The wrapped function will assume that if the first argument is a string, it is a
heavy chain sequence, and in that case will return only the heavy chain result."""
@wraps(function)
def wrapper(*args, **kwargs):
seq = args[0]
if isinstance(seq, str):
seq = (seq, "")
res = function(seq, *args[1:], **kwargs)
return res[0]
else:
return function(*args, **kwargs)


def chunk_function(
first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None
):
Expand Down
14 changes: 5 additions & 9 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
# We need the model to see special tokens here. For every other purpose
# they are masked out.
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
Expand Down Expand Up @@ -207,18 +204,17 @@ def loss_of_batch(self, batch):
# TODO have a close look at these two functions, I'm feeling unsure about
# them
def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor):
"""Build a selection matrix from a parent amino acid sequence.
"""Build a selection matrix from a single parent amino acid sequence.
Inputs are expected to be as prepared in the Dataset constructor.
Values at ambiguous sites are meaningless.
"""
per_aa_selection_factors = self.model.forward(aa_parent_idxs, mask)

with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0).exp()

# TODO why 1.0?
return zap_predictions_along_diagonal(per_aa_selection_factors, aa_parent_idxs, fill=1.0)

return per_aa_selection_factors

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent nucleotide sequence.
Expand Down
37 changes: 29 additions & 8 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
# Right here is where model is evaluated!
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_mut_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
Expand Down Expand Up @@ -156,24 +157,44 @@ def loss_of_batch(self, batch):
predictions = self.predictions_of_batch(batch).masked_select(mask)
return self.bce_loss(predictions, aa_subs_indicator)


def _build_selection_matrix_from_selection_factors(self, selection_factors, aa_parent_idxs):
"""Build a selection matrix from a selection factor tensor for a single sequence.
upgrades the provided tensor containing a selection factor per site to a matrix
containing a selection factor per site and amino acid. The wildtype aa selection
factor is set ot 1, and the rest are set to the selection factor."""
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
return selection_matrix

def build_selection_matrix_from_parent_aa(self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor):
"""Build a selection matrix from a single parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0).exp()
return self._build_selection_matrix_from_selection_factors(
selection_factors, aa_parent_idxs

# TODO upgrade this to take pair of heavy and light sequences
def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a nucleotide sequence.
Values at ambiguous sites are meaningless.
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_matrix

return self._build_selection_matrix_from_selection_factors(
selection_factors, parent_idxs
)

class DNSMHyperBurrito(HyperBurrito):
# Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method.
Expand Down
21 changes: 18 additions & 3 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
nt_mutation_frequency,
strip_unrecognized_tokens_from_series,
dataset_inputs_of_pcp_df,
token_mask_of_aa_idxs,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
)


class DXSMDataset(framework.BranchLengthDataset, ABC):
prefix = "dxsm"
# Not defining model_type here; instead defining it in subclasses.
# This will raise an error if we aren't using a subclass.

def __init__(
self,
Expand Down Expand Up @@ -271,20 +273,31 @@ class DXSMBurrito(framework.Burrito, ABC):
# Not defining model_type here; instead defining it in subclasses.
# This will raise an error if we aren't using a subclass.

def selection_factors_of_aa_idxs(self, aa_idxs, aa_mask):
"""Get the log selection factors for a batch of amino acid indices.
aa_idxs and aa_mask are expected to be as prepared in the Dataset constructor."""

# We need the model to see special tokens here. For every other purpose
# they are masked out.
keep_token_mask = mask | token_mask_of_aa_idxs(aa_idxs)
return self.model(aa_idxs, keep_token_mask)


def _find_optimal_branch_length(
self,
parent,
child,
nt_rates,
nt_csps,
aa_mask,
aa_parents_indices,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
# TODO finish switching to build_selection_matrix_from_parent_aa
# thing...
sel_matrix = self.build_selection_matrix_from_parent(parent)
sel_matrix = self.build_selection_matrix_from_parent_aa(aa_parents_indices, aa_mask)
trimmed_aa_mask = aa_mask[: len(sel_matrix)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix[trimmed_aa_mask],
Expand All @@ -304,13 +317,14 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

for parent, child, nt_rates, nt_csps, aa_mask, starting_length in tqdm(
for parent, child, nt_rates, nt_csps, aa_mask, aa_parents_indices, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.nt_ratess,
dataset.nt_cspss,
dataset.masks,
dataset.aa_parents_idxss,
dataset.branch_lengths,
),
total=len(dataset.nt_parents),
Expand All @@ -322,6 +336,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
nt_rates[: len(parent)],
nt_csps[: len(parent), :],
aa_mask,
aa_parents_indices,
starting_length,
dataset.multihit_model,
**optimization_kwargs,
Expand Down
50 changes: 25 additions & 25 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
aa_mask_tensor_of,
encode_sequences,
chunk_function,
assume_single_sequence_is_heavy_chain,
)

from netam.sequences import set_wt_to_nan

from typing import Tuple

warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.nn.modules.transformer"
)
Expand Down Expand Up @@ -580,7 +583,10 @@ def predictions_of_sequences(self, sequences, **kwargs):
def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
# TODO make sure that insertion of model tokens is actually done here...
# Also check if this is used anymore...
@assume_single_sequence_is_heavy_chain
def selection_factors_of_aa_str(self, aa_sequence: Tuple[str, str]) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Expand All @@ -591,40 +597,34 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
Otherwise it should be a tuple, with the first element being the heavy chain and the second element being the light chain sequence.
Returns:
A numpy array of the same length as the input string representing
A tuple of numpy arrays of the same length as the input strings representing
the level of selection for each amino acid at each site.
If the input was a tuple of heavy/light chain sequences, the output will be a tuple of
numpy arrays.
"""

aa_str, added_indices = sequences.prepare_heavy_light_pair(*aa_sequence, self.hyperparameters["embedding_dim"])
aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
aa_idxs = aa_idxs.to(self.device)
# This makes the expected mask because of
# test_sequence.py::test_compare_mask_tensors.
# TODO write test that compares for all possible embedding_dim values
# the output of aa_mask_tensor_of and (codon_mask_tensor_of | token_mask_of_aa_idxs).
# (Here we expect those two to be the same)
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(self.device)

# Here we're ignoring sites containing tokens that have index greater
# than the embedding dimension. If extra tokens have been added since
# this model was defined, they are stripped out before feeding the
# sequence to the model, and the returned selection factors will be NaN
# at sites containing those unrecognized tokens.
model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"]
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
result = torch.full(
(len(aa_str), self.hyperparameters["output_dim"]),
float("nan"),
device=self.device,
)

with torch.no_grad():
model_out = self(
aa_idxs[model_valid_sites].unsqueeze(0),
mask[model_valid_sites].unsqueeze(0),
).squeeze(0)
result[model_valid_sites] = torch.exp(model_out)[: model_valid_sites.sum()]

return result
aa_idxs.unsqueeze(0),
mask.unsqueeze(0),
).squeeze(0).exp()

# Now split into heavy and light chain results:
sequence_mask = torch.ones(len(model_out), dtype=bool)
sequence_mask[added_indices] = False
masked_model_out = model_out[sequence_mask]
light_chain = masked_model_out[:len(aa_str[0])]
heavy_chain = masked_model_out[len(aa_str[0]):]
return light_chain, heavy_chain


class TransformerBinarySelectionModelLinAct(AbstractBinarySelectionModel):
Expand Down
26 changes: 24 additions & 2 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
DASMDataset,
zap_predictions_along_diagonal,
)
from netam.sequences import MAX_EMBEDDING_DIM
from netam.sequences import MAX_EMBEDDING_DIM, TOKEN_STR_SORTED


@pytest.fixture(scope="module")
# TODO verify that this loops through both pcp_dfs, even though one is named
# the same as the argument. If not, remember to fix in test_dnsm.py too.
@pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"])
def dasm_burrito(pcp_df):
force_spawn()
"""Fixture that returns the DNSM Burrito object."""
Expand Down Expand Up @@ -94,3 +96,23 @@ def test_zap_diagonal(dasm_burrito):
zeroed_predictions[batch_idx, i, j]
== predictions[batch_idx, i, j]
)


# TODO this won't work until build_selection_matrix_from_parent is fixed
def test_build_selection_matrix_from_parent(dasm_burrito):
dataset_row = dasm_burrito.val_dataset[0]

parent = dasm_burrito.val_dataset.nt_parents[0]
parent_aa_idxs = dasm_burrito.val_dataset.aa_parents_idxss[0]
aa_mask = dasm_burrito.val_dataset.masks[0]
aa_parent = "".join(TOKEN_STR_SORTED[i] for i in parent)

separator_idx = aa_parent.index('^') * 3
light_chain_seq = parent[:separator_idx]
heavy_chain_seq = parent[separator_idx + 3:]

direct_val = dasm_burrito.build_selection_matrix_from_parent_aa(parent_aa_idxs, aa_mask)

indirect_val = dasm_burrito.build_selection_matrix_from_parent((light_chain_seq, heavy_chain_seq))

assert torch.allclose(direct_val, indirect_val)
Loading

0 comments on commit 3bfb072

Please sign in to comment.