-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlight_external_train_test.py
130 lines (105 loc) · 5.36 KB
/
light_external_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
from datetime import datetime
from pathlib import Path
from warnings import filterwarnings
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl
import pandas as pd
import torch
import yaml
from tqdm import tqdm
from dataset import TCGA_Program_Dataset
from datasets_manager import TCGA_Balanced_Datasets_Manager, TCGA_Datasets_Manager
from lit_models import LitFullModel
from model import Classifier, Feature_Extractor, Graph_And_Clinical_Feature_Extractor, Task_Classifier
from utils import config_add_subdict_key, get_logger, override_n_genes, set_random_seed, setup_logging
from external_lightningdatamodule import ExternalDataModule
SEED = 1126
set_random_seed(SEED)
def main():
# Select a config file.
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help='Path to the config file.', required=True)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
override_n_genes(config) # For multi-task graph models.
config_name = Path(args.config).stem
# Setup logging.
setup_logging(log_path := f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/')
logger = get_logger(config_name)
logger.info(f'Using Random Seed {SEED} for this experiment')
get_logger('lightning.pytorch.accelerators.cuda', log_level='WARNING') # Disable cuda logging.
filterwarnings('ignore', r'.*Skipping val loop.*') # Disable val loop warning.
# Create dataset manager.
#here use torch lightning DS
#data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}
#add the external data
#use same file as for the MTL trained on TCGA, but tested on SCLC
external_testing_data = ExternalDataModule(**config['external_datasets'])
external_testing_data.setup(only_test=False)
external_testing_dataloader = external_testing_data.test_dataloader()
train = external_testing_data.train_dataloader()
test = external_testing_data.test_dataloader()
# Cross validation, adapt the code for tensting on the external dataset batches
# still need to create he proper shuffling of the data
if config['cross_validation']:
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer( # Create sub-folders for each fold.
default_root_dir=log_path,
max_epochs=config['max_epochs'],
log_every_n_steps=1,
enable_model_summary=False,
enable_checkpointing=False,
)
#trainer.fit(lit_model, train_dataloaders=values['train'], val_dataloaders=values['valid']) #the validation is failing
trainer.fit(lit_model, train_dataloaders=train)
trainer.test(lit_model, dataloaders=test, verbose=True)
# Train the final model.
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer(
default_root_dir=log_path,
max_epochs=config['max_epochs'],
enable_progress_bar=False,
log_every_n_steps=1,
logger=False,
)
trainer.fit(lit_model, train_dataloaders=train)
# Test the final model.
bootstrap_results = []
for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'):
bootstrap_results.append(trainer.test(lit_model, dataloaders=test, verbose=False)[0])
bootstrap_results = pd.DataFrame.from_records(bootstrap_results)
for key, value in bootstrap_results.describe().loc[['mean', 'std']].to_dict().items():
logger.info(f'| {key.ljust(10).upper()} | {value["mean"]:.5f} ± {value["std"]:.5f} |')
def create_models_and_optimizers(config: dict):
models: dict[str, torch.nn.Module] = {}
optimizers: dict[str, torch.optim.Optimizer] = {}
# Setup models. Do not use getattr() for better IDE support.
for model_name, kargs in config['models'].items():
if model_name == 'Graph_And_Clinical_Feature_Extractor':
models['feat_ext'] = Graph_And_Clinical_Feature_Extractor(**kargs)
elif model_name == 'Feature_Extractor':
models['feat_ext'] = Feature_Extractor(**kargs)
elif model_name == 'Task_Classifier':
models['clf'] = Task_Classifier(**kargs)
elif model_name == 'Classifier':
models['clf'] = Classifier(**kargs)
else:
raise ValueError(f'Unknown model type: {model_name}')
# Setup optimizers. If the key is 'all', the optimizer will be applied to all models.
for key, optim_dict in config['optimizers'].items():
opt_name = next(iter(optim_dict))
if key == 'all':
params = [param for model in models.values() for param in model.parameters()]
optimizers[key] = getattr(torch.optim, opt_name)(params, **optim_dict[opt_name])
else:
optimizers[key] = getattr(torch.optim, opt_name)(models[key].parameters(), **optim_dict[opt_name])
# Add models' structure to config for logging. TODO: Prettify.
for model_name, torch_model in models.items():
config[f'model.{model_name}'] = str(torch_model)
return models, optimizers
if __name__ == '__main__':
main()