From 6bc541a958123532c61e46ae4ad5f2c3eaa43f0b Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 5 Dec 2024 14:11:22 -0800 Subject: [PATCH 1/7] debugging --- netam/dxsm.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/netam/dxsm.py b/netam/dxsm.py index 9c7157e3..2b65820a 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -240,21 +240,23 @@ def _find_optimal_branch_length( multihit_model, **optimization_kwargs, ): - sel_matrix = self.build_selection_matrix_from_parent(parent) - trimmed_aa_mask = aa_mask[: len(sel_matrix)] - log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - sel_matrix[trimmed_aa_mask], - apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), - apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), - nt_rates[trimmed_aa_mask.repeat_interleave(3)], - nt_csps[trimmed_aa_mask.repeat_interleave(3)], - multihit_model, - ) - if isinstance(starting_branch_length, torch.Tensor): - starting_branch_length = starting_branch_length.detach().item() - return molevol.optimize_branch_length( - log_pcp_probability, starting_branch_length, **optimization_kwargs - ) + # sel_matrix = self.build_selection_matrix_from_parent(parent) + # trimmed_aa_mask = aa_mask[: len(sel_matrix)] + # log_pcp_probability = molevol.mutsel_log_pcp_probability_of( + # sel_matrix[trimmed_aa_mask], + # apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), + # apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), + # nt_rates[trimmed_aa_mask.repeat_interleave(3)], + # nt_csps[trimmed_aa_mask.repeat_interleave(3)], + # multihit_model, + # ) + # if isinstance(starting_branch_length, torch.Tensor): + # starting_branch_length = starting_branch_length.detach().item() + # return molevol.optimize_branch_length( + # log_pcp_probability, starting_branch_length, **optimization_kwargs + # ) + # TODO for debugging + return 0.026, False def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] @@ -295,9 +297,10 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # # The following can be used when one wants a better traceback. - # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # The following can be used when one wants a better traceback. + burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # TODO comment that^ our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, From 2f98e4ab4abb2e605b2bbea3d4578ae25a620737 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 5 Dec 2024 15:50:33 -0800 Subject: [PATCH 2/7] parallelize function wrapper --- netam/common.py | 52 +++++++++++++++++++++++++++++++++++++++++++++- netam/dxsm.py | 39 +++++++++++++++++----------------- netam/framework.py | 3 ++- 3 files changed, 72 insertions(+), 22 deletions(-) diff --git a/netam/common.py b/netam/common.py index c50200a3..9c16568f 100644 --- a/netam/common.py +++ b/netam/common.py @@ -5,7 +5,7 @@ import subprocess from tqdm import tqdm from functools import wraps -from itertools import islice +from itertools import islice, repeat import numpy as np import torch @@ -438,3 +438,53 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator + + +def apply_args_and_kwargs(func, args, kwargs): + return func(*args, **kwargs) + +def parallelize_function(max_workers=10, min_chunk_size=50): + """Decorator to parallelize function application with multiprocessing. + + Expects that all positional arguments are iterables of the same length, + and that outputs are tuples of tensors whose first dimension + corresponds to the first dimension of the input iterables. + + If function returns just one item, it must not be a tuple. + + Division between processes is done along the first dimension of all inputs. + + Args: + max_workers: The maximum number of processes to use. + min_chunk_size: The minimum chunk size for input data. The number of + workers is adjusted to ensure that the chunk size is at least this. + """ + + def decorator(function): + max_worker_count = min(mp.cpu_count() // 2, max_workers) + if max_worker_count <= 1: + return function + + @wraps(function) + def wrapper(*args, **kwargs): + min_worker_count = (len(args[0]) // min_chunk_size) + + worker_count = min(min_worker_count, max_worker_count) + if worker_count <= 1: + return function(*args, **kwargs) + + def worker_func(*chunked_args): + return function(*chunked_args, **kwargs) + + chunk_size = (len(args[0]) // worker_count) + 1 + chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in args))) + with mp.Pool(worker_count) as pool: + results = pool.starmap(apply_args_and_kwargs, list(zip(repeat(function), chunked_args, repeat(kwargs)))) + if isinstance(results[0], tuple): + return tuple(torch.cat(tensors) for tensors in zip(*results)) + else: + return torch.cat(results) + + return wrapper + + return decorator diff --git a/netam/dxsm.py b/netam/dxsm.py index 2b65820a..38a67de2 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -240,23 +240,23 @@ def _find_optimal_branch_length( multihit_model, **optimization_kwargs, ): - # sel_matrix = self.build_selection_matrix_from_parent(parent) - # trimmed_aa_mask = aa_mask[: len(sel_matrix)] - # log_pcp_probability = molevol.mutsel_log_pcp_probability_of( - # sel_matrix[trimmed_aa_mask], - # apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), - # apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), - # nt_rates[trimmed_aa_mask.repeat_interleave(3)], - # nt_csps[trimmed_aa_mask.repeat_interleave(3)], - # multihit_model, - # ) - # if isinstance(starting_branch_length, torch.Tensor): - # starting_branch_length = starting_branch_length.detach().item() - # return molevol.optimize_branch_length( - # log_pcp_probability, starting_branch_length, **optimization_kwargs - # ) + sel_matrix = self.build_selection_matrix_from_parent(parent) + trimmed_aa_mask = aa_mask[: len(sel_matrix)] + log_pcp_probability = molevol.mutsel_log_pcp_probability_of( + sel_matrix[trimmed_aa_mask], + apply_aa_mask_to_nt_sequence(parent, trimmed_aa_mask), + apply_aa_mask_to_nt_sequence(child, trimmed_aa_mask), + nt_rates[trimmed_aa_mask.repeat_interleave(3)], + nt_csps[trimmed_aa_mask.repeat_interleave(3)], + multihit_model, + ) + if isinstance(starting_branch_length, torch.Tensor): + starting_branch_length = starting_branch_length.detach().item() + return molevol.optimize_branch_length( + log_pcp_probability, starting_branch_length, **optimization_kwargs + ) # TODO for debugging - return 0.026, False + # return 0.026, False def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] @@ -297,10 +297,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - # TODO comment that^ + # # The following can be used when one wants a better traceback. + # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, diff --git a/netam/framework.py b/netam/framework.py index ad4748e2..5402ea2d 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -26,6 +26,7 @@ BIG, VRC01_NT_SEQ, encode_sequences, + parallelize_function, ) from netam import models import netam.molevol as molevol @@ -338,7 +339,7 @@ def crepe_exists(prefix): def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" - rates, csp_logits = crepe(parents) + rates, csp_logits = parallelize_function()(crepe)(parents) rates = rates.cpu().detach() csps = torch.softmax(csp_logits, dim=-1).cpu().detach() trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)] From 58fe0275643eb9308805cde6e7963b879ea6f31e Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Sun, 8 Dec 2024 21:13:13 -0800 Subject: [PATCH 3/7] updated parallelization --- netam/common.py | 87 ++++++++++++++++++++++++++++++---------------- netam/framework.py | 7 ++-- netam/models.py | 4 +-- 3 files changed, 62 insertions(+), 36 deletions(-) diff --git a/netam/common.py b/netam/common.py index 9c16568f..f72f7527 100644 --- a/netam/common.py +++ b/netam/common.py @@ -397,39 +397,47 @@ def chunked(iterable, n): yield chunk -def chunk_method(default_chunk_size=2048, progress_bar_name=None): - """Decorator to chunk the input to a method. +def chunk_function( + first_chunkable_idx=0, + default_chunk_size=2048, + progress_bar_name=None +): + """Decorator to chunk the input to a function. Expects that all positional arguments are iterables of the same length, and that outputs are tuples of tensors whose first dimension corresponds to the first dimension of the input iterables. - If method returns just one item, it must not be a tuple. + If function returns just one item, it must not be a tuple. Chunking is done along the first dimension of all inputs. Args: - default_chunk_size: The default chunk size. The decorated method can + default_chunk_size: The default chunk size. The decorated function can also automatically accept a `default_chunk_size` keyword argument. progress_bar_name: The name of the progress bar. If None, no progress bar is shown. """ - def decorator(method): - @wraps(method) - def wrapper(self, *args, **kwargs): + def decorator(function): + @wraps(function) + def wrapper(*args, **kwargs): if "chunk_size" in kwargs: chunk_size = kwargs.pop("chunk_size") else: chunk_size = default_chunk_size + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + print("chunk got args len ", len(chunkable_args[0])) + results = [] if progress_bar_name is None: progargs = {"disable": True} else: progargs = {"desc": progress_bar_name} - bar = tqdm(total=len(args[0]), delay=2.0, **progargs) - for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)): + bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs) + for chunked_args in zip(*(chunked(arg, chunk_size) for arg in chunkable_args)): bar.update(len(chunked_args[0])) - results.append(method(self, *chunked_args, **kwargs)) + results.append(function(*pre_chunk_args, *chunked_args, **kwargs)) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: @@ -440,10 +448,17 @@ def wrapper(self, *args, **kwargs): return decorator -def apply_args_and_kwargs(func, args, kwargs): - return func(*args, **kwargs) +def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs): + return func(*pre_chunk_args, *chunked_args, **kwargs) + -def parallelize_function(max_workers=10, min_chunk_size=50): +def parallelize_function( + function, + first_chunkable_idx=0, + default_parallelize=True, + max_workers=10, + min_chunk_size=1000, +): """Decorator to parallelize function application with multiprocessing. Expects that all positional arguments are iterables of the same length, @@ -453,38 +468,50 @@ def parallelize_function(max_workers=10, min_chunk_size=50): If function returns just one item, it must not be a tuple. Division between processes is done along the first dimension of all inputs. + The wrapped function will be endowed with the parallelize keyword + argument, so that parallelization can be turned on or off at each invocation. Args: + first_chunkable_idx: The index of the first argument to be chunked. All positional arguments after this index will be chunked. + default_parallelize: Whether to parallelize function calls by default, or require passing `parallelize=True` in invocation to parallelize. max_workers: The maximum number of processes to use. min_chunk_size: The minimum chunk size for input data. The number of workers is adjusted to ensure that the chunk size is at least this. """ - def decorator(function): - max_worker_count = min(mp.cpu_count() // 2, max_workers) - if max_worker_count <= 1: - return function - - @wraps(function) - def wrapper(*args, **kwargs): - min_worker_count = (len(args[0]) // min_chunk_size) + max_worker_count = min(mp.cpu_count() // 2, max_workers) + if max_worker_count <= 1: + return function + + @wraps(function) + def wrapper(*args, **kwargs): + if "parallelize" in kwargs: + parallelize = kwargs.pop("parallelize") + else: + parallelize = default_parallelize + if parallelize: + if len(args) <= first_chunkable_idx: + raise ValueError( + f"Function {function.__name__} cannot be parallelized without chunkable arguments" + ) + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + print("parallelize got args len ", len(chunkable_args[0])) + min_worker_count = (len(chunkable_args[0]) // min_chunk_size) worker_count = min(min_worker_count, max_worker_count) if worker_count <= 1: return function(*args, **kwargs) - def worker_func(*chunked_args): - return function(*chunked_args, **kwargs) - - chunk_size = (len(args[0]) // worker_count) + 1 - chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in args))) + chunk_size = (len(chunkable_args[0]) // worker_count) + 1 + chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) with mp.Pool(worker_count) as pool: - results = pool.starmap(apply_args_and_kwargs, list(zip(repeat(function), chunked_args, repeat(kwargs)))) + results = pool.starmap(_apply_args_and_kwargs, list(zip(repeat(function), repeat(pre_chunk_args), chunked_args, repeat(kwargs)))) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: return torch.cat(results) + else: + return function(*args, **kwargs) - return wrapper - - return decorator + return wrapper diff --git a/netam/framework.py b/netam/framework.py index 5402ea2d..b4d9d396 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -26,7 +26,6 @@ BIG, VRC01_NT_SEQ, encode_sequences, - parallelize_function, ) from netam import models import netam.molevol as molevol @@ -254,13 +253,13 @@ def __init__(self, encoder, model, training_hyperparameters={}): self.model = model self.training_hyperparameters = training_hyperparameters - def __call__(self, sequences): + def __call__(self, sequences, **kwargs): """Evaluate the model on a list of sequences.""" if isinstance(sequences, str): raise ValueError( "Expected a list of sequences for call on crepe, but got a single string instead." ) - return self.model.evaluate_sequences(sequences, encoder=self.encoder) + return self.model.evaluate_sequences(sequences, encoder=self.encoder, **kwargs) @property def device(self): @@ -339,7 +338,7 @@ def crepe_exists(prefix): def trimmed_shm_model_outputs_of_crepe(crepe, parents): """Model outputs trimmed to the length of the parent sequences.""" - rates, csp_logits = parallelize_function()(crepe)(parents) + rates, csp_logits = parallelize_function(crepe)(parents) rates = rates.cpu().detach() csps = torch.softmax(csp_logits, dim=-1).cpu().detach() trimmed_rates = [rates[i, : len(parent)] for i, parent in enumerate(parents)] diff --git a/netam/models.py b/netam/models.py index 93293b8e..9a67e6c8 100644 --- a/netam/models.py +++ b/netam/models.py @@ -17,7 +17,7 @@ generate_kmers, aa_mask_tensor_of, encode_sequences, - chunk_method, + chunk_function, ) warnings.filterwarnings( @@ -65,7 +65,7 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - @chunk_method(progress_bar_name="Evaluating model") + @chunk_function(first_chunkable_idx=1, progress_bar_name="Evaluating model") def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048): if encoder is None: raise ValueError("An encoder must be provided.") From 96d0480950998d0699b2ffb2af5e70dd160ad3b9 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 10 Dec 2024 10:42:35 -0800 Subject: [PATCH 4/7] tweaks --- netam/common.py | 2 -- netam/framework.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/netam/common.py b/netam/common.py index f72f7527..7b870891 100644 --- a/netam/common.py +++ b/netam/common.py @@ -427,7 +427,6 @@ def wrapper(*args, **kwargs): chunk_size = default_chunk_size pre_chunk_args = args[:first_chunkable_idx] chunkable_args = args[first_chunkable_idx:] - print("chunk got args len ", len(chunkable_args[0])) results = [] if progress_bar_name is None: @@ -496,7 +495,6 @@ def wrapper(*args, **kwargs): ) pre_chunk_args = args[:first_chunkable_idx] chunkable_args = args[first_chunkable_idx:] - print("parallelize got args len ", len(chunkable_args[0])) min_worker_count = (len(chunkable_args[0]) // min_chunk_size) worker_count = min(min_worker_count, max_worker_count) diff --git a/netam/framework.py b/netam/framework.py index b4d9d396..bbd021e1 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -26,6 +26,7 @@ BIG, VRC01_NT_SEQ, encode_sequences, + parallelize_function, ) from netam import models import netam.molevol as molevol From 87a0008912db2b3258caa580b53f3b5b478a5990 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 10 Dec 2024 10:43:46 -0800 Subject: [PATCH 5/7] fix NaN issue --- netam/dasm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index e3712fe2..0505311b 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -180,6 +180,7 @@ def loss_of_batch(self, batch): # mut_pos_loss, and we mask out sites with no substitution for the CSP # loss. The latter class of sites also eliminates sites that have Xs in # the parent or child (see sequences.aa_subs_indicator_tensor_of). + predictions = zap_predictions_along_diagonal(predictions, aa_parents_idxs) # After zapping out the diagonal, we can effectively sum over the @@ -195,11 +196,10 @@ def loss_of_batch(self, batch): # logit space, so we are set up for using the cross entropy loss. # However we have to mask out the sites that are not substituted, i.e. # the sites for which aa_subs_indicator is 0. - subs_mask = aa_subs_indicator == 1 + subs_mask = (aa_subs_indicator == 1) & mask csp_pred = predictions[subs_mask] csp_targets = aa_children_idxs[subs_mask] csp_loss = self.xent_loss(csp_pred, csp_targets) - return torch.stack([subs_pos_loss, csp_loss]) def build_selection_matrix_from_parent(self, parent: str): From e9c7823dca701199ad543b44ef4d9010c50bad96 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 10 Dec 2024 11:00:19 -0800 Subject: [PATCH 6/7] cleanup --- netam/common.py | 52 +++++++++++++++++++++++-------------------------- netam/dxsm.py | 8 +++----- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/netam/common.py b/netam/common.py index 7b870891..5c1eb36a 100644 --- a/netam/common.py +++ b/netam/common.py @@ -454,11 +454,13 @@ def _apply_args_and_kwargs(func, pre_chunk_args, chunked_args, kwargs): def parallelize_function( function, first_chunkable_idx=0, - default_parallelize=True, max_workers=10, min_chunk_size=1000, ): - """Decorator to parallelize function application with multiprocessing. + """Function to parallelize another function's application with multiprocessing. + + This is intentionally not designed to be used with decorator syntax because it should only + be used when the function it is applied to will be run on the CPU. Expects that all positional arguments are iterables of the same length, and that outputs are tuples of tensors whose first dimension @@ -471,8 +473,9 @@ def parallelize_function( argument, so that parallelization can be turned on or off at each invocation. Args: - first_chunkable_idx: The index of the first argument to be chunked. All positional arguments after this index will be chunked. - default_parallelize: Whether to parallelize function calls by default, or require passing `parallelize=True` in invocation to parallelize. + function: The function to be parallelized. + first_chunkable_idx: The index of the first argument to be chunked. + All positional arguments after this index will be chunked. max_workers: The maximum number of processes to use. min_chunk_size: The minimum chunk size for input data. The number of workers is adjusted to ensure that the chunk size is at least this. @@ -484,32 +487,25 @@ def parallelize_function( @wraps(function) def wrapper(*args, **kwargs): - if "parallelize" in kwargs: - parallelize = kwargs.pop("parallelize") - else: - parallelize = default_parallelize - if parallelize: - if len(args) <= first_chunkable_idx: - raise ValueError( - f"Function {function.__name__} cannot be parallelized without chunkable arguments" - ) - pre_chunk_args = args[:first_chunkable_idx] - chunkable_args = args[first_chunkable_idx:] - min_worker_count = (len(chunkable_args[0]) // min_chunk_size) + if len(args) <= first_chunkable_idx: + raise ValueError( + f"Function {function.__name__} cannot be parallelized without chunkable arguments" + ) + pre_chunk_args = args[:first_chunkable_idx] + chunkable_args = args[first_chunkable_idx:] + min_worker_count = (len(chunkable_args[0]) // min_chunk_size) - worker_count = min(min_worker_count, max_worker_count) - if worker_count <= 1: - return function(*args, **kwargs) + worker_count = min(min_worker_count, max_worker_count) + if worker_count <= 1: + return function(*args, **kwargs) - chunk_size = (len(chunkable_args[0]) // worker_count) + 1 - chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) - with mp.Pool(worker_count) as pool: - results = pool.starmap(_apply_args_and_kwargs, list(zip(repeat(function), repeat(pre_chunk_args), chunked_args, repeat(kwargs)))) - if isinstance(results[0], tuple): - return tuple(torch.cat(tensors) for tensors in zip(*results)) - else: - return torch.cat(results) + chunk_size = (len(chunkable_args[0]) // worker_count) + 1 + chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) + with mp.Pool(worker_count) as pool: + results = pool.starmap(_apply_args_and_kwargs, list(zip(repeat(function), repeat(pre_chunk_args), chunked_args, repeat(kwargs)))) + if isinstance(results[0], tuple): + return tuple(torch.cat(tensors) for tensors in zip(*results)) else: - return function(*args, **kwargs) + return torch.cat(results) return wrapper diff --git a/netam/dxsm.py b/netam/dxsm.py index 38a67de2..0ebd3c56 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -255,8 +255,6 @@ def _find_optimal_branch_length( return molevol.optimize_branch_length( log_pcp_probability, starting_branch_length, **optimization_kwargs ) - # TODO for debugging - # return 0.026, False def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): optimal_lengths = [] @@ -297,9 +295,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # # The following can be used when one wants a better traceback. - # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # The following can be used when one wants a better traceback. + burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, From 6a24dcd22866a9183a2ebf35296f5a129dfa2301 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 10 Dec 2024 11:01:19 -0800 Subject: [PATCH 7/7] more cleanup and format --- netam/common.py | 22 ++++++++++++++++------ netam/dxsm.py | 6 +++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/netam/common.py b/netam/common.py index 5c1eb36a..00979d47 100644 --- a/netam/common.py +++ b/netam/common.py @@ -398,9 +398,7 @@ def chunked(iterable, n): def chunk_function( - first_chunkable_idx=0, - default_chunk_size=2048, - progress_bar_name=None + first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None ): """Decorator to chunk the input to a function. @@ -434,7 +432,9 @@ def wrapper(*args, **kwargs): else: progargs = {"desc": progress_bar_name} bar = tqdm(total=len(chunkable_args[0]), delay=2.0, **progargs) - for chunked_args in zip(*(chunked(arg, chunk_size) for arg in chunkable_args)): + for chunked_args in zip( + *(chunked(arg, chunk_size) for arg in chunkable_args) + ): bar.update(len(chunked_args[0])) results.append(function(*pre_chunk_args, *chunked_args, **kwargs)) if isinstance(results[0], tuple): @@ -493,7 +493,7 @@ def wrapper(*args, **kwargs): ) pre_chunk_args = args[:first_chunkable_idx] chunkable_args = args[first_chunkable_idx:] - min_worker_count = (len(chunkable_args[0]) // min_chunk_size) + min_worker_count = len(chunkable_args[0]) // min_chunk_size worker_count = min(min_worker_count, max_worker_count) if worker_count <= 1: @@ -502,7 +502,17 @@ def wrapper(*args, **kwargs): chunk_size = (len(chunkable_args[0]) // worker_count) + 1 chunked_args = list(zip(*(chunked(arg, chunk_size) for arg in chunkable_args))) with mp.Pool(worker_count) as pool: - results = pool.starmap(_apply_args_and_kwargs, list(zip(repeat(function), repeat(pre_chunk_args), chunked_args, repeat(kwargs)))) + results = pool.starmap( + _apply_args_and_kwargs, + list( + zip( + repeat(function), + repeat(pre_chunk_args), + chunked_args, + repeat(kwargs), + ) + ), + ) if isinstance(results[0], tuple): return tuple(torch.cat(tensors) for tensors in zip(*results)) else: diff --git a/netam/dxsm.py b/netam/dxsm.py index 0ebd3c56..9c7157e3 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -295,9 +295,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # # The following can be used when one wants a better traceback. + # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__,