Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings; multihit device fix #73

Merged
merged 14 commits into from
Oct 28, 2024
6 changes: 3 additions & 3 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def build_selection_matrix_from_parent(self, parent: str):
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

return selection_factors
return per_aa_selection_factors
16 changes: 11 additions & 5 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def update_neutral_probs(self):
the DNSM (in which case it's neutral probabilities of any nonsynonymous
mutation) and the DASM (in which case it's the neutral probabilities of mutation
to the various amino acids).

This is the case of the DNSM, but the DASM will override this method.
"""
neutral_aa_mut_prob_l = []

Expand Down Expand Up @@ -333,8 +335,8 @@ def load_branch_lengths(self, in_csv_prefix):
self.val_dataset.load_branch_lengths(in_csv_prefix + ".val_branch_lengths.csv")

def prediction_pair_of_batch(self, batch):
"""Get log neutral mutation probabilities and log selection factors for a batch
of data."""
"""Get log neutral amino acid substitution probabilities and log selection
factors for a batch of data."""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_mut_probs = batch["log_neutral_aa_mut_probs"].to(self.device)
Expand All @@ -346,7 +348,8 @@ def prediction_pair_of_batch(self, batch):
return log_neutral_aa_mut_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
# Take the product of the neutral mutation probabilities and the selection factors.
"""Obtain the predictions for a pair consisting of the log neutral amino acid
mutation substitution probabilities and the log selection factors."""
predictions = torch.exp(log_neutral_aa_mut_probs + log_selection_factors)
assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
Expand Down Expand Up @@ -409,6 +412,9 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

if dataset.multihit_model is not None:
willdumm marked this conversation as resolved.
Show resolved Hide resolved
multihit_model = copy.deepcopy(dataset.multihit_model).to("cpu")

for parent, child, rates, subs_probs, starting_length in tqdm(
zip(
dataset.nt_parents,
Expand All @@ -426,7 +432,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
dataset.multihit_model,
multihit_model,
**optimization_kwargs,
)

Expand All @@ -443,7 +449,7 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
worker_count = min(mp.cpu_count() // 2, 10)
# # The following can be used when one wants a better traceback.
# burrito = DNSMBurrito(None, dataset, copy.deepcopy(self.model))
# burrito = self.__class__(None, dataset, copy.deepcopy(self.model))
# return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs)
our_optimize_branch_length = partial(
worker_optimize_branch_length,
Expand Down
6 changes: 3 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def multi_train(self, epochs, max_tries=3):
and resume training.
"""
for i in range(max_tries):
train_history = self.train(epochs)
train_history = self.simple_train(epochs)
if self.optimizer.param_groups[0]["lr"] < self.min_learning_rate:
return train_history
else:
Expand Down Expand Up @@ -571,7 +571,7 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None
self.write_loss("Validation loss", average_loss, self.global_epoch)
return loss_reduction(average_loss)

def train(self, epochs, out_prefix=None):
def simple_train(self, epochs, out_prefix=None):
"""Train the model for the given number of epochs.

If out_prefix is provided, then a crepe will be saved to that location.
Expand Down Expand Up @@ -760,7 +760,7 @@ def joint_train(
weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)
)
self.reset_optimization(new_lr)
loss_history_l.append(self.train(epochs, out_prefix=out_prefix))
loss_history_l.append(self.simple_train(epochs, out_prefix=out_prefix))
# We standardize and optimize the branch lengths after each cycle, even the last one.
optimize_branch_lengths()
self.mark_branch_lengths_optimized(cycle + 1)
Expand Down
12 changes: 9 additions & 3 deletions netam/molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def aaprob_of_mut_and_sub(
parent_codon_idxs: Tensor, mut_probs: Tensor, sub_probs: Tensor
) -> Tensor:
"""For a sequence of parent codons and given nucleotide mutability and substitution
probabilities, compute the amino acid substitution probabilities for each codon
along the sequence.
probabilities, compute the probability of a substitution to each amino acid for each
codon along the sequence.

Stop codons don't appear as part of this calculation.

Args:
parent_codon_idxs (torch.Tensor): A 2D tensor where each row contains indices representing
Expand Down Expand Up @@ -288,6 +290,8 @@ def build_codon_mutsel(
"""Build a sequence of codon mutation-selection matrices for codons along a
sequence.

These will assign zero for the probability of mutating to a stop codon.

Args:
parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3)
codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3)
Expand All @@ -310,6 +314,7 @@ def build_codon_mutsel(
# So, for each site (s) and codon (c), sum over amino acids (a):
# codon_sel_matrices[s, c] = sum_a(CODON_AA_INDICATOR_MATRIX[c, a] * aa_sel_matrices[s, a])
# Resulting shape is (S, C) where S is the number of sites and C is the number of codons.
# Stop codons don't appear in this sum, so columns for stop codons will be zero.
codon_sel_matrices = torch.einsum(
"ca,sa->sc", CODON_AA_INDICATOR_MATRIX, aa_sel_matrices
)
Expand All @@ -322,7 +327,8 @@ def build_codon_mutsel(

# Now we need to recalculate the probability of staying in the same codon.
# In our setup, this is the probability of nothing happening.
# To calculate this, we zero out the previously calculated probabilities...
# To calculate this, we zero out the probabilities of mutating to the parent
# codon...
codon_count = parent_codon_idxs.shape[0]
codon_mutsel[(torch.arange(codon_count), *parent_codon_idxs.T)] = 0.0
# sum together their probabilities...
Expand Down
2 changes: 1 addition & 1 deletion tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def tiny_model():
@pytest.fixture
def tiny_burrito(tiny_dataset, tiny_val_dataset, tiny_model):
burrito = SHMBurrito(tiny_dataset, tiny_val_dataset, tiny_model)
burrito.train(epochs=5)
burrito.simple_train(epochs=5)
return burrito


Expand Down
Loading