This repository provides the official implementation and pretrained models for BROW. More details will be included in the related paper which would be released soon.
- Developed a foundation model for Whole Slide Image (WSI) analysis. The model was pretrained on a dataset containing more than 10,000 WSIs without using any labels or annotations.
- The model produces robust and high-quality feature representations for WSI.
- The features can be directly employed with classifiers on slide-level multi-class subtyping problems. The trained model also performs well on patch-level classification tasks with slight fine-tuning.
Main Requirements
Linux (Tested on Ubuntu 18.04)
Python==3.9.16
Pytorch==1.12.0
torchvision==0.13.0
openslide-python==1.2.0
opencv-python==4.7.0.72
Installation
The training is performed using Pytorch on a Linux environment. It requires the main packages metioned above as well as a number of other 3rd party packages. To setup all the required dependencies for training and evaluation, please follow the instructions below:
conda (Recommended) - Clone the repository and then create and activate a BROW
conda environment using the provided environment definition environment.yaml
:
conda env create -f environment.yaml
conda activate BROW
Please refer to Installation guide for more details about installation.
Dataset Links
Data Preparation The models were trained with a large WSI dataset, which contains more than 10,000 slides from multiple datasets, including about 6,000 slides from The Cancer Genome Atlas Program (TCGA), 1,000 slides from CAMELYON17 and more than 3,000 private slides. For each slide, we used CLAM to segment the tissue and exluding the blank areas, then extracted the patches within the segmented regions, saved the coordinates of patches in a .npy file. The following example assumes that the whole slide image data in well known standard formats (.svs, .tiff etc.) and the coordinates files are stored under a folder named DATA_DIRECTORY
DATA_DIRECTORY/
SUBDATASET1/
├── slide_1.svs
├── slide_1.npy
├── slide_2.svs
├── slide_2.npy
└── ...
SUBDATASET2/
├── slide_1.tiff
├── slide_1.npy
├── slide_2.tiff
├── slide_2.npy
└── ...
This codebase was developed with Python version 3.9.16, PyTorch version 1.12.0, CUDA 11.7 and torchvision 0.13.0 with NVIDIA A100 GPUs. The training log can be found at Links/Model. Following is a vanilla training implementation example on 1 nodes with 4 GPUs (total 4 GPUs):
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=1 \
--node_rank=0 \
--master_addr="xx.xxx.xxx.xxx" \
--master_port=xxxx \
train.py \
--patch_size 16 \
--arch "vit_base" \
--batch_size_per_gpu xxx \
--use_fp16 0 \
--output_dir ./output_dir
You can use the pre-trained model for various downstream tasks and the weights can be found at Links/Model. The model can be easily initialized with the backbone weights using the genmodel() function in genmodel.py
.
For this task, we adopted the multiple instance learning (MIL) framework and test models' performance on several dataset, including TCGA-BRCA, TCGA-RCC, TCGA-NSCLC, CAMELYON16, PANDA, etc. The features for each slides are pre-extracted due to the large scale of WSI. Then the MIL classifier is trained on these features according to the common practices. The extracted feature embeddings, the trained models' weights and the test resluts are provided:
Dataset | Acc | AUC | download | ||
---|---|---|---|---|---|
TCGA-BRCA | 0.8897 | 0.9224 | args | weights | embeddings |
TCGA-RCC | 0.9511 | 0.9942 | args | weights | embeddings |
TCGA-NSCLC | 0.8818 | 0.9606 | args | weights | embeddings |
CAMELYON16 | 0.9535 | 0.9756 | args | weights | embeddings |
PANDA | 0.9407 | 0.9802 | args | weights | embeddings |
python eval.py \
--dataset <name of dataset> \
--data_root_dir <directory to your data> \
--models_exp_code <directory to checkpoints> \
--save_exp_code <directory to save the eval results, it will be under ./eval_results/> \
--labelcsv_dir <directory to save the eval results, which can be found at ./dataset_csv> \
--splits_dir <data split folder, which can be found at ./splits> \
--k <cross validation folds number>
Here is an example for evaluation on TCGA-BRCA dataset. It assumes the feature embeddings are stored at ./BRCA/pt_files
, the checkpoints at ./FINAL_CKPT_CLAM/clam_BRCA
, the directory for saving the eval results is ./eval_results/clam_BRCA
:
python eval.py \
--dataset BRCA \
--data_root_dir ./BRCA/pt_files \
--models_exp_code ./FINAL_CKPT_CLAM/clam_BRCA \
--save_exp_code clam_BRCA \
--labelcsv_dir ./dataset_csv/BRCA_subtyping2.csv \
--splits_dir ./splits/BRCA_subtyping2 \
--k 10
Here, we provide a complete example using CLAM as classifier for training and testing on TCGA-BRCA dataset.
Data Preparation
Download the feature embeddings from embeddings
column in the table mentioned at Slide-level multi-class subtyping task section. Or generate them using the pre-trained models provided at Links/Model. The original WSI data can be found at the Dataset links.
cd "Slide-level multi-class subtyping task/feature_extract"
python extract_features.py \
--dataset BRCA \
--data_root_path <data_root_path> \
--save_pt_path <path_saving_features> \
--modelpath <path_to_ckpt> \
--file_ext .svs
The following example assumes the embedding files are stored under a folder named FEAT_DIRECTORY.
FEAT_DIRECTORY/
<path_saving_features>/
├── slide_1.pt
├── slide_2.pt
├── slide_3.pt
└── ...
The arguments used during training can be found in the args
column of the Slide-level multi-class subtyping task section.
Then train and test the model by
cd ..
python train.py \
--dataset BRCA \
--data_root_dir <FEAT_DIRECTORY/path_saving_features> \
--split_dir 'splits/BRCA_subtyping2' \
--exp_info 'args/experiment_task_2_tumor_subtyping_brca.txt' \
--csv_path 'dataset_csv/BRCA_subtyping2.csv' \
--exp_code 'task_2_tumor_subtyping_brca'
python eval.py \
--dataset BRCA \
--data_root_dir <FEAT_DIRECTORY/path_saving_features> \
--models_exp_code './results/task_2_tumor_subtyping_brca_s1' \
--save_exp_code 'task_2_tumor_subtyping_brca' \
--labelcsv_dir 'dataset_csv/BRCA_subtyping2.csv' \
--splits_dir 'splits/BRCA_subtyping2' \
--k 10
- Code for slide-level subtyping tasks was largely adapted from CLAM.
The related paper will be released soon.