From 22c8873e91e462e83f0bd9d0182db3f996cc461c Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 10 Dec 2024 11:46:30 -0800 Subject: [PATCH] Fix for NaN's when training DASM with ambiguous sequences (#93) One line fix for DASM training, applying a mask that should have been applied from the very beginning, but previously didn't matter much since the data didn't contain ambiguities. Also adds parallelized neutral model application, moved to the cpu. This makes running some notebooks much faster, and also speeds up setup for training Snakemake runs. --- netam/common.py | 103 ++++++++++++++++++++++++++++++++++++++++----- netam/dasm.py | 4 +- netam/framework.py | 7 +-- netam/models.py | 4 +- 4 files changed, 100 insertions(+), 18 deletions(-) diff --git a/netam/common.py b/netam/common.py index c50200a3..00979d47 100644 --- a/netam/common.py +++ b/netam/common.py @@ -5,7 +5,7 @@ import subprocess from tqdm import tqdm from functools import wraps -from itertools import islice +from itertools import islice, repeat import numpy as np import torch @@ -397,39 +397,46 @@ def chunked(iterable, n): yield chunk -def chunk_method(default_chunk_size=2048, progress_bar_name=None): - """Decorator to chunk the input to a method. +def chunk_function( + first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None +): + """Decorator to chunk the input to a function. Expects that all positional arguments are iterables of the same length, and that outputs are tuples of tensors whose first dimension corresponds to the first dimension of the input iterables. - If method returns just one item, it must not be a tuple. + If function returns just one item, it must not be a tuple. Chunking is done along the first dimension of all inputs. Args: - default_chunk_size: The default chunk size. The decorated method can + default_chunk_size: The default chunk size. The decorated function can also automatically accept a `default_chunk_size` keyword argument. progress_bar_name: The name of the progress bar. If None, no progress bar is shown. """ - def decorator(method): - @wraps(method) - def wrapper(self, *args, **kwargs): + def decorator(function): + @wraps(function) + def wrapper(*args, **kwargs): if "chunk_size" in kwargs: chunk_size = kwargs.pop("chunk_size") else: chunk_size = default_chunk_size + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + results = [] if progress_bar_name is None: progargs = {"disable": True} else: progargs = {"desc": progress_bar_name} - bar = tqdm(total=len(args[0]), delay=2.0, **progargs) - for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)): + bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs) + for chunked_args in zip( + *(chunked(arg, chunk_size) for arg in chunkable_args) + ): bar.update(len(chunked_args[0])) - results.append(method(self, *chunked_args, **kwargs)) + results.append(function(*pre_chunk_args, *chunked_args, **kwargs)) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: @@ -438,3 +445,77 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator + + +def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs): + return func(*pre_chunk_args, *chunked_args, **kwargs) + + +def parallelize_function( + function, + first_chunkable_idx=0, + max_workers=10, + min_chunk_size=1000, +): + """Function to parallelize another function's application with multiprocessing. + + This is intentionally not designed to be used with decorator syntax because it should only + be used when the function it is applied to will be run on the CPU. + + Expects that all positional arguments are iterables of the same length, + and that outputs are tuples of tensors whose first dimension + corresponds to the first dimension of the input iterables. + + If function returns just one item, it must not be a tuple. + + Division between processes is done along the first dimension of all inputs. + The wrapped function will be endowed with the parallelize keyword + argument, so that parallelization can be turned on or off at each invocation. + + Args: + function: The function to be parallelized. + first_chunkable_idx: The index of the first argument to be chunked. + All positional arguments after this index will be chunked. + max_workers: The maximum number of processes to use. + min_chunk_size: The minimum chunk size for input data. The number of + workers is adjusted to ensure that the chunk size is at least this. + """ + + max_worker_count = min(mp.cpu_count() // 2, max_workers) + if max_worker_count <= 1: + return function + + @wraps(function) + def wrapper(*args, **kwargs): + if len(args) <= first_chunkable_idx: + raise ValueError( + f"Function {function.__name__} cannot be parallelized without chunkable arguments" + ) + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + min_worker_count = len(chunkable_args[0]) // min_chunk_size + + worker_count = min(min_worker_count, max_worker_count) + if worker_count <= 1: + return function(*args, **kwargs) + + chunk_size = (len(chunkable_args[0]) // worker_count) + 1 + chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) + with mp.Pool(worker_count) as pool: + results = pool.starmap( + _apply_args_and_kwargs, + list( + zip( + repeat(function), + repeat(pre_chunk_args), + chunked_args, + repeat(kwargs), + ) + ), + ) + if isinstance(results[0], tuple): + return tuple(torch.cat(tensors) for tensors in zip(*results)) + else: + return torch.cat(results) + + return wrapper diff --git a/netam/dasm.py b/netam/dasm.py index e3712fe2..0505311b 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -180,6 +180,7 @@ def loss_of_batch(self, batch): # mut_pos_loss, and we mask out sites with no substitution for the CSP # loss. The latter class of sites also eliminates sites that have Xs in # the parent or child (see sequences.aa_subs_indicator_tensor_of). + predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the @@ -195,11 +196,10 @@ def loss_of_batch(self, batch): # logit space, so we are set up for using the cross entropy loss. # However we have to mask out the sites that are not substituted, i.e. # the sites for which aa_subs_indicator is 0. - subs_mask = aa_subs_indicator == 1 + subs_mask = (aa_subs_indicator == 1) & mask csp_pred = predictions[subs_mask] csp_targets = aa_children_idxs[subs_mask] csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): diff --git a/netam/framework.py b/netam/framework.py index ad4748e2..bbd021e1 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -26,6 +26,7 @@ BIG, VRC01_NT_SEQ, encode_sequences, + parallelize_function, ) from netam import models import netam.molevol as molevol @@ -253,13 +254,13 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.model = model self.training_hyperparameters = training_hyperparameters - def __call__(self, sequences): + def __call__(self, sequences, **kwargs): """Evaluate the model on a list of sequences.""" if isinstance(sequences, str): raise ValueError( "Expected a list of sequences for call on crepe, but got a single string instead." ) - return self.model.evaluate_sequences(sequences, encoder=self.encoder) + return self.model.evaluate_sequences(sequences, encoder=self.encoder, **kwargs) @property def device(self): @@ -338,7 +339,7 @@ def crepe_exists(prefix): def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" - rates, csp_logits = crepe(parents) + rates, csp_logits = parallelize_function(crepe)(parents) rates = rates.cpu().detach() csps = torch.softmax(csp_logits, dim=-1).cpu().detach() trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)] diff --git a/netam/models.py b/netam/models.py index 93293b8e..9a67e6c8 100644 --- a/netam/models.py +++ b/netam/models.py @@ -17,7 +17,7 @@ generate_kmers, aa_mask_tensor_of, encode_sequences, - chunk_method, + chunk_function, ) warnings.filterwarnings( @@ -65,7 +65,7 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - @chunk_method(progress_bar_name="Evaluating model") + @chunk_function(first_chunkable_idx=1, progress_bar_name="Evaluating model") def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048): if encoder is None: raise ValueError("An encoder must be provided.")