This repository is the official PyTorch implementation of Active Learning for Deep Object Detection via Probabilistic Modeling, ICCV 2021.
The proposed method is implemented based on the SSD pytorch.
Our approach relies on mixture density networks to estimate, in a single forward pass of a single model, both localization and classification uncertainties, and leverages them in the scoring function for active learning.
Our method performs on par with multiple model-based methods (e.g., ensembles and MC-Dropout). Therefore, our method provides the best trade-off between accuracy and computational cost.
To view a NVIDIA Source Code License for this work, visit https://github.com/NVlabs/AL-MDN/blob/main/LICENSE
For setup and data preparation, please refer to the README in SSD pytorch.
Code was tested in virtual environment with Python 3+
and Pytorch 1.1
.
-
Make directory
mkdir weights
andcd weights
. -
Download the FC-reduced VGG-16 backbone weight in the
weights
directory, andcd ..
. -
If necessary, change the
VOC_ROOT
indata/voc0712.py
orCOCO_ROOT
indata/coco.py
. -
Please refer to
data/config.py
for configuration. -
Run the training code:
# Supervised learning
CUDA_VISIBLE_DEVICES=<GPU_ID> python train_ssd_gmm_supervised_learning.py
# Active learning
CUDA_VISIBLE_DEVICES=<GPU_ID> python train_ssd_gmm_active_learining.py
-
To evaluate on MS-COCO, change the
COCO_ROOT_EVAL
indata/coco_eval.py
. -
Run the evaluation code:
# Evaluation on PASCAL VOC
python eval_voc.py --trained_model <trained weight path>
# Evaluation on MS-COCO
python eval_coco.py --trained_model <trained weight path>
- Run the visualization code:
python demo.py --trained_model <trained weight path>
@InProceedings{Choi_2021_ICCV,
author = {Choi, Jiwoong and Elezi, Ismail and Lee, Hyuk-Jae and Farabet, Clement and Alvarez, Jose M.},
title = {Active Learning for Deep Object Detection via Probabilistic Modeling},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {10264-10273}
}