Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked model evaluation #91

Merged
merged 9 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import itertools
import resource
import subprocess
from tqdm import tqdm
from functools import wraps
from itertools import islice

import numpy as np
import torch
Expand Down Expand Up @@ -380,3 +383,58 @@ def encode_sequences(sequences, encoder):
torch.stack(masks),
torch.stack(wt_base_modifiers),
)


# from https://docs.python.org/3.11/library/itertools.html#itertools-recipes
# avoiding walrus:
def chunked(iterable, n):
"Chunk data into lists of length n. The last chunk may be shorter."
it = iter(iterable)
while True:
chunk = list(islice(it, n))
if not chunk:
return
yield chunk


def chunk_method(default_chunk_size=2048, progress_bar_name=None):
"""Decorator to chunk the input to a method.

Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.

If method returns just one item, it must not be a tuple.

Chunking is done along the first dimension of all inputs.

Args:
default_chunk_size: The default chunk size. The decorated method can
also automatically accept a `default_chunk_size` keyword argument.
progress_bar_name: The name of the progress bar. If None, no progress bar is shown.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
if "chunk_size" in kwargs:
chunk_size = kwargs.pop("chunk_size")
else:
chunk_size = default_chunk_size
results = []
if progress_bar_name is None:
progargs = {"disable": True}
else:
progargs = {"desc": progress_bar_name}
bar = tqdm(total=len(args[0]), delay=2.0, **progargs)
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)):
bar.update(len(chunked_args[0]))
results.append(method(self, *chunked_args, **kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
return torch.cat(results)

return wrapper

return decorator
9 changes: 8 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def export_branch_lengths(self, out_csv_path):
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"
Expand Down Expand Up @@ -252,6 +254,11 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
"""Evaluate the model on a list of sequences."""
if isinstance(sequences, str):
raise ValueError(
"Expected a list of sequences for call on crepe, but got a single string instead."
)
return self.model.evaluate_sequences(sequences, encoder=self.encoder)

@property
Expand Down
4 changes: 3 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
chunk_method,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -64,7 +65,8 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def evaluate_sequences(self, sequences, encoder=None):
@chunk_method(progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
if encoder is None:
raise ValueError("An encoder must be provided.")
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
Expand Down
Loading