end-to-end example on how to train a fully 3D resnet on MRI images with
faimed3d
from faimed3d.all import *
from torchvision.models.video import r3d_18
from fastai.distributed import *
from fastai.callback.all import SaveModelCallback
The Stanford MR-NET dataset is a collection of 1,370, performed at Stanford University Medical Center. Each examinations consists out of three different sequences. In this example the sagittal TIRM is used, as it is the best sequence to detect bone edma, a sensitive sign of a global abnormality.
The dataset can be downloaded at https://stanfordmlgroup.github.io/competitions/mrnet/. Please check out their paper for more information on the dataset and the MR-NET.
NBS_DIR = Path(os.getcwd())
DATA_DIR = NBS_DIR.parent/'data'
MODEL_DIR = NBS_DIR.parent/'models'
d = pd.read_csv(DATA_DIR/'train_knee.csv') # data not available in this repository.
d['file_names'] = [fn.replace('coronal', 'sagittal') for fn in d.file_names]
d.columns
Index(['file_names', 'abnormal', 'is_valid'], dtype='object')
In the dataframe, the first column is the absolute path to the image file and the second column is the label. Also there is a strong imbalance towards pathological images, so the normal images will be oversampled. To avoid having the same images in train and validation dataset, due to oversampling, an index is used for the validation dataset.
d_oversampled = pd.concat((d, d[d.abnormal == 0], d[d.abnormal == 0]))
A progressive resizing approach will be used to classify the sagittal images. For this images are first resized to 20 x 112 x 112 px, which allows the use of a large batch size. The training is performed in parallel on two Nvidia RTX 2080ti with 11GB of VRAM each, but would also be possible on a single GPU.
dls = ImageDataLoaders3D.from_df(d_oversampled, path = '/',
splitter = ColSplitter('is_valid'),
item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
resize_to = (20, 112, 112)),
batch_tfms = aug_transforms_3d(p_all = 0.2), # all tfms with p = 0.2 for each tfms to get called
bs = 64, val_bs = 64)
In a more serious approach, one would utilize some callbacks as the SaveModelCallback
, GradientClipping
, maybe use CrossEntropyLabelSmoothing
as loss function and clearly define train and validation set with an index, to not acidentially mix up the two datasets. It is not done in this notebook to keep it simple.
learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()],
cbs = SaveModelCallback(monitor='accuracy'), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.to_parallel()
learn.fine_tune(30, freeze_epochs = 1, wd = 1e-4)
epoch | train_loss | valid_loss | accuracy | roc_auc_score | time |
---|---|---|---|---|---|
0 | 1.344143 | 7.139972 | 0.395954 | 0.543551 | 00:30 |
Better model found at epoch 0 with accuracy value: 0.3959537446498871.
epoch | train_loss | valid_loss | accuracy | roc_auc_score | time |
---|---|---|---|---|---|
0 | 1.099595 | 0.794597 | 0.647399 | 0.701516 | 00:32 |
1 | 0.950920 | 0.837331 | 0.679191 | 0.798607 | 00:31 |
2 | 0.930535 | 0.839099 | 0.684971 | 0.809049 | 00:32 |
3 | 0.877060 | 0.566988 | 0.754335 | 0.842821 | 00:32 |
4 | 0.834508 | 0.479512 | 0.777457 | 0.860039 | 00:31 |
5 | 0.779111 | 0.725158 | 0.702312 | 0.883037 | 00:32 |
6 | 0.758589 | 0.691208 | 0.734104 | 0.842699 | 00:33 |
7 | 0.765461 | 0.956451 | 0.644509 | 0.810795 | 00:33 |
8 | 0.734858 | 0.632678 | 0.705202 | 0.857856 | 00:32 |
9 | 0.687992 | 0.452573 | 0.791907 | 0.880453 | 00:32 |
10 | 0.612130 | 0.487902 | 0.794798 | 0.895697 | 00:33 |
11 | 0.565593 | 0.430797 | 0.835260 | 0.920756 | 00:33 |
12 | 0.529080 | 0.477395 | 0.806358 | 0.905773 | 00:32 |
13 | 0.519458 | 0.448552 | 0.780347 | 0.880662 | 00:33 |
14 | 0.513008 | 0.530449 | 0.780347 | 0.907607 | 00:32 |
15 | 0.500788 | 0.404986 | 0.841040 | 0.916285 | 00:33 |
16 | 0.481881 | 0.457320 | 0.841040 | 0.924755 | 00:33 |
17 | 0.482219 | 0.578239 | 0.725434 | 0.860161 | 00:33 |
18 | 0.458752 | 0.555119 | 0.765896 | 0.903590 | 00:33 |
19 | 0.442184 | 0.473160 | 0.797688 | 0.902909 | 00:34 |
20 | 0.427631 | 0.473112 | 0.794798 | 0.924772 | 00:34 |
21 | 0.436852 | 0.398998 | 0.843931 | 0.932246 | 00:32 |
22 | 0.404188 | 0.393856 | 0.843931 | 0.930482 | 00:33 |
23 | 0.375834 | 0.380219 | 0.838150 | 0.927566 | 00:32 |
24 | 0.380004 | 0.449176 | 0.835260 | 0.919376 | 00:33 |
25 | 0.394911 | 0.389391 | 0.852601 | 0.940820 | 00:34 |
26 | 0.406509 | 0.488678 | 0.809249 | 0.934446 | 00:33 |
27 | 0.411663 | 0.488464 | 0.794798 | 0.922694 | 00:35 |
28 | 0.419972 | 0.485111 | 0.817919 | 0.929138 | 00:35 |
29 | 0.422603 | 0.511886 | 0.797688 | 0.906227 | 00:34 |
Better model found at epoch 0 with accuracy value: 0.647398829460144.
Better model found at epoch 1 with accuracy value: 0.6791907548904419.
Better model found at epoch 2 with accuracy value: 0.6849710941314697.
Better model found at epoch 3 with accuracy value: 0.7543352842330933.
Better model found at epoch 4 with accuracy value: 0.7774566411972046.
Better model found at epoch 9 with accuracy value: 0.7919074892997742.
Better model found at epoch 10 with accuracy value: 0.7947976589202881.
Better model found at epoch 11 with accuracy value: 0.8352600932121277.
Better model found at epoch 15 with accuracy value: 0.8410404920578003.
Better model found at epoch 21 with accuracy value: 0.8439306616783142.
Better model found at epoch 25 with accuracy value: 0.852601170539856.
learn.recorder.plot_loss()
Changes in MRI are often subtle and may disappear in aggresive downsampling of images. In a next step the image resolution is increased and the model ist trained some more.
dls = ImageDataLoaders3D.from_df(d_oversampled, path = '/',
splitter = RandomSplitter(seed = 42),
item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
resize_to = (20, 224, 224)),
batch_tfms = aug_transforms_3d(p_all = 0.15),
bs = 16, val_bs = 16)
learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()],
cbs = SaveModelCallback(monitor='accuracy'), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.load('model')
learn = learn.to_parallel()
learn.fine_tune(10, freeze_epochs = 1, wd = 1e-4)
epoch | train_loss | valid_loss | accuracy | roc_auc_score | time |
---|---|---|---|---|---|
0 | 0.620913 | 1.198181 | 0.543353 | 0.806269 | 01:39 |
Better model found at epoch 0 with accuracy value: 0.5433526039123535.
epoch | train_loss | valid_loss | accuracy | roc_auc_score | time |
---|---|---|---|---|---|
0 | 0.537233 | 0.602706 | 0.734104 | 0.856645 | 02:23 |
1 | 0.510992 | 0.588361 | 0.742775 | 0.870363 | 02:21 |
2 | 0.482025 | 0.509262 | 0.783237 | 0.893975 | 02:26 |
3 | 0.482205 | 0.576721 | 0.754335 | 0.887249 | 02:19 |
4 | 0.422891 | 0.550677 | 0.760116 | 0.908524 | 02:21 |
5 | 0.439781 | 0.794941 | 0.693642 | 0.904984 | 02:20 |
6 | 0.418022 | 0.628537 | 0.768786 | 0.903533 | 02:19 |
7 | 0.427365 | 0.394326 | 0.852601 | 0.921871 | 02:18 |
8 | 0.419326 | 0.459509 | 0.806358 | 0.916879 | 02:22 |
9 | 0.410638 | 0.475910 | 0.791907 | 0.920083 | 02:19 |
Better model found at epoch 0 with accuracy value: 0.7341040372848511.
Better model found at epoch 1 with accuracy value: 0.7427745461463928.
Better model found at epoch 2 with accuracy value: 0.7832369804382324.
Better model found at epoch 7 with accuracy value: 0.852601170539856.
learn.recorder.plot_loss()
For interpretation, the dataloader is build again and this time the original, not oversampled data is used, as AUC and sensitivity are influenced by the prevalence.
dls = ImageDataLoaders3D.from_df(d, path = '/',
splitter = RandomSplitter(seed = 42),
item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
resize_to = (20, 224, 224)),
batch_tfms = aug_transforms_3d(p_all = 0.15),
bs = 16, val_bs = 16)
learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()], cbs = SaveModelCallback(), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.load('model')
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
interp.print_classification_report()
precision recall f1-score support
0 0.72 0.96 0.82 46
1 0.99 0.92 0.95 204
accuracy 0.92 250
macro avg 0.86 0.94 0.89 250
weighted avg 0.94 0.92 0.93 250
These results are pretty close to the published state of the art with an accuracy of 0.92 vs. 0.85 by the Stanford group and recall of 0.92 for detecting global abnormalities (vs. 0.879 for the Stanford group, see Table 2 of their publication). Although in this example only one of three sequences was used and the metrics are calculated on the validation dataset, not on the hidden test dataset of the Stanford MR-NET challenge.