Skip to content

Commit

Permalink
Merge pull request #73 from BIMSBbioinfo/finetuning
Browse files Browse the repository at this point in the history
Implement a finetuning feature
  • Loading branch information
borauyar authored May 1, 2024
2 parents 53e9b3a + 02a0906 commit 33f28f5
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 112 deletions.
32 changes: 26 additions & 6 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from lightning import seed_everything
# Set the seed for all the possible random number generators.
seed_everything(42, workers=True)
import argparse
import lightning as pl
from typing import NamedTuple
import os
import os, yaml, torch, time, random, warnings, argparse
os.environ["OMP_NUM_THREADS"] = "1"
import yaml
import torch
import pandas as pd
import flexynesis
from flexynesis.models import *
import warnings
import time
from lightning.pytorch.callbacks import EarlyStopping


def main():
parser = argparse.ArgumentParser(description="Flexynesis - Your PyTorch model training interface",
Expand All @@ -34,6 +32,7 @@ def main():
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
parser.add_argument("--fusion_type", help="How to fuse the omics layers", type=str, choices=["early", "intermediate"], default = 'intermediate')
parser.add_argument("--hpo_iter", help="Number of iterations for hyperparameter optimisation", type=int, default = 5)
parser.add_argument("--finetuning_samples", help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning", type=int, default = 0)
parser.add_argument("--correlation_threshold", help="Correlation threshold to drop highly redundant features (default: 0.8; set to 1 for no redundancy filtering)", type=float, default = 0.8)
parser.add_argument("--restrict_to_features", help="Restrict the analyis to the list of features provided by the user (default: None)", type = str, default = None)
parser.add_argument("--subsample", help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0", type=int, default = 0)
Expand Down Expand Up @@ -218,6 +217,27 @@ class AvailableModels(NamedTuple):
# do a hyperparameter search training multiple models and get the best_configuration
model, best_params = tuner.perform_tuning()

# if fine-tuning is enabled; fine tune the model on a portion of test samples
if args.finetuning_samples > 0:
finetuneSampleN = args.finetuning_samples
print("[INFO] Finetuning the model on ",finetuneSampleN,"test samples")
# split test dataset into finetuning and holdout datasets
all_indices = range(len(test_dataset))
finetune_indices = random.sample(all_indices, finetuneSampleN)
holdout_indices = list(set(all_indices) - set(finetune_indices))
finetune_dataset = test_dataset.subset(finetune_indices)
holdout_dataset = test_dataset.subset(holdout_indices)

# fine tune on the finetuning dataset; freeze the encoders
finetuner = flexynesis.FineTuner(model,
finetune_dataset)
finetuner.run_experiments()

# update the model to finetuned model
model = finetuner.model
# update the test dataset to exclude finetuning samples
test_dataset = holdout_dataset

# evaluate predictions
print("[INFO] Computing model evaluation metrics")
metrics_df = flexynesis.evaluate_wrapper(model.predict(test_dataset), test_dataset,
Expand Down
17 changes: 17 additions & 0 deletions flexynesis/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,23 @@ def __len__ (self):
"""
return len(self.samples)

def subset(self, indices):
"""Create a new dataset object containing only the specified indices.
Args:
indices (list of int): The indices of the samples to include in the subset.
Returns:
MultiomicDataset: A new dataset object with the same structure but only containing the selected samples.
"""
subset_dat = {x: self.dat[x][indices] for x in self.dat.keys()}
subset_ann = {x: self.ann[x][indices] for x in self.ann.keys()}
subset_samples = [self.samples[idx] for idx in indices]

# Create a new dataset object
return MultiomicDataset(subset_dat, subset_ann, self.variable_types, self.features,
subset_samples, self.label_mappings, self.feature_ann)

def get_feature_subset(self, feature_df):
"""Get a subset of data matrices corresponding to specified features and concatenate them into a pandas DataFrame.
Expand Down
105 changes: 105 additions & 0 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from skopt import Optimizer
from skopt.utils import use_named_args
from .config import search_spaces
from .data import TripletMultiOmicDataset

import numpy as np

Expand Down Expand Up @@ -213,6 +214,110 @@ def load_and_convert_config(self, config_path):
return search_space_user


from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import KFold
import numpy as np
import random, copy, logging

class FineTuner(pl.LightningModule):
def __init__(self, model, dataset, n_splits=5, batch_size=32, learning_rates=None, max_epoch = 50, freeze_configs = None):
super().__init__()
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
self.original_model = model
self.dataset = dataset # Use the entire dataset
self.n_splits = n_splits
self.batch_size = batch_size
self.kfold = KFold(n_splits=self.n_splits, shuffle=True)
self.learning_rates = learning_rates if learning_rates else [model.config['lr'], model.config['lr']/10, model.config['lr']/100]
self.folds_data = list(self.kfold.split(np.arange(len(self.dataset))))
self.max_epoch = max_epoch
self.freeze_configs = freeze_configs if freeze_configs else [
{'encoders': True, 'supervisors': False},
{'encoders': False, 'supervisors': True},
{'encoders': False, 'supervisors': False}
]

def apply_freeze_config(self, config):
# Freeze or unfreeze encoders
for encoder in self.model.encoders:
for param in encoder.parameters():
param.requires_grad = not config['encoders']

# Freeze or unfreeze supervisors
for mlp in self.model.MLPs.values():
for param in mlp.parameters():
param.requires_grad = not config['supervisors']

def train_dataloader(self):
# Override to load data for the current fold
train_idx, val_idx = self.folds_data[self.current_fold]
train_subset = torch.utils.data.Subset(self.dataset, train_idx)
return DataLoader(train_subset, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
# Override to load validation data for the current fold
train_idx, val_idx = self.folds_data[self.current_fold]
val_subset = torch.utils.data.Subset(self.dataset, val_idx)
return DataLoader(val_subset, batch_size=self.batch_size)

def training_step(self, batch, batch_idx):
return self.model.training_step(batch, batch_idx, log=False)

def validation_step(self, batch, batch_idx):
# Call the model's validation step without logging
val_loss = self.model.validation_step(batch, batch_idx, log=False) # Assuming you can disable logging
self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
return val_loss

def configure_optimizers(self):
return torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.learning_rate)

def run_experiments(self):
val_loss_results = []
for lr in self.learning_rates:
for config in self.freeze_configs:
fold_losses = []
epochs = [] # record how many epochs the training happened
for fold in range(self.n_splits):
model_copy = copy.deepcopy(self.original_model) # Deep copy the model for each fold
self.model = model_copy
self.apply_freeze_config(config) # try freezing different components
self.current_fold = fold
self.learning_rate = lr
early_stopping = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=False,
mode='min'
)
trainer = pl.Trainer(max_epochs=self.max_epoch, devices=1, accelerator='auto', logger=False, enable_checkpointing=False,
enable_progress_bar = False, enable_model_summary=False, callbacks=[early_stopping])
trainer.fit(self)
stopped_epoch = early_stopping.stopped_epoch
val_loss = trainer.validate(self.model, verbose = False)
fold_losses.append(val_loss[0]['val_loss']) # Adjust based on your validation output format
epochs.append(stopped_epoch)
#print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}")
avg_val_loss = np.mean(fold_losses)
avg_epochs = int(np.mean(epochs))
print(f"[INFO] average 5-fold cross-validation loss {avg_val_loss} for learning rate: {lr} freeze {config}, average epochs {avg_epochs}")
val_loss_results.append({'learning_rate': lr, 'average_val_loss': avg_val_loss, 'freeze': config, 'epochs': avg_epochs})

# Find the best configuration based on validation loss
best_config = min(val_loss_results, key=lambda x: x['average_val_loss'])
print(f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}",
f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}")

# build a final model using the best setup on all samples
final_model = copy.deepcopy(self.model)
self.model = final_model
self.learning_rate = best_config['learning_rate']
self.apply_freeze_config(best_config['freeze'])
dl = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
final_trainer = pl.Trainer(max_epochs=best_config['epochs'], devices=1, accelerator='auto', logger=False, enable_checkpointing=False)
final_trainer.fit(self, train_dataloaders=dl)


import matplotlib.pyplot as plt
from IPython.display import clear_output
from lightning import Callback
Expand Down
10 changes: 6 additions & 4 deletions flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def compute_total_loss(self, losses):
total_loss = sum(losses.values())
return total_loss

def training_step(self, train_batch, batch_idx):
def training_step(self, train_batch, batch_idx, log = True):
"""
Perform a single training step.
Args:
Expand Down Expand Up @@ -159,10 +159,11 @@ def training_step(self, train_batch, batch_idx):
total_loss = self.compute_total_loss(losses)
# add train loss for logging
losses['train_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def validation_step(self, val_batch, batch_idx):
def validation_step(self, val_batch, batch_idx, log = True):
"""
Perform a single validation step.
Expand Down Expand Up @@ -191,7 +192,8 @@ def validation_step(self, val_batch, batch_idx):
losses[var] = loss
total_loss = sum(losses.values())
losses['val_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def prepare_data_loaders(self, dataset):
Expand Down
10 changes: 6 additions & 4 deletions flexynesis/models/supervised_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def compute_total_loss(self, losses):
return total_loss


def training_step(self, train_batch, batch_idx):
def training_step(self, train_batch, batch_idx, log = True):
dat, y_dict = train_batch
layers = dat.keys()
x_list = [dat[x] for x in layers]
Expand Down Expand Up @@ -244,10 +244,11 @@ def training_step(self, train_batch, batch_idx):
total_loss = self.compute_total_loss(losses)
# add total loss for logging
losses['train_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def validation_step(self, val_batch, batch_idx):
def validation_step(self, val_batch, batch_idx, log = True):
dat, y_dict = val_batch
layers = dat.keys()
x_list = [dat[x] for x in layers]
Expand All @@ -274,7 +275,8 @@ def validation_step(self, val_batch, batch_idx):

total_loss = sum(losses.values())
losses['val_loss'] = total_loss
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
if log:
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
return total_loss

def prepare_data(self):
Expand Down
Loading

0 comments on commit 33f28f5

Please sign in to comment.