Skip to content

Commit

Permalink
Move data loaders outside of GCNN model; use torch_geometric data loa…
Browse files Browse the repository at this point in the history
…der conditionally when GNNs are used
  • Loading branch information
borauyar committed May 14, 2024
1 parent 1e1fc7c commit f796c47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 26 deletions.
14 changes: 10 additions & 4 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
seed_everything(42, workers=True)
import torch
from torch.utils.data import DataLoader, random_split
import torch_geometric

import lightning as pl
from lightning.pytorch.callbacks import RichProgressBar
Expand Down Expand Up @@ -81,9 +82,14 @@ def __init__(self, dataset, model_class, config_name, target_variables,
self.input_layers = input_layers
self.output_layers = output_layers

self.DataLoader = DataLoader # use torch data loader by default

if self.model_class.__name__ == 'MultiTripletNetwork':
self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0])

if self.model_class.__name__ == 'DirectPredGCNN':
# use torch_geometric data loader for GCNN class
self.DataLoader = torch_geometric.loader.DataLoader

# If config_path is provided, use it
if config_path:
external_config = self.load_and_convert_config(config_path)
Expand Down Expand Up @@ -153,7 +159,7 @@ def objective(self, params, current_step, total_steps, full_train = False):

if full_train:
# Train on the full dataset
full_loader = DataLoader(self.loader_dataset, batch_size=int(params['batch_size']),
full_loader = self.DataLoader(self.loader_dataset, batch_size=int(params['batch_size']),
shuffle=True, pin_memory=True, drop_last=True)
model = self.model_class(**model_args)
trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True)
Expand All @@ -180,8 +186,8 @@ def objective(self, params, current_step, total_steps, full_train = False):
print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}")
train_subset = torch.utils.data.Subset(self.loader_dataset, train_index)
val_subset = torch.utils.data.Subset(self.loader_dataset, val_index)
train_loader = DataLoader(train_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=False)
train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=True, drop_last=True)
val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=False)

model = self.model_class(**model_args)
trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps)
Expand Down
23 changes: 1 addition & 22 deletions flexynesis/models/direct_pred_gcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
batch_variables=None,
surv_event_var=None,
surv_time_var=None,
val_size=0.2,
use_loss_weighting=True,
device_type = None,
gnn_conv_type = None
Expand All @@ -43,16 +42,13 @@ def __init__(
self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables
self.variable_types = dataset.variable_types
self.ann = dataset.ann
self.val_size = val_size

self.feature_importances = {}
self.use_loss_weighting = use_loss_weighting

self.device_type = device_type
self.gnn_conv_type = gnn_conv_type

self.prepare_data_loaders(dataset)


if self.use_loss_weighting:
# Initialize log variance parameters for uncertainty weighting
self.log_vars = nn.ParameterDict()
Expand Down Expand Up @@ -192,23 +188,6 @@ def validation_step(self, val_batch, batch_idx):
self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=int(x_list[0].batch_size))
return total_loss

def prepare_data_loaders(self, dataset):
# Split the dataset
train_size = int(len(dataset) * (1 - self.val_size))
val_size = len(dataset) - train_size
dat_train, dat_val = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# Create data loaders
self.train_loader = DataLoader(dat_train, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=True, drop_last=True)
self.val_loader = DataLoader(dat_val, batch_size=int(self.config['batch_size']),
num_workers=0, pin_memory=True, shuffle=False)

def train_dataloader(self):
return self.train_loader

def val_dataloader(self):
return self.val_loader

def predict(self, dataset):
self.eval()
Expand Down

0 comments on commit f796c47

Please sign in to comment.