-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
215 lines (181 loc) · 9.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 17 11:52:02 2022
@author: AmayaGS
"""
import os
import os.path
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from collections import Counter
import pickle
import argparse
from PIL import Image
from PIL import ImageFile
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
# MUSTANG functions
from loaders import Loaders
from embedding_net import VGG_embedding
from create_store_graphs import create_embeddings_graphs
from graph_train_loop import train_graph_multi_wsi, test_graph_multi_wsi
from auxiliary_functions import seed_everything
from Graph_model import GAT_SAGPool
from plotting_results import plot_train_results, plot_test_results
def main():
parser = argparse.ArgumentParser(description="Multi-stain self-attention graph multiple instance learning for Whole Slide Image set classification at the patient level")
# Command line arguments
parser.add_argument("--dataset_name", type=str, default="RA", help="Dataset name")
parser.add_argument("--PATH_patches", type=str, default="df_labels.csv", help="CSV file with patch file location")
parser.add_argument("--embedding_vector_size", type=int, default=1024, help="Embedding vector size")
parser.add_argument("--learning_rate", type=float, default=0.0001, help="Learning rate")
parser.add_argument("--pooling_ratio", type=float, default=0.7, help="Pooling ratio")
parser.add_argument("--heads", type=int, default=4, help="Number of GAT heads")
parser.add_argument("--K", type=int, default=5, help="Number of nearest neighbours in k-NNG created from WSI embeddings")
parser.add_argument("--train_fraction", type=float, default=0.7, help="Train fraction")
#parser.add_argument("--slide_batch", type=int, default=10, help="Slide batch size")
parser.add_argument("--num_epochs", type=int, default=60, help="Number of training epochs")
parser.add_argument("--n_classes", type=int, default=2, help="Number of classes")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
parser.add_argument("--graph_batch_size", type=int, default=1, help="Graph batch size for training")
parser.add_argument("--checkpoint", action="store_false", default=True, help="Enable checkpointing of GNN weights. Set to False if you don't want to store checkpoints.")
args = parser.parse_args()
# Set environment variables
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Check for GPU availability
use_gpu = torch.cuda.is_available()
if use_gpu:
print("Using CUDA")
#device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set image properties
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
plt.ion()
# Image transforms # TODO
train_transform = transforms.Compose([
transforms.RandomChoice([
transforms.ColorJitter(brightness=0.1),
transforms.ColorJitter(contrast=0.1),
transforms.ColorJitter(saturation=0.1),
transforms.ColorJitter(hue=0.1)]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Parameters
seed = args.seed
seed_everything(seed)
train_fraction = args.train_fraction
subset= False # TODO
slide_batch = 10 # TODO. change Dataset to Iterable dataset to solve this problem. this needs to be larger than one, otherwise Dataloader can fail when only passed a None object from collate function.
K= args.K
num_workers = args.num_workers
batch_size = args.graph_batch_size
creating_knng = False # TODO
creating_embedding = False # TODO
train_graph = True # TODO
embedding_vector_size = args.embedding_vector_size
learning_rate = args.learning_rate
pooling_ratio = args.pooling_ratio
heads = args.heads
num_epochs = args.num_epochs
TRAIN = True
TEST = True
label = 'label'
patient_id = 'Patient ID'
dataset_name = args.dataset_name
n_classes= args.n_classes
checkpoint = args.checkpoint
current_directory = Path(__file__).resolve().parent
run_results_folder = f"graph_{dataset_name}_{seed}_{heads}_{pooling_ratio}_{learning_rate}"
results = os.path.join(current_directory, "results/" + run_results_folder)
checkpoints = results + "/checkpoints"
os.makedirs(results, exist_ok = True)
os.makedirs(checkpoints, exist_ok = True)
# Load the dataset
df = pd.read_csv(args.PATH_patches, header=0)
df = df.dropna(subset=[label])
# Define collate function
def collate_fn_none(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
# create k-NNG with VGG patch embedddings
if creating_knng:
file_ids, train_ids, test_ids = Loaders().train_test_ids(df, train_fraction, seed, patient_id, label, subset)
train_subset, test_subset = Loaders().df_loader(df, train_transform, test_transform, train_ids, test_ids, patient_id, label, subset=subset)
train_slides, test_slides = Loaders().slides_dataloader(train_subset, test_subset, train_ids, test_ids, train_transform, test_transform, slide_batch=slide_batch, num_workers=num_workers, shuffle=False, collate=collate_fn_none, label=label, patient_id=patient_id)
embedding_net = VGG_embedding(embedding_vector_size=embedding_vector_size, n_classes=n_classes)
if use_gpu:
embedding_net.cuda()
# Save k-NNG with VGG patch embedddings for future use
slides_dict = {('train_graph_dict_', 'train_embedding_dict_') : train_slides ,
('test_graph_dict_', 'test_embedding_dict_'): test_slides}
for file_prefix, slides in slides_dict.items():
graph_dict, embedding_dict = create_embeddings_graphs(embedding_net, slides, k=K, mode='connectivity', include_self=False)
print(f"Started saving {file_prefix[0]} to file")
with open(f"{file_prefix[0]}{dataset_name}.pkl", "wb") as file:
pickle.dump(graph_dict, file) # encode dict into Pickle
print("Done writing graph dict into pickle file")
print(f"Started saving {file_prefix[1]} to file")
with open(f"{file_prefix[1]}{dataset_name}.pkl", "wb") as file:
pickle.dump(embedding_dict, file) # encode dict into Pickle
print("Done writing embedding dict into pickle file")
# load pickled embeddings and graphs
if not creating_knng:
with open(f"train_graph_dict_{dataset_name}.pkl", "rb") as train_file:
# Load the dictionary from the file
train_graph_dict = pickle.load(train_file)
with open(f"test_graph_dict_{dataset_name}.pkl", "rb") as test_file:
# Load the dictionary from the file
test_graph_dict = pickle.load(test_file)
# TODO
if not creating_embedding:
with open(f"train_embedding_dict_{dataset_name}.pkl", "rb") as train_file:
# Load the dictionary from the file
train_embedding_dict = pickle.load(train_file)
with open(f"test_embedding_dict_{dataset_name}.pkl", "rb") as test_file:
# Load the dictionary from the file
test_embedding_dict = pickle.load(test_file)
# calculate weights for minority oversampling
count = []
for k, v in train_graph_dict.items():
count.append(v[1].item())
counter = Counter(count)
class_count = np.array(list(counter.values()))
weight = 1 / class_count
samples_weight = np.array([weight[t] for t in count])
samples_weight = torch.from_numpy(samples_weight)
sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), num_samples=len(samples_weight), replacement=True)
# MULTI-STAIN GRAPH
train_graph_loader = torch.utils.data.DataLoader(train_graph_dict, batch_size=batch_size, shuffle=False, num_workers=num_workers, sampler=sampler, drop_last=False)
#train_graph_loader = torch_geometric.loader.DataLoader(train_graph_dict, batch_size=1, shuffle=False, num_workers=0, sampler=sampler, drop_last=False, generator=seed_everything(state)) #TODO MINIBATCHING
test_graph_loader = torch.utils.data.DataLoader(test_graph_dict, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
graph_net = GAT_SAGPool(embedding_vector_size, heads=heads, pooling_ratio=pooling_ratio)
loss_fn = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(graph_net.parameters(), lr=learning_rate)
if use_gpu:
graph_net.cuda()
if TRAIN:
graph_weights, results_dict = train_graph_multi_wsi(graph_net, train_graph_loader, test_graph_loader, loss_fn, optimizer_ft, n_classes=n_classes, num_epochs=num_epochs, checkpoint=checkpoint, checkpoint_path= checkpoints + "/checkpoint_")
torch.save(graph_weights.state_dict(), results + "\\" + run_results_folder + ".pth")
df_results = pd.DataFrame.from_dict(results_dict)
df_results.to_csv(results + "\\" + run_results_folder + ".csv", index=False)
plot = plot_train_results(df_results, save= results + "\\")
plot.plot()
if TEST:
graph_net.load_state_dict(torch.load(results + "\\" + run_results_folder + ".pth"), strict=True)
labels, prob, conf_matrix, sensitivity, specificity = test_graph_multi_wsi(graph_net, test_graph_loader, loss_fn, n_classes=n_classes)
plot = plot_test_results(labels, prob, conf_matrix, target_names=["Fibroid", "M/Lymphoid"], save= results + "\\")
plot.plot()
# %%
if __name__ == "__main__":
main()