Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jul 7, 2021
1 parent 5099ace commit c0422a1
Show file tree
Hide file tree
Showing 18 changed files with 185 additions and 486 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,5 @@ tensorrt/build/*
datasets/coco/train.txt
datasets/coco/val.txt
pretrained/*
dist_train.sh

88 changes: 64 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,25 @@
My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV2](https://arxiv.org/abs/1808.00897).


The mIOU evaluation result of the models trained and evaluated on cityscapes train/val set is:
mIOUs and fps on cityscapes val set:
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
|------|:--:|:---:|:---:|:----:|:---:|:----:|
| bisenetv1 | 75.55 | 76.90 | 77.40 | 78.91 | 60/19 | [download](https://drive.google.com/file/d/140MBBAt49N1z1wsKueoFA6HB_QuYud8i/view?usp=sharing) |
| bisenetv2 | 74.12 | 74.18 | 75.89 | 75.87 | 50/16 | [download](https://drive.google.com/file/d/1qq38u9JT4pp1ubecGLTCHHtqwntH0FCY/view?usp=sharing) |
| bisenetv1 | 75.10 | 76.90 | 77.22 | 78.73 | 60/19 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 50/16 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |

> Where **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales of multi-scales evaluation are `[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]`, and the crop size of crop evaluation is `[1024, 1024]`.
mIOUs on cocostuff val2017 set:
| none | ss | ssc | msf | mscf | link |
|------|:--:|:---:|:---:|:----:|:----:|
| bisenetv1 | 31.89 | 31.62 | 32.81 | 32.72 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_coco.pth) |
| bisenetv2 | 30.49 | 30.55 | 31.81 | 31.73 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_coco.pth) |

> Where **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales and crop size of multi-scales evaluation can be found in [configs](./configs/).
> The fps is tested in different way from the paper. For more information, please see [here](./tensorrt).
Note that the model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 72.1-74.4.
> For cocostuff dataset: The authors of the paper `bisenetv2` used the "old split" of 9k train set and 1k val set, while I used the "new split" of 118k train set and 5k val set. Thus the above results on cocostuff does not match the paper. The authors of bisenetv1 did not report their results on cocostuff, so here I simply provide a "make it work" result. Following the tradition of object detection, I used "1x"(90k) and "2x"(180k) schedule to train bisenetv1(1x) and bisenetv2(2x) respectively. Maybe you can have a better result by picking up hyper-parameters more carefully.
Note that the model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 73.1-75.1.


## platform
Expand All @@ -22,8 +30,8 @@ My platform is like this:
* nvidia Tesla T4 gpu, driver 450.51.05
* cuda 10.2
* cudnn 7
* miniconda python 3.6.9
* pytorch 1.6.0
* miniconda python 3.8.8
* pytorch 1.8.1


## get start
Expand All @@ -47,7 +55,23 @@ $ unzip leftImg8bit_trainvaltest.zip
$ unzip gtFine_trainvaltest.zip
```

2.custom dataset
2. cocostuff
Download `train2017.zip`, `val2017.zip` and `stuffthingmaps_trainval2017.zip` split from official [website](https://cocodataset.org/#download). Then do as following:
```
$ unzip train2017.zip
$ unzip val2017.zip
$ mv train2017/ /path/to/BiSeNet/datasets/coco/images
$ mv val2017/ /path/to/BiSeNet/datasets/coco/images
$ unzip stuffthingmaps_trainval2017.zip
$ mv train2017/ /path/to/BiSeNet/datasets/coco/labels
$ mv val2017/ /path/to/BiSeNet/datasets/coco/labels
$ cd /path/to/BiSeNet
$ python tools/gen_coco_annos.py
```

3.custom dataset

If you want to train on your own dataset, you should generate annotation files first with the format like this:
```
Expand All @@ -56,30 +80,46 @@ frankfurt_000001_079206_leftImg8bit.png,frankfurt_000001_079206_gtFine_labelIds.
...
```
Each line is a pair of training sample and ground truth image path, which are separated by a single comma `,`.
Then you need to change the field of `im_root` and `train/val_im_anns` in the configuration files.
Then you need to change the field of `im_root` and `train/val_im_anns` in the configuration files. If you found what shows in `cityscapes_cv2.py` is not clear, you can also see `coco.py`.

## train
In order to train the model, you can run command like this:
```
$ export CUDA_VISIBLE_DEVICES=0,1

# if you want to train with apex
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --config configs/bisenetv2_city.py # or bisenetv1
# if you want to train with pytorch fp16 feature from torch 1.6
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --config configs/bisenetv2_city.py # or bisenetv1
## train
I used the following command to train the models:
```bash
# bisenetv1 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv1_city.py
NGPUS=2
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv2 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv2_city.py
NGPUS=2
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv1 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3
cfg_file=configs/bisenetv1_coco.py
NGPUS=4
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv2 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
cfg_file=configs/bisenetv2_coco.py
NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file
```

Note that though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
Note:
1. though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
2. I used overall batch size of 16 to train all models. Since cocostuff has 171 categories, it requires more memory to train models on it. I split the 16 images into more gpus than 2, as I do with cityscapes.


## finetune from trained model
You can also load the trained model weights and finetune from it:
You can also load the trained model weights and finetune from it, like this:
```
$ export CUDA_VISIBLE_DEVICES=0,1
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --config ./configs/bisenetv2_city.py # or bisenetv1
# same with pytorch fp16 feature
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --config ./configs/bisenetv2_city.py # or bisenetv1
```

Expand All @@ -94,6 +134,6 @@ $ python tools/evaluate.py --config configs/bisenetv1_city.py --weight-path /pat
You can go to [tensorrt](./tensorrt) For details.


### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation.
### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation if you need, though I believe you will not need it.


23 changes: 23 additions & 0 deletions configs/bisenetv1_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

cfg = dict(
model_type='bisenetv1',
n_cats=171,
num_aux_heads=2,
lr_start=1e-2,
weight_decay=1e-4,
warmup_iters=1000,
max_iter=90000,
dataset='CocoStuff',
im_root='./datasets/coco',
train_im_anns='./datasets/coco/train.txt',
val_im_anns='./datasets/coco/val.txt',
scales=[0.5, 2.],
cropsize=[512, 512],
eval_crop=[512, 512],
eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
ims_per_gpu=4,
eval_ims_per_gpu=1,
use_fp16=True,
use_sync_bn=True,
respth='./res',
)
4 changes: 2 additions & 2 deletions configs/bisenetv2_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
## bisenetv2
cfg = dict(
model_type='bisenetv2',
n_cats=182,
n_cats=171,
num_aux_heads=4,
lr_start=5e-3,
weight_decay=1e-4,
warmup_iters=1000,
max_iter=20000,
max_iter=180000,
dataset='CocoStuff',
im_root='./datasets/coco',
train_im_anns='./datasets/coco/train.txt',
Expand Down
20 changes: 15 additions & 5 deletions dist_train.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
PORT=52339
NGPUS=8
export CUDA_VISIBLE_DEVICES=0,1
PORT=52332
NGPUS=2
# cfg_file=configs/bisenetv1_city.py
# cfg_file=configs/bisenetv1_coco.py
# cfg_file=configs/bisenetv2_city.py
cfg_file=configs/bisenetv2_coco.py

python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file --port $PORT
# python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file --port $PORT
python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/modelzoo/model_final_v2_city.pth --config ./configs/bisenetv2_city.py # or bisenetv1

## train, use run
# python -m torch.distributed.run --nnode=1 --rdzv_backend=c10d --rdzv_id=001 --rdzv_endpoint=127.0.0.1:$PORT --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file --port $PORT




# python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config $cfg_file --port $PORT

# python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/evaluate.py --config $cfg_file --port $PORT --weight-path res/model_final.pth
# python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/evaluate.py --config $cfg_file --port $PORT --weight-path res/modelzoo/model_final_v2_coco.pth


26 changes: 0 additions & 26 deletions lib/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import cv2
import numpy as np

import lib.transform_cv2 as T
from lib.sampler import RepeatedDistSampler


Expand Down Expand Up @@ -58,31 +57,6 @@ def __len__(self):
return self.len


class TransformationTrain(object):

def __init__(self, scales, cropsize):
self.trans_func = T.Compose([
T.RandomResizedCrop(scales, cropsize),
T.RandomHorizontalFlip(),
T.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4
),
])

def __call__(self, im_lb):
im_lb = self.trans_func(im_lb)
return im_lb


class TransformationVal(object):

def __call__(self, im_lb):
im, lb = im_lb['im'], im_lb['lb']
return dict(im=im, lb=lb)


if __name__ == "__main__":
from tqdm import tqdm
from torch.utils.data import DataLoader
Expand Down
16 changes: 12 additions & 4 deletions lib/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lib.base_dataset import BaseDataset

'''
91 + 91 = 182 classes, label proportions are:
91(thing) + 91(stuff) = 182 classes, label proportions are:
[0.0901445377, 0.00157896236, 0.00611962763, 0.00494526505, 0.00335260064, 0.00765355955, 0.00772972804, 0.00631509744,
0.00270457286, 0.000697793344, 0.00114085574, 0.0, 0.00114084131, 0.000705729068, 0.00359758029, 0.00162208938, 0.00598373796,
0.00440213609, 0.00362085441, 0.00193052224, 0.00271001196, 0.00492864603, 0.00186985393, 0.00332902228, 0.00334420294, 0.0,
Expand All @@ -38,7 +38,9 @@
0.00112924659, 0.001457768, 0.00190406757, 0.00173232644, 0.0116980759, 0.000850599027, 0.00565381261, 0.000787379463, 0.0577763754,
0.00214883711, 0.00553984356, 0.0443605019, 0.0218570174, 0.0027310644, 0.00225446528, 0.00903008323, 0.00644298871, 0.00442167269,
0.000129279566, 0.00176047379, 0.0101637834, 0.00255549522]
11 classes has no annos, proportions are 0
11 thing classes has no annos, proportions are 0:
[11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90]
'''


Expand All @@ -47,9 +49,15 @@ class CocoStuff(BaseDataset):

def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
super(CocoStuff, self).__init__(dataroot, annpath, trans_func, mode)
self.n_cats = 182 # 91 stuff, 91 thing, 11 of thing have no annos
self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos
self.lb_ignore = 255
self.lb_map = None

## label mapping, remove non-existing labels
missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90]
remain = [ind for ind in range(182) if not ind in missing]
self.lb_map = np.arange(256)
for ind in remain:
self.lb_map[ind] = remain.index(ind)

self.to_tensor = T.ToTensor(
mean=(0.46962251, 0.4464104, 0.40718787), # coco, rgb
Expand Down
28 changes: 27 additions & 1 deletion lib/get_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,38 @@
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

import lib.transform_cv2 as T
from lib.sampler import RepeatedDistSampler
from lib.base_dataset import TransformationTrain, TransformationVal
from lib.cityscapes_cv2 import CityScapes
from lib.coco import CocoStuff



class TransformationTrain(object):

def __init__(self, scales, cropsize):
self.trans_func = T.Compose([
T.RandomResizedCrop(scales, cropsize),
T.RandomHorizontalFlip(),
T.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4
),
])

def __call__(self, im_lb):
im_lb = self.trans_func(im_lb)
return im_lb


class TransformationVal(object):

def __call__(self, im_lb):
im, lb = im_lb['im'], im_lb['lb']
return dict(im=im, lb=lb)


def get_data_loader(cfg, mode='train', distributed=True):
if mode == 'train':
trans_func = TransformationTrain(cfg.scales, cfg.cropsize)
Expand Down
5 changes: 3 additions & 2 deletions lib/models/bisenetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.up_factor = up_factor
out_chan = n_classes * up_factor * up_factor
out_chan = n_classes
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
self.up = nn.PixelShuffle(up_factor)
self.up = nn.Upsample(scale_factor=up_factor,
mode='bilinear', align_corners=False)
self.init_weight()

def forward(self, x):
Expand Down
Loading

0 comments on commit c0422a1

Please sign in to comment.