Skip to content

Commit

Permalink
batch -> chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Dec 4, 2024
1 parent a82f2be commit 867aa95
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
32 changes: 16 additions & 16 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,49 +387,49 @@ def encode_sequences(sequences, encoder):

# from https://docs.python.org/3.11/library/itertools.html#itertools-recipes
# avoiding walrus:
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
def chunked(iterable, n):
"Chunk data into lists of length n. The last chunk may be shorter."
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
chunk = list(islice(it, n))
if not chunk:
return
yield batch
yield chunk


def batch_method(default_batch_size=2048, progress_bar_name=None):
"""Decorator to batch the input to a method.
def chunk_method(default_chunk_size=2048, progress_bar_name=None):
"""Decorator to chunk the input to a method.
Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.
If method returns just one item, it must not be a tuple.
Batching is done along the first dimension of all inputs.
Chunking is done along the first dimension of all inputs.
Args:
default_batch_size: The default batch size. The decorated method can
also automatically accept a `default_batch_size` keyword argument.
default_chunk_size: The default chunk size. The decorated method can
also automatically accept a `default_chunk_size` keyword argument.
progress_bar_name: The name of the progress bar. If None, no progress bar is shown.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
if "batch_size" in kwargs:
batch_size = kwargs.pop("batch_size")
if "chunk_size" in kwargs:
chunk_size = kwargs.pop("chunk_size")
else:
batch_size = default_batch_size
chunk_size = default_chunk_size
results = []
if progress_bar_name is None:
progargs = {"disable": True}
else:
progargs = {"desc": progress_bar_name}
bar = tqdm(total=len(args[0]), delay=2.0, **progargs)
for batched_args in zip(*(batched(arg, batch_size) for arg in args)):
bar.update(len(batched_args[0]))
results.append(method(self, *batched_args, **kwargs))
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)):
bar.update(len(chunked_args[0]))
results.append(method(self, *chunked_args, **kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
Expand Down
6 changes: 3 additions & 3 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
batch_method,
chunk_method,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -65,8 +65,8 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

@batch_method(progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, batch_size=2048):
@chunk_method(progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
if encoder is None:
raise ValueError("An encoder must be provided.")
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
Expand Down

0 comments on commit 867aa95

Please sign in to comment.