Skip to content

Latest commit

 

History

History
587 lines (516 loc) · 14.7 KB

3d_classification.md

File metadata and controls

587 lines (516 loc) · 14.7 KB

Example - Train a 3d Classifier in Stanford MR-NET Dataset

end-to-end example on how to train a fully 3D resnet on MRI images with faimed3d

from faimed3d.all import *
from torchvision.models.video import r3d_18
from fastai.distributed import *
from fastai.callback.all import SaveModelCallback

The Stanford MR-NET dataset is a collection of 1,370, performed at Stanford University Medical Center. Each examinations consists out of three different sequences. In this example the sagittal TIRM is used, as it is the best sequence to detect bone edma, a sensitive sign of a global abnormality.
The dataset can be downloaded at https://stanfordmlgroup.github.io/competitions/mrnet/. Please check out their paper for more information on the dataset and the MR-NET.

NBS_DIR = Path(os.getcwd())
DATA_DIR = NBS_DIR.parent/'data'
MODEL_DIR = NBS_DIR.parent/'models'
d = pd.read_csv(DATA_DIR/'train_knee.csv') # data not available in this repository.
d['file_names'] = [fn.replace('coronal', 'sagittal') for fn in d.file_names]
d.columns
Index(['file_names', 'abnormal', 'is_valid'], dtype='object')

In the dataframe, the first column is the absolute path to the image file and the second column is the label. Also there is a strong imbalance towards pathological images, so the normal images will be oversampled. To avoid having the same images in train and validation dataset, due to oversampling, an index is used for the validation dataset.

d_oversampled = pd.concat((d, d[d.abnormal == 0], d[d.abnormal == 0]))

Baseline at 112 px

A progressive resizing approach will be used to classify the sagittal images. For this images are first resized to 20 x 112 x 112 px, which allows the use of a large batch size. The training is performed in parallel on two Nvidia RTX 2080ti with 11GB of VRAM each, but would also be possible on a single GPU.

dls = ImageDataLoaders3D.from_df(d_oversampled, path = '/',
                                splitter = ColSplitter('is_valid'),
                                item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
                                                         resize_to = (20, 112, 112)), 
                                batch_tfms = aug_transforms_3d(p_all = 0.2), # all tfms with p = 0.2 for each tfms to get called
                                bs = 64, val_bs = 64)

In a more serious approach, one would utilize some callbacks as the SaveModelCallback, GradientClipping, maybe use CrossEntropyLabelSmoothing as loss function and clearly define train and validation set with an index, to not acidentially mix up the two datasets. It is not done in this notebook to keep it simple.

learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()], 
                       cbs = SaveModelCallback(monitor='accuracy'), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.to_parallel()
learn.fine_tune(30, freeze_epochs = 1, wd = 1e-4)
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.344143 7.139972 0.395954 0.543551 00:30
Better model found at epoch 0 with accuracy value: 0.3959537446498871.
epoch train_loss valid_loss accuracy roc_auc_score time
0 1.099595 0.794597 0.647399 0.701516 00:32
1 0.950920 0.837331 0.679191 0.798607 00:31
2 0.930535 0.839099 0.684971 0.809049 00:32
3 0.877060 0.566988 0.754335 0.842821 00:32
4 0.834508 0.479512 0.777457 0.860039 00:31
5 0.779111 0.725158 0.702312 0.883037 00:32
6 0.758589 0.691208 0.734104 0.842699 00:33
7 0.765461 0.956451 0.644509 0.810795 00:33
8 0.734858 0.632678 0.705202 0.857856 00:32
9 0.687992 0.452573 0.791907 0.880453 00:32
10 0.612130 0.487902 0.794798 0.895697 00:33
11 0.565593 0.430797 0.835260 0.920756 00:33
12 0.529080 0.477395 0.806358 0.905773 00:32
13 0.519458 0.448552 0.780347 0.880662 00:33
14 0.513008 0.530449 0.780347 0.907607 00:32
15 0.500788 0.404986 0.841040 0.916285 00:33
16 0.481881 0.457320 0.841040 0.924755 00:33
17 0.482219 0.578239 0.725434 0.860161 00:33
18 0.458752 0.555119 0.765896 0.903590 00:33
19 0.442184 0.473160 0.797688 0.902909 00:34
20 0.427631 0.473112 0.794798 0.924772 00:34
21 0.436852 0.398998 0.843931 0.932246 00:32
22 0.404188 0.393856 0.843931 0.930482 00:33
23 0.375834 0.380219 0.838150 0.927566 00:32
24 0.380004 0.449176 0.835260 0.919376 00:33
25 0.394911 0.389391 0.852601 0.940820 00:34
26 0.406509 0.488678 0.809249 0.934446 00:33
27 0.411663 0.488464 0.794798 0.922694 00:35
28 0.419972 0.485111 0.817919 0.929138 00:35
29 0.422603 0.511886 0.797688 0.906227 00:34
Better model found at epoch 0 with accuracy value: 0.647398829460144.
Better model found at epoch 1 with accuracy value: 0.6791907548904419.
Better model found at epoch 2 with accuracy value: 0.6849710941314697.
Better model found at epoch 3 with accuracy value: 0.7543352842330933.
Better model found at epoch 4 with accuracy value: 0.7774566411972046.
Better model found at epoch 9 with accuracy value: 0.7919074892997742.
Better model found at epoch 10 with accuracy value: 0.7947976589202881.
Better model found at epoch 11 with accuracy value: 0.8352600932121277.
Better model found at epoch 15 with accuracy value: 0.8410404920578003.
Better model found at epoch 21 with accuracy value: 0.8439306616783142.
Better model found at epoch 25 with accuracy value: 0.852601170539856.
learn.recorder.plot_loss()

png

Resizing

Changes in MRI are often subtle and may disappear in aggresive downsampling of images. In a next step the image resolution is increased and the model ist trained some more.

dls = ImageDataLoaders3D.from_df(d_oversampled, path = '/',
                                splitter = RandomSplitter(seed = 42),
                                item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
                                                         resize_to = (20, 224, 224)), 
                                batch_tfms = aug_transforms_3d(p_all = 0.15), 
                                bs = 16, val_bs = 16)
learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()], 
                       cbs = SaveModelCallback(monitor='accuracy'), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.load('model')
learn = learn.to_parallel()
learn.fine_tune(10, freeze_epochs = 1, wd = 1e-4)
epoch train_loss valid_loss accuracy roc_auc_score time
0 0.620913 1.198181 0.543353 0.806269 01:39
Better model found at epoch 0 with accuracy value: 0.5433526039123535.
epoch train_loss valid_loss accuracy roc_auc_score time
0 0.537233 0.602706 0.734104 0.856645 02:23
1 0.510992 0.588361 0.742775 0.870363 02:21
2 0.482025 0.509262 0.783237 0.893975 02:26
3 0.482205 0.576721 0.754335 0.887249 02:19
4 0.422891 0.550677 0.760116 0.908524 02:21
5 0.439781 0.794941 0.693642 0.904984 02:20
6 0.418022 0.628537 0.768786 0.903533 02:19
7 0.427365 0.394326 0.852601 0.921871 02:18
8 0.419326 0.459509 0.806358 0.916879 02:22
9 0.410638 0.475910 0.791907 0.920083 02:19
Better model found at epoch 0 with accuracy value: 0.7341040372848511.
Better model found at epoch 1 with accuracy value: 0.7427745461463928.
Better model found at epoch 2 with accuracy value: 0.7832369804382324.
Better model found at epoch 7 with accuracy value: 0.852601170539856.
learn.recorder.plot_loss()

png

Interpretation

For interpretation, the dataloader is build again and this time the original, not oversampled data is used, as AUC and sensitivity are influenced by the prevalence.

dls = ImageDataLoaders3D.from_df(d, path = '/', 
                                splitter = RandomSplitter(seed = 42),
                                item_tfms = ResizeCrop3D(crop_by = (0, 0, 0), # don't crop the images
                                                         resize_to = (20, 224, 224)), 
                                batch_tfms = aug_transforms_3d(p_all = 0.15), 
                                bs = 16, val_bs = 16)
learn = cnn_learner_3d(dls, r3d_18, metrics = [accuracy, RocAucBinary()], cbs = SaveModelCallback(), model_dir = MODEL_DIR)
learn.to_fp16()
learn = learn.load('model')
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

png

interp.print_classification_report()
              precision    recall  f1-score   support

           0       0.72      0.96      0.82        46
           1       0.99      0.92      0.95       204

    accuracy                           0.92       250
   macro avg       0.86      0.94      0.89       250
weighted avg       0.94      0.92      0.93       250

These results are pretty close to the published state of the art with an accuracy of 0.92 vs. 0.85 by the Stanford group and recall of 0.92 for detecting global abnormalities (vs. 0.879 for the Stanford group, see Table 2 of their publication). Although in this example only one of three sequences was used and the metrics are calculated on the validation dataset, not on the hidden test dataset of the Stanford MR-NET challenge.