Skip to content

Commit

Permalink
Fix for NaN's when training DASM with ambiguous sequences (#93)
Browse files Browse the repository at this point in the history
One line fix for DASM training, applying a mask that should have been applied from the very beginning, but previously didn't matter much since the data didn't contain ambiguities.

Also adds parallelized neutral model application, moved to the cpu. This makes running some notebooks much faster, and also speeds up setup for training Snakemake runs.
  • Loading branch information
willdumm authored Dec 10, 2024
1 parent b59adbd commit 22c8873
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 18 deletions.
103 changes: 92 additions & 11 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
from tqdm import tqdm
from functools import wraps
from itertools import islice
from itertools import islice, repeat

import numpy as np
import torch
Expand Down Expand Up @@ -397,39 +397,46 @@ def chunked(iterable, n):
yield chunk


def chunk_method(default_chunk_size=2048, progress_bar_name=None):
"""Decorator to chunk the input to a method.
def chunk_function(
first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None
):
"""Decorator to chunk the input to a function.
Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.
If method returns just one item, it must not be a tuple.
If function returns just one item, it must not be a tuple.
Chunking is done along the first dimension of all inputs.
Args:
default_chunk_size: The default chunk size. The decorated method can
default_chunk_size: The default chunk size. The decorated function can
also automatically accept a `default_chunk_size` keyword argument.
progress_bar_name: The name of the progress bar. If None, no progress bar is shown.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
if "chunk_size" in kwargs:
chunk_size = kwargs.pop("chunk_size")
else:
chunk_size = default_chunk_size
pre_chunk_args = args[:first_chunkable_idx]
chunkable_args = args[first_chunkable_idx:]

results = []
if progress_bar_name is None:
progargs = {"disable": True}
else:
progargs = {"desc": progress_bar_name}
bar = tqdm(total=len(args[0]), delay=2.0, **progargs)
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)):
bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs)
for chunked_args in zip(
*(chunked(arg, chunk_size) for arg in chunkable_args)
):
bar.update(len(chunked_args[0]))
results.append(method(self, *chunked_args, **kwargs))
results.append(function(*pre_chunk_args, *chunked_args, **kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
Expand All @@ -438,3 +445,77 @@ def wrapper(self, *args, **kwargs):
return wrapper

return decorator


def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs):
return func(*pre_chunk_args, *chunked_args, **kwargs)


def parallelize_function(
function,
first_chunkable_idx=0,
max_workers=10,
min_chunk_size=1000,
):
"""Function to parallelize another function's application with multiprocessing.
This is intentionally not designed to be used with decorator syntax because it should only
be used when the function it is applied to will be run on the CPU.
Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.
If function returns just one item, it must not be a tuple.
Division between processes is done along the first dimension of all inputs.
The wrapped function will be endowed with the parallelize keyword
argument, so that parallelization can be turned on or off at each invocation.
Args:
function: The function to be parallelized.
first_chunkable_idx: The index of the first argument to be chunked.
All positional arguments after this index will be chunked.
max_workers: The maximum number of processes to use.
min_chunk_size: The minimum chunk size for input data. The number of
workers is adjusted to ensure that the chunk size is at least this.
"""

max_worker_count = min(mp.cpu_count() // 2, max_workers)
if max_worker_count <= 1:
return function

@wraps(function)
def wrapper(*args, **kwargs):
if len(args) <= first_chunkable_idx:
raise ValueError(
f"Function {function.__name__} cannot be parallelized without chunkable arguments"
)
pre_chunk_args = args[:first_chunkable_idx]
chunkable_args = args[first_chunkable_idx:]
min_worker_count = len(chunkable_args[0]) // min_chunk_size

worker_count = min(min_worker_count, max_worker_count)
if worker_count <= 1:
return function(*args, **kwargs)

chunk_size = (len(chunkable_args[0]) // worker_count) + 1
chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args)))
with mp.Pool(worker_count) as pool:
results = pool.starmap(
_apply_args_and_kwargs,
list(
zip(
repeat(function),
repeat(pre_chunk_args),
chunked_args,
repeat(kwargs),
)
),
)
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
return torch.cat(results)

return wrapper
4 changes: 2 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def loss_of_batch(self, batch):
# mut_pos_loss, and we mask out sites with no substitution for the CSP
# loss. The latter class of sites also eliminates sites that have Xs in
# the parent or child (see sequences.aa_subs_indicator_tensor_of).

predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zapping out the diagonal, we can effectively sum over the
Expand All @@ -195,11 +196,10 @@ def loss_of_batch(self, batch):
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator == 1
subs_mask = (aa_subs_indicator == 1) & mask
csp_pred = predictions[subs_mask]
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
Expand Down
7 changes: 4 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BIG,
VRC01_NT_SEQ,
encode_sequences,
parallelize_function,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -253,13 +254,13 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
def __call__(self, sequences, **kwargs):
"""Evaluate the model on a list of sequences."""
if isinstance(sequences, str):
raise ValueError(
"Expected a list of sequences for call on crepe, but got a single string instead."
)
return self.model.evaluate_sequences(sequences, encoder=self.encoder)
return self.model.evaluate_sequences(sequences, encoder=self.encoder, **kwargs)

@property
def device(self):
Expand Down Expand Up @@ -338,7 +339,7 @@ def crepe_exists(prefix):

def trimmed_shm_model_outputs_of_crepe(crepe, parents):
"""Model outputs trimmed to the length of the parent sequences."""
rates, csp_logits = crepe(parents)
rates, csp_logits = parallelize_function(crepe)(parents)
rates = rates.cpu().detach()
csps = torch.softmax(csp_logits, dim=-1).cpu().detach()
trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)]
Expand Down
4 changes: 2 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
chunk_method,
chunk_function,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -65,7 +65,7 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

@chunk_method(progress_bar_name="Evaluating model")
@chunk_function(first_chunkable_idx=1, progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
if encoder is None:
raise ValueError("An encoder must be provided.")
Expand Down

0 comments on commit 22c8873

Please sign in to comment.