-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlight.py
133 lines (110 loc) · 5.69 KB
/
light.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
from datetime import datetime
from pathlib import Path
from warnings import filterwarnings
import lightning.pytorch as pl
import pandas as pd
import torch
import yaml
from tqdm import tqdm
from dataset import TCGA_Program_Dataset
from datasets_manager import TCGA_Balanced_Datasets_Manager, TCGA_Datasets_Manager
from lit_models import LitFullModel
from model import Classifier, Feature_Extractor, Graph_And_Clinical_Feature_Extractor, Task_Classifier
from utils import config_add_subdict_key, get_logger, override_n_genes, set_random_seed, setup_logging
SEED = 1126
set_random_seed(SEED)
def main():
# Select a config file.
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help='Path to the config file.', required=True)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
override_n_genes(config) # For multi-task graph models.
config_name = Path(args.config).stem
# Setup logging.
setup_logging(log_path := f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/')
logger = get_logger(config_name)
logger.info(f'Using Random Seed {SEED} for this experiment')
get_logger('lightning.pytorch.accelerators.cuda', log_level='WARNING') # Disable cuda logging.
filterwarnings('ignore', r'.*Skipping val loop.*') # Disable val loop warning.
# Create dataset manager.
#here use torch lightning DS
data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}
if 'TCGA_Balanced_Datasets_Manager' == config['datasets_manager']['type']:
manager = TCGA_Balanced_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))
else:
manager = TCGA_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))
valid_results = []
# Cross validation.
for key, values in manager['TCGA_BLC']['dataloaders'].items():
if isinstance(key, int) and config['cross_validation']:
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer( # Create sub-folders for each fold.
default_root_dir=log_path,
max_epochs=config['max_epochs'],
log_every_n_steps=1,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(lit_model, train_dataloaders=values['train'], val_dataloaders=values['valid'])
trainer.test(lit_model, dataloaders=values['valid'], verbose=False) #verbose true prints the validation results for each fold
# print validation results
valid_results.append(trainer.test(lit_model, dataloaders=values['valid'], verbose=False)[0])
elif key == 'train':
train = values
elif key == 'test':
test = values
# Print validation results.
valid_results = pd.DataFrame.from_records(valid_results)
for key, value in valid_results.describe().loc[['mean', 'std']].to_dict().items():
logger.info(f'| {key.ljust(10).upper()} | {value["mean"]:.5f} ± {value["std"]:.5f} |')
# Train the final model.
models, optimizers = create_models_and_optimizers(config)
lit_model = LitFullModel(models, optimizers, config)
trainer = pl.Trainer(
default_root_dir=log_path,
max_epochs=config['max_epochs'],
enable_progress_bar=False,
log_every_n_steps=1,
logger=True,
)
trainer.fit(lit_model, train_dataloaders=train)
# Test the final model.
bootstrap_results = []
for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'):
bootstrap_results.append(trainer.test(lit_model, dataloaders=test, verbose=False)[0])
bootstrap_results = pd.DataFrame.from_records(bootstrap_results)
for key, value in bootstrap_results.describe().loc[['mean', 'std']].to_dict().items():
logger.info(f'| {key.ljust(10).upper()} | {value["mean"]:.5f} ± {value["std"]:.5f} |')
def create_models_and_optimizers(config: dict):
models: dict[str, torch.nn.Module] = {}
optimizers: dict[str, torch.optim.Optimizer] = {}
# Setup models. Do not use getattr() for better IDE support.
for model_name, kargs in config['models'].items():
if model_name == 'Graph_And_Clinical_Feature_Extractor':
models['feat_ext'] = Graph_And_Clinical_Feature_Extractor(**kargs)
elif model_name == 'Feature_Extractor':
models['feat_ext'] = Feature_Extractor(**kargs)
elif model_name == 'Task_Classifier':
models['clf'] = Task_Classifier(**kargs)
elif model_name == 'Classifier':
models['clf'] = Classifier(**kargs)
else:
raise ValueError(f'Unknown model type: {model_name}')
# Setup optimizers. If the key is 'all', the optimizer will be applied to all models.
for key, optim_dict in config['optimizers'].items():
opt_name = next(iter(optim_dict))
if key == 'all':
params = [param for model in models.values() for param in model.parameters()]
optimizers[key] = getattr(torch.optim, opt_name)(params, **optim_dict[opt_name])
else:
optimizers[key] = getattr(torch.optim, opt_name)(models[key].parameters(), **optim_dict[opt_name])
# Add models' structure to config for logging. TODO: Prettify.
for model_name, torch_model in models.items():
config[f'model.{model_name}'] = str(torch_model)
return models, optimizers
if __name__ == '__main__':
main()