Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for NaN's when training DASM with ambiguous sequences #93

Merged
merged 7 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
from tqdm import tqdm
from functools import wraps
from itertools import islice
from itertools import islice, repeat

import numpy as np
import torch
Expand Down Expand Up @@ -397,39 +397,46 @@ def chunked(iterable, n):
yield chunk


def chunk_method(default_chunk_size=2048, progress_bar_name=None):
"""Decorator to chunk the input to a method.
def chunk_function(
first_chunkable_idx=0,
default_chunk_size=2048,
progress_bar_name=None
):
"""Decorator to chunk the input to a function.

Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.

If method returns just one item, it must not be a tuple.
If function returns just one item, it must not be a tuple.

Chunking is done along the first dimension of all inputs.

Args:
default_chunk_size: The default chunk size. The decorated method can
default_chunk_size: The default chunk size. The decorated function can
also automatically accept a `default_chunk_size` keyword argument.
progress_bar_name: The name of the progress bar. If None, no progress bar is shown.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
if "chunk_size" in kwargs:
chunk_size = kwargs.pop("chunk_size")
else:
chunk_size = default_chunk_size
pre_chunk_args = args[:first_chunkable_idx]
chunkable_args = args[first_chunkable_idx:]

results = []
if progress_bar_name is None:
progargs = {"disable": True}
else:
progargs = {"desc": progress_bar_name}
bar = tqdm(total=len(args[0]), delay=2.0, **progargs)
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)):
bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs)
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in chunkable_args)):
bar.update(len(chunked_args[0]))
results.append(method(self, *chunked_args, **kwargs))
results.append(function(*pre_chunk_args, *chunked_args, **kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
Expand All @@ -438,3 +445,71 @@ def wrapper(self, *args, **kwargs):
return wrapper

return decorator


def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs):
return func(*pre_chunk_args, *chunked_args, **kwargs)


def parallelize_function(
function,
first_chunkable_idx=0,
default_parallelize=True,
max_workers=10,
min_chunk_size=1000,
):
"""Decorator to parallelize function application with multiprocessing.

Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.

If function returns just one item, it must not be a tuple.

Division between processes is done along the first dimension of all inputs.
The wrapped function will be endowed with the parallelize keyword
argument, so that parallelization can be turned on or off at each invocation.

Args:
first_chunkable_idx: The index of the first argument to be chunked. All positional arguments after this index will be chunked.
default_parallelize: Whether to parallelize function calls by default, or require passing `parallelize=True` in invocation to parallelize.
max_workers: The maximum number of processes to use.
min_chunk_size: The minimum chunk size for input data. The number of
workers is adjusted to ensure that the chunk size is at least this.
"""

max_worker_count = min(mp.cpu_count() // 2, max_workers)
if max_worker_count <= 1:
return function

@wraps(function)
def wrapper(*args, **kwargs):
if "parallelize" in kwargs:
parallelize = kwargs.pop("parallelize")
else:
parallelize = default_parallelize
if parallelize:
if len(args) <= first_chunkable_idx:
raise ValueError(
f"Function {function.__name__} cannot be parallelized without chunkable arguments"
)
pre_chunk_args = args[:first_chunkable_idx]
chunkable_args = args[first_chunkable_idx:]
min_worker_count = (len(chunkable_args[0]) // min_chunk_size)

worker_count = min(min_worker_count, max_worker_count)
if worker_count <= 1:
return function(*args, **kwargs)

chunk_size = (len(chunkable_args[0]) // worker_count) + 1
chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args)))
with mp.Pool(worker_count) as pool:
results = pool.starmap(_apply_args_and_kwargs, list(zip(repeat(function), repeat(pre_chunk_args), chunked_args, repeat(kwargs))))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
return torch.cat(results)
else:
return function(*args, **kwargs)

return wrapper
4 changes: 2 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def loss_of_batch(self, batch):
# mut_pos_loss, and we mask out sites with no substitution for the CSP
# loss. The latter class of sites also eliminates sites that have Xs in
# the parent or child (see sequences.aa_subs_indicator_tensor_of).

predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs)

# After zapping out the diagonal, we can effectively sum over the
Expand All @@ -195,11 +196,10 @@ def loss_of_batch(self, batch):
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator == 1
subs_mask = (aa_subs_indicator == 1) & mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is it, the whole Nan/inf fix! Could you have a look around here and make sure this looks reasonable to you, too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice sleuthing! It seems good to me... am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, just wanted extra eyes on it! Thanks

csp_pred = predictions[subs_mask]
csp_targets = aa_children_idxs[subs_mask]
csp_loss = self.xent_loss(csp_pred, csp_targets)

return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
Expand Down
2 changes: 2 additions & 0 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def _find_optimal_branch_length(
return molevol.optimize_branch_length(
log_pcp_probability, starting_branch_length, **optimization_kwargs
)
# TODO for debugging
# return 0.026, False

def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
Expand Down
7 changes: 4 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BIG,
VRC01_NT_SEQ,
encode_sequences,
parallelize_function,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -253,13 +254,13 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
def __call__(self, sequences, **kwargs):
"""Evaluate the model on a list of sequences."""
if isinstance(sequences, str):
raise ValueError(
"Expected a list of sequences for call on crepe, but got a single string instead."
)
return self.model.evaluate_sequences(sequences, encoder=self.encoder)
return self.model.evaluate_sequences(sequences, encoder=self.encoder, **kwargs)

@property
def device(self):
Expand Down Expand Up @@ -338,7 +339,7 @@ def crepe_exists(prefix):

def trimmed_shm_model_outputs_of_crepe(crepe, parents):
"""Model outputs trimmed to the length of the parent sequences."""
rates, csp_logits = crepe(parents)
rates, csp_logits = parallelize_function(crepe)(parents)
rates = rates.cpu().detach()
csps = torch.softmax(csp_logits, dim=-1).cpu().detach()
trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)]
Expand Down
4 changes: 2 additions & 2 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
chunk_method,
chunk_function,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -65,7 +65,7 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

@chunk_method(progress_bar_name="Evaluating model")
@chunk_function(first_chunkable_idx=1, progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
if encoder is None:
raise ValueError("An encoder must be provided.")
Expand Down
Loading