diff --git a/netam/dasm.py b/netam/dasm.py index 734d3d1b..1e3f3f52 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -206,24 +206,4 @@ def build_selection_matrix_from_parent(self, parent: str): parent_idxs = sequences.aa_idx_array_of_str(parent) selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 - return selection_factors - - # We need to repeat this so that we use this worker_optimize_branch_length below. - def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): - worker_count = min(mp.cpu_count() // 2, 10) - # The following can be used when one wants a better traceback. - # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - with mp.Pool(worker_count) as pool: - splits = dataset.split(worker_count) - results = pool.starmap( - worker_optimize_branch_length, - [(self.model, split, optimization_kwargs) for split in splits], - ) - return torch.cat(results) - - -def worker_optimize_branch_length(model, dataset, optimization_kwargs): - """The worker used for parallel branch length optimization.""" - burrito = DASMBurrito(None, dataset, copy.deepcopy(model)) - return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + return selection_factors \ No newline at end of file diff --git a/netam/dnsm.py b/netam/dnsm.py index bd7d55b4..23539bd9 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -9,6 +9,7 @@ import copy import multiprocessing as mp +from functools import partial import torch from torch.utils.data import Dataset @@ -414,10 +415,13 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): # The following can be used when one wants a better traceback. # burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model)) # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + our_optimize_branch_length = partial( + worker_optimize_branch_length, + self.__class__,) with mp.Pool(worker_count) as pool: splits = dataset.split(worker_count) results = pool.starmap( - worker_optimize_branch_length, + our_optimize_branch_length, [(self.model, split, optimization_kwargs) for split in splits], ) return torch.cat(results) @@ -437,9 +441,9 @@ def to_crepe(self): return framework.Crepe(encoder, self.model, training_hyperparameters) -def worker_optimize_branch_length(model, dataset, optimization_kwargs): +def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kwargs): """The worker used for parallel branch length optimization.""" - burrito = DNSMBurrito(None, dataset, copy.deepcopy(model)) + burrito = burrito_class(None, dataset, copy.deepcopy(model)) return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)