-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bring over epam code to avoid circular dependency #34
Conversation
removing a legacy """
A silly amino acid prediction model.
We'll use these conventions:
* B is the batch size
* L is the max sequence length
"""
import math
import os
import pandas as pd
import torch
import torch.optim as optim
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from tensorboardX import SummaryWriter
from epam.torch_common import pick_device, PositionalEncoding
import epam.sequences as sequences
class AAPCPDataset(Dataset):
def __init__(self, aa_parents, aa_children):
assert len(aa_parents) == len(aa_parents)
pcp_count = len(aa_parents)
for parent, child in zip(aa_parents, aa_children):
if parent == child:
raise ValueError(
f"Found an identical parent and child sequence: {parent}"
)
self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
self.aa_parents_onehot = torch.zeros((pcp_count, self.max_aa_seq_len, 20))
self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len))
# padding_mask is True for padding positions.
self.padding_mask = torch.ones(
(pcp_count, self.max_aa_seq_len), dtype=torch.bool
)
for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
aa_indices_parent = sequences.aa_idx_array_of_str(aa_parent)
aa_seq_len = len(aa_parent)
self.aa_parents_onehot[i, torch.arange(aa_seq_len), aa_indices_parent] = 1
self.aa_subs_indicator_tensor[i, :aa_seq_len] = torch.tensor(
[p != c for p, c in zip(aa_parent, aa_child)], dtype=torch.float
)
self.padding_mask[i, :aa_seq_len] = False
def __len__(self):
return len(self.aa_parents_onehot)
def __getitem__(self, idx):
return {
"aa_onehot": self.aa_parents_onehot[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"padding_mask": self.padding_mask[idx],
}
class TransformerBinaryModel(nn.Module):
"""A transformer-based model for binary selection.
This is a model that takes in a batch of one-hot encoded sequences and outputs a binary selection matrix.
See forward() for details.
"""
def __init__(
self,
nhead: int,
dim_feedforward: int,
layer_count: int,
d_model: int = 20,
dropout: float = 0.5,
):
super().__init__()
self.device = pick_device()
self.d_model = d_model
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(self.encoder_layer, layer_count)
self.linear = nn.Linear(self.d_model, 1)
self.to(self.device)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange, initrange)
def forward(self, parent_onehots: Tensor, padding_mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from a one-hot encoded parent sequence.
Because we're predicting log of the selection factor, we don't use an
activation function after the transformer.
Parameters:
parent_onehots: A tensor of shape (B, L, 20) representing the one-hot encoding of parent sequences.
padding_mask: A tensor of shape (B, L) representing the padding mask for the sequence.
Returns:
A tensor of shape (B, L, 1) representing the log level of selection
for each amino acid site.
"""
parent_onehots = parent_onehots * math.sqrt(self.d_model)
# Have to do the permutation because the positional encoding expects the
# sequence length to be the first dimension.
parent_onehots = self.pos_encoder(parent_onehots.permute(1, 0, 2)).permute(
1, 0, 2
)
# NOTE: not masking due to MPS bug
out = self.encoder(parent_onehots) # , src_key_padding_mask=padding_mask)
out = self.linear(out)
return out.squeeze(-1)
def prediction_of_aa_str(self, aa_str: str):
"""Do the forward method without gradients from an amino acid string and convert to numpy.
Parameters:
aa_str: A string of amino acids.
Returns:
A numpy array of the same length as the input string representing
the level of selection for each amino acid site.
"""
aa_onehot = sequences.aa_onehot_tensor_of_str(aa_str)
# Create a padding mask with False values (i.e., no padding)
padding_mask = torch.zeros(len(aa_str), dtype=torch.bool).to(self.device)
with torch.no_grad():
aa_onehot = aa_onehot.to(self.device)
model_out = self(aa_onehot.unsqueeze(0), padding_mask.unsqueeze(0)).squeeze(
0
)
final_out = torch.exp(model_out)
final_out = torch.clamp(final_out, min=0.0, max=0.999)
return final_out.cpu().numpy()
def train_model(
pcp_df,
nhead,
dim_feedforward,
layer_count,
batch_size=32,
num_epochs=10,
learning_rate=0.001,
checkpoint_dir="./_checkpoints",
log_dir="./_logs",
):
print("preparing data...")
parents = pcp_df["aa_parent"]
children = pcp_df["aa_child"]
train_len = int(0.8 * len(parents))
train_parents, val_parents = parents[:train_len], parents[train_len:]
train_children, val_children = children[:train_len], children[train_len:]
# It's important to make separate PCPDatasets for training and validation
# because the maximum sequence length can differ between those two.
train_set = AAPCPDataset(train_parents, train_children)
val_set = AAPCPDataset(val_parents, val_children)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
model = TransformerBinaryModel(
nhead=nhead, dim_feedforward=dim_feedforward, layer_count=layer_count
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
device = model.device
if not os.path.exists(log_dir):
os.makedirs(log_dir)
writer = SummaryWriter(log_dir=log_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
bce_loss = nn.BCELoss()
def complete_loss_fn(log_aa_mut_probs, aa_subs_indicator, padding_mask):
predictions = torch.exp(log_aa_mut_probs)
predictions = predictions.masked_select(~padding_mask)
aa_subs_indicator = aa_subs_indicator.masked_select(~padding_mask)
# In the early stages of training, we can get probabilities > 1.0 because
# of bad parameter initialization. We clamp the predictions to be between
# 0 and 0.999 to avoid this: out of range predictions can make NaNs
# downstream.
out_of_range_prediction_count = torch.sum(predictions > 1.0)
# if out_of_range_prediction_count > 0:
# print(f"{out_of_range_prediction_count}\tpredictions out of range.")
predictions = torch.clamp(predictions, min=0.0, max=0.999)
return bce_loss(predictions, aa_subs_indicator)
def loss_of_batch(batch):
aa_onehot = batch["aa_onehot"].to(device)
aa_subs_indicator = batch["subs_indicator"].to(device)
padding_mask = batch["padding_mask"].to(device)
log_aa_mut_probs = model(aa_onehot, padding_mask)
return complete_loss_fn(
log_aa_mut_probs,
aa_subs_indicator,
padding_mask,
)
def compute_avg_loss(data_loader):
total_loss = 0
with torch.no_grad():
for batch in data_loader:
total_loss += loss_of_batch(batch).item()
return total_loss / len(data_loader)
# Record epoch 0
model.eval()
avg_train_loss_epoch_zero = compute_avg_loss(train_loader)
avg_val_loss_epoch_zero = compute_avg_loss(val_loader)
writer.add_scalar("Training Loss", avg_train_loss_epoch_zero, 0)
writer.add_scalar("Validation Loss", avg_val_loss_epoch_zero, 0)
print(
f"Epoch [0/{num_epochs}], Training Loss: {avg_train_loss_epoch_zero}, Validation Loss: {avg_val_loss_epoch_zero}"
)
loss_records = []
for epoch in range(num_epochs):
model.train()
for i, batch in enumerate(train_loader):
optimizer.zero_grad()
loss = loss_of_batch(batch)
loss.backward()
optimizer.step()
writer.add_scalar(
"Training Loss", loss.item(), epoch * len(train_loader) + i
)
# Validation Loop
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
val_loss += loss_of_batch(batch).item()
avg_val_loss = val_loss / len(val_loader)
writer.add_scalar("Validation Loss", avg_val_loss, epoch)
# Save model checkpoint
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": avg_val_loss,
},
f"{checkpoint_dir}/model_epoch_{epoch}.pth",
)
print(
f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}"
)
loss_records.append(
{
"Epoch": epoch + 1,
"Training Loss": loss.item(),
"Validation Loss": avg_val_loss,
}
)
writer.close()
loss_df = pd.DataFrame(loss_records)
loss_df.to_csv("training_validation_loss.csv", index=False)
return model |
Removing """
These are functions for simulating amino acid mutations in a protein sequence.
So, this is not for simulating mutation-selection processes.
It corresponds to the inference happning in toy_dnsm.py.
"""
import random
import pandas as pd
from tqdm import tqdm
from epam.sequences import AA_STR_SORTED
def mimic_mutations(sequence_mutator_fn, parents, sub_counts):
"""
Mimics mutations for a series of parent sequences.
Parameters
----------
sequence_mutator_fn : function
Function that takes a string sequence and an integer, returns a mutated sequence with that many mutations.
parents : pd.Series
Series containing parent sequences as strings.
sub_counts : pd.Series
Series containing the number of substitutions for each parent sequence.
Returns
-------
pd.Series
Series containing mutated sequences as strings.
"""
mutated_sequences = []
for seq, sub_count in tqdm(
zip(parents, sub_counts), total=len(parents), desc="Mutating sequences"
):
mutated_seq = sequence_mutator_fn(seq, sub_count)
mutated_sequences.append(mutated_seq)
return pd.Series(mutated_sequences)
def general_mutator(aa_seq, sub_count, mut_criterion):
"""
General function to mutate an amino acid sequence based on a criterion function.
The function first identifies positions in the sequence that satisfy the criterion
specified by `mut_criterion`. If the number of such positions is less than or equal
to the number of mutations needed (`sub_count`), then mutations are made at those positions.
If `sub_count` is greater than the number of positions satisfying the criterion, the function
mutates all those positions and then randomly selects additional positions to reach `sub_count`
total mutations. All mutations change the amino acid to a randomly selected new amino acid,
avoiding a mutation to the same type.
Parameters
----------
aa_seq : str
Original amino acid sequence.
sub_count : int
Number of substitutions to make.
mut_criterion : function
Function that takes a sequence and a position, returns True if position should be mutated.
Returns
-------
str
Mutated amino acid sequence.
"""
def draw_new_aa_for_pos(pos):
return random.choice([aa for aa in AA_STR_SORTED if aa != aa_seq_list[pos]])
aa_seq_list = list(aa_seq)
# find all positions that satisfy the mutation criterion
mut_positions = [
pos for pos, aa in enumerate(aa_seq_list) if mut_criterion(aa_seq, pos)
]
# if fewer criterion-satisfying positions than required mutations, randomly add more
if len(mut_positions) < sub_count:
extra_positions = random.choices(
[pos for pos in range(len(aa_seq_list)) if pos not in mut_positions],
k=sub_count - len(mut_positions),
)
mut_positions += extra_positions
# if more criterion-satisfying positions than required mutations, randomly remove some
elif len(mut_positions) > sub_count:
mut_positions = random.sample(mut_positions, sub_count)
# perform mutations
for pos in mut_positions:
aa_seq_list[pos] = draw_new_aa_for_pos(pos)
return "".join(aa_seq_list)
# Criterion functions
def tyrosine_mut_criterion(aa_seq, pos):
return aa_seq[pos] == "Y"
def hydrophobic_mut_criterion(aa_seq, pos):
hydrophobic_aa = set("AILMFVWY")
return aa_seq[pos] in hydrophobic_aa
def hydrophobic_neighbor_mut_criterion(aa_seq, pos):
"""
Criterion function that returns True if either amino acid at immediate
neighbors are hydrophobic.
"""
hydrophobic_aa = set("AILMFVWY")
positions_to_check = []
if pos > 0:
positions_to_check.append(pos - 1)
if pos < len(aa_seq) - 1:
positions_to_check.append(pos + 1)
return any(aa_seq[i] in hydrophobic_aa for i in positions_to_check)
[tyrosine_mutator, hydrophobic_mutator, hydrophobic_neighbor_mutator] = [
lambda aa_seq, sub_count, crit=crit: general_mutator(aa_seq, sub_count, crit)
for crit in [
tyrosine_mut_criterion,
hydrophobic_mut_criterion,
hydrophobic_neighbor_mut_criterion,
]
] |
""" | ||
|
||
# Assert that sub_probs are within the range [0, 1] modulo rounding error | ||
assert torch.all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ksung25 it turns out that I had these locally committed. They are new checks, so may make epam break.
return 1.0 - p_staying_same | ||
|
||
|
||
def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ksung25 I had to bring the content of this over as a free function. So we'll want to use this in the epam models.py
.
- uses: mamba-org/setup-micromamba@v1 | ||
with: | ||
micromamba-version: '1.5.6-0' # any version from https://github.com/mamba-org/micromamba-releases | ||
environment-name: epam | ||
environment-name: netam |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@willdumm here are the changes I mentioned-- let me know if I should change this at all.
@@ -304,7 +302,6 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0): | |||
class DNSMBurrito(framework.Burrito): | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ksung25 note that WrappedBinaryMutSel
is totally gone now and can be dropped from epam.
@@ -0,0 +1,6 @@ | |||
black | |||
pulp==2.7.0 # see issue https://github.com/snakemake/snakemake/issues/2607 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I could tell, this was still a good idea. Let me know if not (@ksung25 ).
toy_dnsm.py
andtoy_simulation.py