diff --git a/netam/common.py b/netam/common.py index 0250aa83..c50200a3 100644 --- a/netam/common.py +++ b/netam/common.py @@ -3,6 +3,9 @@ import itertools import resource import subprocess +from tqdm import tqdm +from functools import wraps +from itertools import islice import numpy as np import torch @@ -380,3 +383,58 @@ def encode_sequences(sequences, encoder): torch.stack(masks), torch.stack(wt_base_modifiers), ) + + +# from https://docs.python.org/3.11/library/itertools.html#itertools-recipes +# avoiding walrus: +def chunked(iterable, n): + "Chunk data into lists of length n. The last chunk may be shorter." + it = iter(iterable) + while True: + chunk = list(islice(it, n)) + if not chunk: + return + yield chunk + + +def chunk_method(default_chunk_size=2048, progress_bar_name=None): + """Decorator to chunk the input to a method. + + Expects that all positional arguments are iterables of the same length, + and that outputs are tuples of tensors whose first dimension + corresponds to the first dimension of the input iterables. + + If method returns just one item, it must not be a tuple. + + Chunking is done along the first dimension of all inputs. + + Args: + default_chunk_size: The default chunk size. The decorated method can + also automatically accept a `default_chunk_size` keyword argument. + progress_bar_name: The name of the progress bar. If None, no progress bar is shown. + """ + + def decorator(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + if "chunk_size" in kwargs: + chunk_size = kwargs.pop("chunk_size") + else: + chunk_size = default_chunk_size + results = [] + if progress_bar_name is None: + progargs = {"disable": True} + else: + progargs = {"desc": progress_bar_name} + bar = tqdm(total=len(args[0]), delay=2.0, **progargs) + for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)): + bar.update(len(chunked_args[0])) + results.append(method(self, *chunked_args, **kwargs)) + if isinstance(results[0], tuple): + return tuple(torch.cat(tensors) for tensors in zip(*results)) + else: + return torch.cat(results) + + return wrapper + + return decorator diff --git a/netam/framework.py b/netam/framework.py index be7cf500..ad4748e2 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -143,7 +143,9 @@ def export_branch_lengths(self, out_csv_path): ) def load_branch_lengths(self, in_csv_path): - self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values + self.branch_lengths = torch.Tensor( + pd.read_csv(in_csv_path)["branch_length"].values + ) def __repr__(self): return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}" @@ -252,6 +254,11 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.training_hyperparameters = training_hyperparameters def __call__(self, sequences): + """Evaluate the model on a list of sequences.""" + if isinstance(sequences, str): + raise ValueError( + "Expected a list of sequences for call on crepe, but got a single string instead." + ) return self.model.evaluate_sequences(sequences, encoder=self.encoder) @property diff --git a/netam/models.py b/netam/models.py index e5b13673..93293b8e 100644 --- a/netam/models.py +++ b/netam/models.py @@ -17,6 +17,7 @@ generate_kmers, aa_mask_tensor_of, encode_sequences, + chunk_method, ) warnings.filterwarnings( @@ -64,7 +65,8 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - def evaluate_sequences(self, sequences, encoder=None): + @chunk_method(progress_bar_name="Evaluating model") + def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048): if encoder is None: raise ValueError("An encoder must be provided.") encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)