Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate the multihit model into the DNSM framework #60

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ def __init__(
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.all_rates = all_rates
self.all_subs_probs = all_subs_probs
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
# We want these parameters to act like fixed data
self.multihit_model.values.requires_grad_(False)

assert len(self.nt_parents) == len(self.nt_children)
pcp_count = len(self.nt_parents)
Expand Down Expand Up @@ -99,6 +104,7 @@ def of_seriess(
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
branch_length_multiplier=5.0,
multihit_model=None,
):
"""Alternative constructor that takes the raw data and calculates the initial
branch lengths.
Expand All @@ -119,10 +125,11 @@ def of_seriess(
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
initial_branch_lengths,
multihit_model=multihit_model,
)

@classmethod
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column"
Expand All @@ -132,6 +139,7 @@ def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0):
pcp_df["rates"],
pcp_df["subs_probs"],
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)

def clone(self):
Expand All @@ -142,6 +150,7 @@ def clone(self):
self.all_rates.copy(),
self.all_subs_probs.copy(),
self._branch_lengths.copy(),
multihit_model=copy.deepcopy(self.multihit_model),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deepcopy and the one below are no longer necessary now that you are copying in the init, right?

)
return new_dataset

Expand All @@ -158,6 +167,7 @@ def subset_via_indices(self, indices):
self.all_rates[indices],
self.all_subs_probs[indices],
self._branch_lengths[indices],
multihit_model=copy.deepcopy(self.multihit_model),
)
return new_dataset

Expand Down Expand Up @@ -220,6 +230,7 @@ def update_neutral_aa_mut_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_aa_mut_prob).all():
Expand Down Expand Up @@ -272,21 +283,27 @@ def to(self, device):
self.all_subs_probs = self.all_subs_probs.to(device)


def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
def train_val_datasets_of_pcp_df(
pcp_df, branch_length_multiplier=5.0, multihit_model=None
):
"""Perform a train-val split based on a "in_train" column.

Stays here so it can be used in tests.
"""
train_df = pcp_df[pcp_df["in_train"]].reset_index(drop=True)
val_df = pcp_df[~pcp_df["in_train"]].reset_index(drop=True)
val_dataset = DNSMDataset.of_pcp_df(
val_df, branch_length_multiplier=branch_length_multiplier
val_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
if len(train_df) == 0:
return None, val_dataset
# else:
train_dataset = DNSMDataset.of_pcp_df(
train_df, branch_length_multiplier=branch_length_multiplier
train_df,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
return train_dataset, val_dataset

Expand Down Expand Up @@ -360,13 +377,14 @@ def _find_optimal_branch_length(
rates,
subs_probs,
starting_branch_length,
multihit_model,
**optimization_kwargs,
):
if parent == child:
return 0.0
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, subs_probs
sel_matrix, parent, child, rates, subs_probs, multihit_model
)
if type(starting_branch_length) == torch.Tensor:
starting_branch_length = starting_branch_length.detach().item()
Expand Down Expand Up @@ -395,6 +413,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
dataset.multihit_model,
**optimization_kwargs,
)

Expand All @@ -410,7 +429,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# The following can be used when one wants a better traceback.
# # The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
with mp.Pool(worker_count) as pool:
Expand Down
25 changes: 21 additions & 4 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def build_codon_mutsel(
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
aa_sel_matrices: Tensor,
multihit_model=None,
) -> Tensor:
"""Build a sequence of codon mutation-selection matrices for codons along a
sequence.
Expand All @@ -301,6 +302,9 @@ def build_codon_mutsel(
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp()

# Calculate the codon selection matrix for each sequence via Einstein
# summation, in which we sum over the repeated indices.
# So, for each site (s) and codon (c), sum over amino acids (a):
Expand Down Expand Up @@ -339,6 +343,7 @@ def neutral_aa_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will mutate to every
amino acid?
Expand All @@ -356,10 +361,13 @@ def neutral_aa_probs(
mut_matrices = build_mutation_matrices(
parent_codon_idxs, codon_mut_probs, codon_sub_probs
)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices).view(-1, 64)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

if multihit_model is not None:
codon_probs = multihit_model(parent_codon_idxs, codon_probs.log()).exp()

# Get the probability of mutating to each amino acid.
aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX
aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX

return aa_probs

Expand Down Expand Up @@ -396,6 +404,7 @@ def neutral_aa_mut_probs(
parent_codon_idxs: Tensor,
codon_mut_probs: Tensor,
codon_sub_probs: Tensor,
multihit_model=None,
) -> Tensor:
"""For every site, what is the probability that the amino acid will have a
substution or mutate to a stop under neutral evolution?
Expand All @@ -414,12 +423,19 @@ def neutral_aa_mut_probs(
Shape: (codon_count,)
"""

aa_probs = neutral_aa_probs(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
aa_probs = neutral_aa_probs(
parent_codon_idxs,
codon_mut_probs,
codon_sub_probs,
multihit_model=multihit_model,
)
mut_probs = mut_probs_of_aa_probs(parent_codon_idxs, aa_probs)
return mut_probs


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
def mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, sub_probs, multihit_model=None
):
"""Constructs the log_pcp_probability function specific to given rates and
sub_probs.

Expand All @@ -442,6 +458,7 @@ def log_pcp_probability(log_branch_length: torch.Tensor):
mut_probs.reshape(-1, 3),
sub_probs.reshape(-1, 3, 4),
sel_matrix,
multihit_model=multihit_model,
)

# This is a diagnostic generating data for netam issue #7.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_dnsm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing as mp
import os

import torch
Expand All @@ -14,6 +15,15 @@
from netam.dnsm import DNSMBurrito, train_val_datasets_of_pcp_df


def force_spawn():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were going to move this into common, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and remove the other instance)

"""Force the spawn start method for multiprocessing.

This is necessary to avoid conflicts with the internal OpenMP-based thread pool in
PyTorch.
"""
mp.set_start_method("spawn", force=True)


def test_aa_idx_tensor_of_str_ambig():
input_seq = "ACX"
expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int)
Expand All @@ -36,6 +46,7 @@ def pcp_df():
@pytest.fixture
def dnsm_burrito(pcp_df):
"""Fixture that returns the DNSM Burrito object."""
force_spawn()
pcp_df["in_train"] = True
pcp_df.loc[pcp_df.index[-15:], "in_train"] = False
train_dataset, val_dataset = train_val_datasets_of_pcp_df(pcp_df)
Expand All @@ -57,6 +68,7 @@ def dnsm_burrito(pcp_df):


def test_parallel_branch_length_optimization(dnsm_burrito):
force_spawn()
dataset = dnsm_burrito.val_dataset
parallel_branch_lengths = dnsm_burrito.find_optimal_branch_lengths(dataset)
branch_lengths = dnsm_burrito.serial_find_optimal_branch_lengths(dataset)
Expand Down