-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2682584
Showing
33 changed files
with
28,611 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
MIT License | ||
|
||
All contributions from “Superpixel Segmentation with Fully Convolutional Network”: | ||
Copyright (c) 2020 Fengting Yang | ||
|
||
All contributions from “FlowNetPytorch”: | ||
Copyright (c) 2017 Clément Pinard | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# SpixelFCN: Superpixel Segmentation with Fully Convolutional Network | ||
|
||
This is is a PyTorch implementation of the superpixel segmentation network introduced in our CVPR-20 paper: | ||
|
||
[Superpixel Segmentation with Fully Convolutional Network](http://personal.psu.edu/fuy34) | ||
|
||
[Fengting Yang](http://personal.psu.edu/fuy34/), [Qian Sun](https://www.linkedin.com/in/qiansuun), [Hailin Jin](https://research.adobe.com/person/hailin-jin/), and [Zihan Zhou](https://faculty.ist.psu.edu/zzhou/Home.html) | ||
|
||
Please contact Fengting Yang ([email protected]) if you have any question. | ||
|
||
## Prerequisites | ||
The training code was mainly developed and tested with python 2.7, PyTorch 0.4.1, CUDA 9, and Ubuntu 16.04. | ||
|
||
During test, we make use of the component connection method in [SSN](https://github.com/NVlabs/ssn_superpixels) to enforce the connectivity | ||
in superpixels. The code has been included in ```/third_paty/cython```. To compile it: | ||
``` | ||
cd third_party/cython/ | ||
python setup.py install --user | ||
cd ../.. | ||
``` | ||
## Demo | ||
The demo script ```run_demo.py``` provides the superpixels with grid size ```16 x 16``` using our pre-trained model (in ```/pretrained_ckpt```). | ||
Please feel free to provide your own images by copying them into ```/demo/inputs```, and run | ||
``` | ||
python run_demo.py --data_dir=./demo/inputs --data_suffix=jpg --output=./demo | ||
``` | ||
The results will be generate in a new folder under ```/demo``` called ```spixel_viz```. | ||
|
||
|
||
## Data preparation | ||
To generate training and test dataset, please first download the data from the original [BSDS500 dataset](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_full.tgz), | ||
and extract it to ```<BSDS_DIR>```. Then, run | ||
``` | ||
cd data_preprocessing | ||
python pre_processing_bsd500.py --dataset=<BSDS_DIR> --dump_root=<DUMP_DIR> | ||
python pre_processing_bsd500_ori.py --dataset=<BSDS_DIR> --dump_root=<DUMP_DIR> | ||
cd .. | ||
``` | ||
The code will generate three folders under the ```<DUMP_DIR>```, named as ```/train```, ```/val```, and ```/test```, and three ```.txt``` files | ||
record the absolute path of the images, named as ```train.txt```, ```val.txt```, and ```test.txt```. | ||
|
||
|
||
## Training | ||
Once the data is prepared, we should be able to train the model by running the following command | ||
``` | ||
python main.py --data=<DUMP_DIR> --savepath=<CKPT_LOG_DIR> | ||
``` | ||
|
||
if we wish to continue a train process or fine-tune from a pre-trained model, we can run | ||
``` | ||
python main.py --data=<DUMP_DIR> --savepath=<CKPT_LOG_DIR> --pretrained=<PATH_TO_THE_CKPT> | ||
``` | ||
The code will start from the recorded status, which includes the optimizer status and epoch number. | ||
|
||
The training log can be viewed from the `tensorboard` session by running | ||
``` | ||
tensorboard --logdir=<CKPT_LOG_DIR> --port=8888 | ||
``` | ||
|
||
If everything is set up properly, reasonable segmentation should be observed after 10 epochs. | ||
|
||
## Testing | ||
We provide test code to generate: 1) superpixel visualization and 2) the```.csv``` files for evaluation. | ||
|
||
To test on BSDS500, run | ||
``` | ||
python run_infer_bsds.py --data_dir=<DUMP_DIR> --output=<TEST_OUTPUT_DIR> --pretrained=<PATH_TO_THE_CKPT> | ||
``` | ||
|
||
To test on NYUv2, please first extract our pre-processed dataset from ```/nyu_test_set/nyu_preprocess_tst.tar.gz``` | ||
to ```<NYU_TEST>``` , or follow the [intruction on the superpixel benchmark](https://github.com/davidstutz/superpixel-benchmark/blob/master/docs/DATASETS.md) | ||
to generate the test dataset, and then run | ||
``` | ||
python run_infer_nyu.py --data_dir=<NYU_TEST> --output=<TEST_OUTPUT_DIR> --pretrained=<PATH_TO_THE_CKPT> | ||
``` | ||
|
||
To test on other datasets, please first collect all the images into one folder ```<CUSTOM_DIR>```, and then convert them into the same | ||
format (e.g. ```.png``` or ```.jpg```) if necessary, and run | ||
``` | ||
python run_demo.py --data_dir=<CUSTOM_DIR> --data_suffix=<IMG_SUFFIX> --output=<TEST_OUTPUT_DIR> --pretrained=<PATH_TO_THE_CKPT> | ||
``` | ||
Superpixels with grid size ```16 x 16``` will be generated by default. To generate the superpixel with a different grid size, we simply need to | ||
resize the images into the approporate resolution before passing them through the code. Please refer to ```run_infer_nyu.py``` for the details. | ||
|
||
## Evaluation | ||
We use the code from [superpixel benchmark](https://github.com/davidstutz/superpixel-benchmark) for superpixel evaluation. | ||
A detailed [instruction](https://github.com/davidstutz/superpixel-benchmark/blob/master/docs/BUILDING.md) is available in the repository, please | ||
|
||
(1) download the code and build it accordingly; | ||
|
||
(2) edit the variables ```$SUPERPIXELS```, ```IMG_PATH``` and ```GT_PATH``` in ```/eval_spixel/my_eval.sh```, | ||
|
||
(3) run | ||
``` | ||
cp /eval_spixel/my_eval.sh <path/to/the/benchmark>/examples/bash/ | ||
cd <path/to/the/benchmark>/examples/ | ||
bash my_eval.sh | ||
``` | ||
several files should be generated in the ```map_csv``` folders in the corresponding test outputs; | ||
|
||
(4) run | ||
``` | ||
cd eval_spixel | ||
python copy_resCSV.py --src=<TEST_OUTPUT_DIR> --dst=<PATH_TO_COLLECT_EVAL_RES> | ||
``` | ||
(5) open ```/eval_spixel/plot_benchmark_curve.m``` , set the ```our1l_res_path``` as ```<PATH_TO_COLLECT_EVAL_RES>``` and modify the ```num_list``` | ||
according to the test setting. The default setting is for our BSDS500 test set.; | ||
|
||
(6) run the ```plot_benchmark_curve.m```, the ```ASA Score```, ```CO Score```, and ```BR-BP curve``` of our method should | ||
be shown on the screen. If you wish to compare our method with the others, you can first run the method and organize the data | ||
as we state above, and uncomment the code in the ```plot_benchmark_curve.m``` to generate a similar figure shown in our papers. | ||
|
||
|
||
## Acknowledgement | ||
Our code is developed based on the training framework provided by [FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import os | ||
import numpy as np | ||
import cv2 | ||
from scipy.io import loadmat | ||
import argparse | ||
from glob import glob | ||
from joblib import Parallel, delayed | ||
import matplotlib.pyplot as plt | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default="", help="where the filtered dataset is stored") | ||
parser.add_argument("--dump_root", type=str, default="", help="Where to dump the data") | ||
parser.add_argument("--b_filter", type=bool, default=False, help="We do not use this in our paper") | ||
parser.add_argument("--num_threads", type=int, default=4, help="number of threads to use") | ||
args = parser.parse_args() | ||
|
||
''' | ||
Extract each pair of image and label from .mat to generagte the data for TRAINING and VALIDATION | ||
Please generate TEST data with pre_process_bsd500_ori_sz.py in the same folder | ||
We follow the SSN configuration to discard all samples that have more than 50 classes in their segments, and | ||
we use the exactaly same train, val, and test list as SSN, see the train/val/test.txt in the data_preprocessing folder for details | ||
author: Fengting Yang | ||
March. 1st 2019 | ||
''' | ||
|
||
def make_dataset(dir): | ||
cwd = os.getcwd() | ||
train_list_path = cwd + '/train.txt' | ||
val_list_path = cwd + '/val.txt' | ||
train_list = [] | ||
val_list = [] | ||
|
||
try: | ||
with open(train_list_path, 'r') as tf: | ||
train_list_0 = tf.readlines() | ||
for path in train_list_0: | ||
img_path = os.path.join(dir, 'BSR/BSDS500/data/images/train', path[:-1]+ '.jpg' ) | ||
if not os.path.isfile(img_path): | ||
print('The validate images are missing in {}'.format(os.path.dirname(img_path))) | ||
print('Please pre-process the BSDS500 as README states and provide the correct dataset path.') | ||
exit(1) | ||
train_list.append(img_path) | ||
|
||
with open (val_list_path, 'r') as vf: | ||
val_list_0 = vf.readlines() | ||
for path in val_list_0: | ||
img_path = os.path.join(dir, 'BSR/BSDS500/data/images/val', path[:-1]+ '.jpg') | ||
if not os.path.isfile(img_path): | ||
print('The validate images are missing in {}'.format(os.path.dirname(img_path))) | ||
print('Please pre-process the BSDS500 as README states and provide the correct dataset path.') | ||
exit(1) | ||
val_list.append(img_path) | ||
|
||
|
||
except IOError: | ||
print ('Error No avaliable list ') | ||
return | ||
|
||
return train_list, val_list | ||
|
||
def convert_label(label): | ||
|
||
problabel = np.zeros(( label.shape[0], label.shape[1], 50)).astype(np.float32) | ||
|
||
ct = 0 | ||
for t in np.unique(label).tolist(): | ||
if ct >= 50: | ||
print('give up sample becaues label shape is larger than 50: {0}'.format(np.unique(label).shape)) | ||
break | ||
else: | ||
problabel[ :, :, ct] = (label == t) #one hot | ||
ct = ct + 1 | ||
|
||
label2 = np.squeeze(np.argmax(problabel, axis = -1)) #squashed label e.g. [1. 3. 9, 10] --> [0,1,2,3], (h*w) | ||
|
||
return label2, problabel | ||
|
||
def BSD_loader(path_imgs, path_label, b_filter=False): | ||
|
||
img_ = cv2.imread(path_imgs) | ||
|
||
# origin size 481*321 or 321*481 | ||
H_, W_, _ = img_.shape | ||
|
||
# crop to 16*n size | ||
if H_ == 321 and W_ == 481: | ||
img = img_[:320, :480, :] | ||
elif H_ == 481 and W_ == 321: | ||
img = img_[:480, :320, :] | ||
else: | ||
print('It is not BSDS500 images') | ||
exit(1) | ||
|
||
if b_filter: | ||
img = cv2.bilateralFilter(img, 5, 75, 75) | ||
|
||
gtseg_lst = [] | ||
|
||
gtseg_all = loadmat(path_label) | ||
for t in range(len(gtseg_all['groundTruth'][0])): | ||
gtseg = gtseg_all['groundTruth'][0][t][0][0][0] | ||
|
||
label_, _ = convert_label(gtseg) | ||
if H_ == 321 and W_ == 481: | ||
label = label_[:320, :480] | ||
elif H_ == 481 and W_ == 321: | ||
label = label_[:480, :320] | ||
|
||
gtseg_lst.append(label) | ||
|
||
return img, gtseg_lst | ||
|
||
def dump_example(n, n_total, dataType, img_path): | ||
global args | ||
if n % 100 == 0: | ||
print('Progress {0} {1}/{2}....' .format (dataType,n, n_total)) | ||
|
||
img, label_lst = BSD_loader(img_path, img_path.replace('images', 'groundTruth')[:-4]+'.mat', b_filter=args.b_filter) | ||
|
||
if args.b_filter: | ||
dump_dir = os.path.join(args.dump_root, dataType + '_b_filter_' + str(args.b_filter)) | ||
else: | ||
dump_dir = os.path.join(args.dump_root, dataType) | ||
|
||
if not os.path.isdir(dump_dir): | ||
try: | ||
os.makedirs(dump_dir) | ||
except OSError: | ||
if not os.path.isdir(dump_dir): | ||
raise | ||
|
||
img_name = os.path.basename(img_path)[:-4] | ||
for k, label in enumerate(label_lst): | ||
# save images | ||
dump_img_file = os.path.join(dump_dir, '{0}_{1}_img.jpg' .format(img_name, k)) | ||
cv2.imwrite(dump_img_file, img.astype(np.uint8)) | ||
|
||
# save label | ||
dump_label_file = os.path.join(dump_dir, '{0}_{1}_label.png' .format(img_name, k)) | ||
cv2.imwrite(dump_label_file, label.astype(np.uint8)) | ||
|
||
# save label viz | ||
if not os.path.isdir(os.path.join(dump_dir,'label_viz')): | ||
os.makedirs(os.path.join(dump_dir,'label_viz')) | ||
dump_label_viz = os.path.join(dump_dir, 'label_viz', '{0}_{1}_label_viz.png'.format(img_name, k)) | ||
plt.imshow(label) | ||
plt.axis('off') | ||
plt.gca().xaxis.set_major_locator(plt.NullLocator()) | ||
plt.gca().yaxis.set_major_locator(plt.NullLocator()) | ||
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) | ||
plt.margins(0, 0) | ||
plt.savefig(dump_label_viz,bbox_inches='tight',pad_inches=0) | ||
plt.close() | ||
|
||
|
||
def main(): | ||
datadir = args.dataset | ||
train_list, val_list = make_dataset(datadir) | ||
|
||
# for debug only | ||
# for n, train_samp in enumerate(train_list): | ||
# dump_example(n, len(train_list),'train', train_samp) | ||
|
||
# mutil-thread running for speed | ||
Parallel(n_jobs=args.num_threads)(delayed(dump_example)(n, len(train_list),'train', train_samp) for n, train_samp in enumerate(train_list)) | ||
Parallel(n_jobs=args.num_threads)(delayed(dump_example)(n, len(train_list),'val', val_samp) for n, val_samp in enumerate(val_list)) | ||
|
||
with open(args.dump_root + '/train.txt', 'w') as trnf: | ||
imfiles = glob(os.path.join(args.dump_root, 'train', '*_img.jpg')) | ||
for frame in imfiles: | ||
trnf.write(frame + '\n') | ||
|
||
with open(args.dump_root + '/val.txt', 'w') as trnf: | ||
imfiles = glob(os.path.join(args.dump_root, 'val', '*_img.jpg')) | ||
for frame in imfiles: | ||
trnf.write(frame + '\n') | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.