diff --git a/netam/common.py b/netam/common.py index c50200a3..00979d47 100644 --- a/netam/common.py +++ b/netam/common.py @@ -5,7 +5,7 @@ import subprocess from tqdm import tqdm from functools import wraps -from itertools import islice +from itertools import islice, repeat import numpy as np import torch @@ -397,39 +397,46 @@ def chunked(iterable, n): yield chunk -def chunk_method(default_chunk_size=2048, progress_bar_name=None): - """Decorator to chunk the input to a method. +def chunk_function( + first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None +): + """Decorator to chunk the input to a function. Expects that all positional arguments are iterables of the same length, and that outputs are tuples of tensors whose first dimension corresponds to the first dimension of the input iterables. - If method returns just one item, it must not be a tuple. + If function returns just one item, it must not be a tuple. Chunking is done along the first dimension of all inputs. Args: - default_chunk_size: The default chunk size. The decorated method can + default_chunk_size: The default chunk size. The decorated function can also automatically accept a `default_chunk_size` keyword argument. progress_bar_name: The name of the progress bar. If None, no progress bar is shown. """ - def decorator(method): - @wraps(method) - def wrapper(self, *args, **kwargs): + def decorator(function): + @wraps(function) + def wrapper(*args, **kwargs): if "chunk_size" in kwargs: chunk_size = kwargs.pop("chunk_size") else: chunk_size = default_chunk_size + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + results = [] if progress_bar_name is None: progargs = {"disable": True} else: progargs = {"desc": progress_bar_name} - bar = tqdm(total=len(args[0]), delay=2.0, **progargs) - for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)): + bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs) + for chunked_args in zip( + *(chunked(arg, chunk_size) for arg in chunkable_args) + ): bar.update(len(chunked_args[0])) - results.append(method(self, *chunked_args, **kwargs)) + results.append(function(*pre_chunk_args, *chunked_args, **kwargs)) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: @@ -438,3 +445,77 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator + + +def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs): + return func(*pre_chunk_args, *chunked_args, **kwargs) + + +def parallelize_function( + function, + first_chunkable_idx=0, + max_workers=10, + min_chunk_size=1000, +): + """Function to parallelize another function's application with multiprocessing. + + This is intentionally not designed to be used with decorator syntax because it should only + be used when the function it is applied to will be run on the CPU. + + Expects that all positional arguments are iterables of the same length, + and that outputs are tuples of tensors whose first dimension + corresponds to the first dimension of the input iterables. + + If function returns just one item, it must not be a tuple. + + Division between processes is done along the first dimension of all inputs. + The wrapped function will be endowed with the parallelize keyword + argument, so that parallelization can be turned on or off at each invocation. + + Args: + function: The function to be parallelized. + first_chunkable_idx: The index of the first argument to be chunked. + All positional arguments after this index will be chunked. + max_workers: The maximum number of processes to use. + min_chunk_size: The minimum chunk size for input data. The number of + workers is adjusted to ensure that the chunk size is at least this. + """ + + max_worker_count = min(mp.cpu_count() // 2, max_workers) + if max_worker_count <= 1: + return function + + @wraps(function) + def wrapper(*args, **kwargs): + if len(args) <= first_chunkable_idx: + raise ValueError( + f"Function {function.__name__} cannot be parallelized without chunkable arguments" + ) + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + min_worker_count = len(chunkable_args[0]) // min_chunk_size + + worker_count = min(min_worker_count, max_worker_count) + if worker_count <= 1: + return function(*args, **kwargs) + + chunk_size = (len(chunkable_args[0]) // worker_count) + 1 + chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) + with mp.Pool(worker_count) as pool: + results = pool.starmap( + _apply_args_and_kwargs, + list( + zip( + repeat(function), + repeat(pre_chunk_args), + chunked_args, + repeat(kwargs), + ) + ), + ) + if isinstance(results[0], tuple): + return tuple(torch.cat(tensors) for tensors in zip(*results)) + else: + return torch.cat(results) + + return wrapper diff --git a/netam/dasm.py b/netam/dasm.py index e3712fe2..0505311b 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -180,6 +180,7 @@ def loss_of_batch(self, batch): # mut_pos_loss, and we mask out sites with no substitution for the CSP # loss. The latter class of sites also eliminates sites that have Xs in # the parent or child (see sequences.aa_subs_indicator_tensor_of). + predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the @@ -195,11 +196,10 @@ def loss_of_batch(self, batch): # logit space, so we are set up for using the cross entropy loss. # However we have to mask out the sites that are not substituted, i.e. # the sites for which aa_subs_indicator is 0. - subs_mask = aa_subs_indicator == 1 + subs_mask = (aa_subs_indicator == 1) & mask csp_pred = predictions[subs_mask] csp_targets = aa_children_idxs[subs_mask] csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): diff --git a/netam/framework.py b/netam/framework.py index ad4748e2..bbd021e1 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -26,6 +26,7 @@ BIG, VRC01_NT_SEQ, encode_sequences, + parallelize_function, ) from netam import models import netam.molevol as molevol @@ -253,13 +254,13 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.model = model self.training_hyperparameters = training_hyperparameters - def __call__(self, sequences): + def __call__(self, sequences, **kwargs): """Evaluate the model on a list of sequences.""" if isinstance(sequences, str): raise ValueError( "Expected a list of sequences for call on crepe, but got a single string instead." ) - return self.model.evaluate_sequences(sequences, encoder=self.encoder) + return self.model.evaluate_sequences(sequences, encoder=self.encoder, **kwargs) @property def device(self): @@ -338,7 +339,7 @@ def crepe_exists(prefix): def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" - rates, csp_logits = crepe(parents) + rates, csp_logits = parallelize_function(crepe)(parents) rates = rates.cpu().detach() csps = torch.softmax(csp_logits, dim=-1).cpu().detach() trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)] diff --git a/netam/models.py b/netam/models.py index 93293b8e..9a67e6c8 100644 --- a/netam/models.py +++ b/netam/models.py @@ -17,7 +17,7 @@ generate_kmers, aa_mask_tensor_of, encode_sequences, - chunk_method, + chunk_function, ) warnings.filterwarnings( @@ -65,7 +65,7 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - @chunk_method(progress_bar_name="Evaluating model") + @chunk_function(first_chunkable_idx=1, progress_bar_name="Evaluating model") def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048): if encoder is None: raise ValueError("An encoder must be provided.")