Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring over epam code to avoid circular dependency #34

Merged
merged 7 commits into from
Jun 11, 2024
Merged

Conversation

matsen
Copy link
Contributor

@matsen matsen commented Jun 10, 2024

  • Removing legacy toy_dnsm.py and toy_simulation.py

@matsen
Copy link
Contributor Author

matsen commented Jun 10, 2024

removing a legacy toy_dnsm.py:

"""
A silly amino acid prediction model.

We'll use these conventions:

* B is the batch size
* L is the max sequence length

"""

import math
import os

import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

from tensorboardX import SummaryWriter

from epam.torch_common import pick_device, PositionalEncoding
import epam.sequences as sequences


class AAPCPDataset(Dataset):
    def __init__(self, aa_parents, aa_children):
        assert len(aa_parents) == len(aa_parents)
        pcp_count = len(aa_parents)

        for parent, child in zip(aa_parents, aa_children):
            if parent == child:
                raise ValueError(
                    f"Found an identical parent and child sequence: {parent}"
                )

        self.max_aa_seq_len = max(len(seq) for seq in aa_parents)
        self.aa_parents_onehot = torch.zeros((pcp_count, self.max_aa_seq_len, 20))
        self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len))

        # padding_mask is True for padding positions.
        self.padding_mask = torch.ones(
            (pcp_count, self.max_aa_seq_len), dtype=torch.bool
        )

        for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
            aa_indices_parent = sequences.aa_idx_array_of_str(aa_parent)
            aa_seq_len = len(aa_parent)
            self.aa_parents_onehot[i, torch.arange(aa_seq_len), aa_indices_parent] = 1
            self.aa_subs_indicator_tensor[i, :aa_seq_len] = torch.tensor(
                [p != c for p, c in zip(aa_parent, aa_child)], dtype=torch.float
            )
            self.padding_mask[i, :aa_seq_len] = False

    def __len__(self):
        return len(self.aa_parents_onehot)

    def __getitem__(self, idx):
        return {
            "aa_onehot": self.aa_parents_onehot[idx],
            "subs_indicator": self.aa_subs_indicator_tensor[idx],
            "padding_mask": self.padding_mask[idx],
        }


class TransformerBinaryModel(nn.Module):
    """A transformer-based model for binary selection.

    This is a model that takes in a batch of one-hot encoded sequences and outputs a binary selection matrix.

    See forward() for details.
    """

    def __init__(
        self,
        nhead: int,
        dim_feedforward: int,
        layer_count: int,
        d_model: int = 20,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.device = pick_device()
        self.d_model = d_model
        self.pos_encoder = PositionalEncoding(self.d_model, dropout)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(self.encoder_layer, layer_count)
        self.linear = nn.Linear(self.d_model, 1)

        self.to(self.device)
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, parent_onehots: Tensor, padding_mask: Tensor) -> Tensor:
        """Build a binary log selection matrix from a one-hot encoded parent sequence.

        Because we're predicting log of the selection factor, we don't use an
        activation function after the transformer.

        Parameters:
            parent_onehots: A tensor of shape (B, L, 20) representing the one-hot encoding of parent sequences.
            padding_mask: A tensor of shape (B, L) representing the padding mask for the sequence.

        Returns:
            A tensor of shape (B, L, 1) representing the log level of selection
            for each amino acid site.
        """

        parent_onehots = parent_onehots * math.sqrt(self.d_model)
        # Have to do the permutation because the positional encoding expects the
        # sequence length to be the first dimension.
        parent_onehots = self.pos_encoder(parent_onehots.permute(1, 0, 2)).permute(
            1, 0, 2
        )

        # NOTE: not masking due to MPS bug
        out = self.encoder(parent_onehots)  # , src_key_padding_mask=padding_mask)
        out = self.linear(out)
        return out.squeeze(-1)

    def prediction_of_aa_str(self, aa_str: str):
        """Do the forward method without gradients from an amino acid string and convert to numpy.

        Parameters:
            aa_str: A string of amino acids.

        Returns:
            A numpy array of the same length as the input string representing
            the level of selection for each amino acid site.
        """
        aa_onehot = sequences.aa_onehot_tensor_of_str(aa_str)

        # Create a padding mask with False values (i.e., no padding)
        padding_mask = torch.zeros(len(aa_str), dtype=torch.bool).to(self.device)

        with torch.no_grad():
            aa_onehot = aa_onehot.to(self.device)
            model_out = self(aa_onehot.unsqueeze(0), padding_mask.unsqueeze(0)).squeeze(
                0
            )
            final_out = torch.exp(model_out)
            final_out = torch.clamp(final_out, min=0.0, max=0.999)

        return final_out.cpu().numpy()


def train_model(
    pcp_df,
    nhead,
    dim_feedforward,
    layer_count,
    batch_size=32,
    num_epochs=10,
    learning_rate=0.001,
    checkpoint_dir="./_checkpoints",
    log_dir="./_logs",
):
    print("preparing data...")
    parents = pcp_df["aa_parent"]
    children = pcp_df["aa_child"]

    train_len = int(0.8 * len(parents))
    train_parents, val_parents = parents[:train_len], parents[train_len:]
    train_children, val_children = children[:train_len], children[train_len:]

    # It's important to make separate PCPDatasets for training and validation
    # because the maximum sequence length can differ between those two.
    train_set = AAPCPDataset(train_parents, train_children)
    val_set = AAPCPDataset(val_parents, val_children)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    model = TransformerBinaryModel(
        nhead=nhead, dim_feedforward=dim_feedforward, layer_count=layer_count
    )
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    device = model.device
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    writer = SummaryWriter(log_dir=log_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    bce_loss = nn.BCELoss()

    def complete_loss_fn(log_aa_mut_probs, aa_subs_indicator, padding_mask):
        predictions = torch.exp(log_aa_mut_probs)

        predictions = predictions.masked_select(~padding_mask)
        aa_subs_indicator = aa_subs_indicator.masked_select(~padding_mask)

        # In the early stages of training, we can get probabilities > 1.0 because
        # of bad parameter initialization. We clamp the predictions to be between
        # 0 and 0.999 to avoid this: out of range predictions can make NaNs
        # downstream.
        out_of_range_prediction_count = torch.sum(predictions > 1.0)
        # if out_of_range_prediction_count > 0:
        #     print(f"{out_of_range_prediction_count}\tpredictions out of range.")
        predictions = torch.clamp(predictions, min=0.0, max=0.999)

        return bce_loss(predictions, aa_subs_indicator)

    def loss_of_batch(batch):
        aa_onehot = batch["aa_onehot"].to(device)
        aa_subs_indicator = batch["subs_indicator"].to(device)
        padding_mask = batch["padding_mask"].to(device)
        log_aa_mut_probs = model(aa_onehot, padding_mask)
        return complete_loss_fn(
            log_aa_mut_probs,
            aa_subs_indicator,
            padding_mask,
        )

    def compute_avg_loss(data_loader):
        total_loss = 0
        with torch.no_grad():
            for batch in data_loader:
                total_loss += loss_of_batch(batch).item()
        return total_loss / len(data_loader)

    # Record epoch 0
    model.eval()
    avg_train_loss_epoch_zero = compute_avg_loss(train_loader)
    avg_val_loss_epoch_zero = compute_avg_loss(val_loader)
    writer.add_scalar("Training Loss", avg_train_loss_epoch_zero, 0)
    writer.add_scalar("Validation Loss", avg_val_loss_epoch_zero, 0)
    print(
        f"Epoch [0/{num_epochs}], Training Loss: {avg_train_loss_epoch_zero}, Validation Loss: {avg_val_loss_epoch_zero}"
    )

    loss_records = []

    for epoch in range(num_epochs):
        model.train()
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            loss = loss_of_batch(batch)
            loss.backward()
            optimizer.step()
            writer.add_scalar(
                "Training Loss", loss.item(), epoch * len(train_loader) + i
            )

        # Validation Loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                val_loss += loss_of_batch(batch).item()

            avg_val_loss = val_loss / len(val_loader)
            writer.add_scalar("Validation Loss", avg_val_loss, epoch)

            # Save model checkpoint
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": avg_val_loss,
                },
                f"{checkpoint_dir}/model_epoch_{epoch}.pth",
            )

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}"
        )

        loss_records.append(
            {
                "Epoch": epoch + 1,
                "Training Loss": loss.item(),
                "Validation Loss": avg_val_loss,
            }
        )

    writer.close()
    loss_df = pd.DataFrame(loss_records)
    loss_df.to_csv("training_validation_loss.csv", index=False)

    return model

@matsen
Copy link
Contributor Author

matsen commented Jun 10, 2024

Removing toy_simulation.py:

"""
These are functions for simulating amino acid mutations in a protein sequence.

So, this is not for simulating mutation-selection processes.

It corresponds to the inference happning in toy_dnsm.py.
"""

import random

import pandas as pd
from tqdm import tqdm

from epam.sequences import AA_STR_SORTED


def mimic_mutations(sequence_mutator_fn, parents, sub_counts):
    """
    Mimics mutations for a series of parent sequences.

    Parameters
    ----------
    sequence_mutator_fn : function
        Function that takes a string sequence and an integer, returns a mutated sequence with that many mutations.
    parents : pd.Series
        Series containing parent sequences as strings.
    sub_counts : pd.Series
        Series containing the number of substitutions for each parent sequence.

    Returns
    -------
    pd.Series
        Series containing mutated sequences as strings.
    """

    mutated_sequences = []

    for seq, sub_count in tqdm(
        zip(parents, sub_counts), total=len(parents), desc="Mutating sequences"
    ):
        mutated_seq = sequence_mutator_fn(seq, sub_count)
        mutated_sequences.append(mutated_seq)

    return pd.Series(mutated_sequences)


def general_mutator(aa_seq, sub_count, mut_criterion):
    """
    General function to mutate an amino acid sequence based on a criterion function.
    The function first identifies positions in the sequence that satisfy the criterion
    specified by `mut_criterion`. If the number of such positions is less than or equal
    to the number of mutations needed (`sub_count`), then mutations are made at those positions.
    If `sub_count` is greater than the number of positions satisfying the criterion, the function
    mutates all those positions and then randomly selects additional positions to reach `sub_count`
    total mutations. All mutations change the amino acid to a randomly selected new amino acid,
    avoiding a mutation to the same type.

    Parameters
    ----------
    aa_seq : str
        Original amino acid sequence.
    sub_count : int
        Number of substitutions to make.
    mut_criterion : function
        Function that takes a sequence and a position, returns True if position should be mutated.

    Returns
    -------
    str
        Mutated amino acid sequence.
    """

    def draw_new_aa_for_pos(pos):
        return random.choice([aa for aa in AA_STR_SORTED if aa != aa_seq_list[pos]])

    aa_seq_list = list(aa_seq)

    # find all positions that satisfy the mutation criterion
    mut_positions = [
        pos for pos, aa in enumerate(aa_seq_list) if mut_criterion(aa_seq, pos)
    ]

    # if fewer criterion-satisfying positions than required mutations, randomly add more
    if len(mut_positions) < sub_count:
        extra_positions = random.choices(
            [pos for pos in range(len(aa_seq_list)) if pos not in mut_positions],
            k=sub_count - len(mut_positions),
        )
        mut_positions += extra_positions

    # if more criterion-satisfying positions than required mutations, randomly remove some
    elif len(mut_positions) > sub_count:
        mut_positions = random.sample(mut_positions, sub_count)

    # perform mutations
    for pos in mut_positions:
        aa_seq_list[pos] = draw_new_aa_for_pos(pos)

    return "".join(aa_seq_list)


# Criterion functions
def tyrosine_mut_criterion(aa_seq, pos):
    return aa_seq[pos] == "Y"


def hydrophobic_mut_criterion(aa_seq, pos):
    hydrophobic_aa = set("AILMFVWY")
    return aa_seq[pos] in hydrophobic_aa


def hydrophobic_neighbor_mut_criterion(aa_seq, pos):
    """
    Criterion function that returns True if either amino acid at immediate
    neighbors are hydrophobic.
    """

    hydrophobic_aa = set("AILMFVWY")

    positions_to_check = []
    if pos > 0:
        positions_to_check.append(pos - 1)
    if pos < len(aa_seq) - 1:
        positions_to_check.append(pos + 1)

    return any(aa_seq[i] in hydrophobic_aa for i in positions_to_check)


[tyrosine_mutator, hydrophobic_mutator, hydrophobic_neighbor_mutator] = [
    lambda aa_seq, sub_count, crit=crit: general_mutator(aa_seq, sub_count, crit)
    for crit in [
        tyrosine_mut_criterion,
        hydrophobic_mut_criterion,
        hydrophobic_neighbor_mut_criterion,
    ]
]

@matsen matsen marked this pull request as ready for review June 10, 2024 21:54
"""

# Assert that sub_probs are within the range [0, 1] modulo rounding error
assert torch.all(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksung25 it turns out that I had these locally committed. They are new checks, so may make epam break.

return 1.0 - p_staying_same


def mutsel_log_pcp_probability_of(sel_matrix, parent, child, rates, sub_probs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksung25 I had to bring the content of this over as a free function. So we'll want to use this in the epam models.py.

- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '1.5.6-0' # any version from https://github.com/mamba-org/micromamba-releases
environment-name: epam
environment-name: netam
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willdumm here are the changes I mentioned-- let me know if I should change this at all.

@@ -304,7 +302,6 @@ def train_val_datasets_of_pcp_df(pcp_df, branch_length_multiplier=5.0):
class DNSMBurrito(framework.Burrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksung25 note that WrappedBinaryMutSel is totally gone now and can be dropped from epam.

@@ -0,0 +1,6 @@
black
pulp==2.7.0 # see issue https://github.com/snakemake/snakemake/issues/2607
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I could tell, this was still a good idea. Let me know if not (@ksung25 ).

@matsen matsen merged commit 29859d9 into main Jun 11, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant