Skip to content

Commit

Permalink
generalizing branch length optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Oct 3, 2024
1 parent 8cff29f commit bf44d5e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 24 deletions.
22 changes: 1 addition & 21 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,24 +206,4 @@ def build_selection_matrix_from_parent(self, parent: str):
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors

# We need to repeat this so that we use this worker_optimize_branch_length below.
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
"""The worker used for parallel branch length optimization."""
burrito = DASMBurrito(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
return selection_factors
10 changes: 7 additions & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import multiprocessing as mp
from functools import partial

import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -414,10 +415,13 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
# The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
self.__class__,)
with mp.Pool(worker_count) as pool:
splits = dataset.split(worker_count)
results = pool.starmap(
worker_optimize_branch_length,
our_optimize_branch_length,
[(self.model, split, optimization_kwargs) for split in splits],
)
return torch.cat(results)
Expand All @@ -437,9 +441,9 @@ def to_crepe(self):
return framework.Crepe(encoder, self.model, training_hyperparameters)


def worker_optimize_branch_length(model, dataset, optimization_kwargs):
def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kwargs):
"""The worker used for parallel branch length optimization."""
burrito = DNSMBurrito(None, dataset, copy.deepcopy(model))
burrito = burrito_class(None, dataset, copy.deepcopy(model))
return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)


Expand Down

0 comments on commit bf44d5e

Please sign in to comment.