-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_fcn_dataset.py
executable file
·120 lines (94 loc) · 4.34 KB
/
convert_fcn_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import logging
import os
import cv2
import numpy as np
import tensorflow as tf
from vgg import vgg_16
flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw pet dataset.')
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
FLAGS = flags.FLAGS
classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
'dog', 'horse', 'motorbike', 'person', 'potted plant',
'sheep', 'sofa', 'train', 'tv/monitor']
# RGB color for each class
colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [
128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128],
[64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
[0, 192, 0], [128, 192, 0], [0, 64, 128]]
cm2lbl = np.zeros(256**3)
for i, cm in enumerate(colormap):
cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
def image2label(im):
data = im.astype('int32')
# cv2.imread. default channel layout is BGR
idx = (data[:, :, 2] * 256 + data[:, :, 1]) * 256 + data[:, :, 0]
return np.array(cm2lbl[idx])
def dict_to_tf_example(data, label):
with open(data, 'rb') as inf:
encoded_data = inf.read()
img_label = cv2.imread(label)
img_mask = image2label(img_label)
encoded_label = img_mask.astype(np.uint8).tobytes()
height, width = img_label.shape[0], img_label.shape[1]
if height < vgg_16.default_image_size or width < vgg_16.default_image_size:
# 保证最后随机裁剪的尺寸
return None
############### My code here, filled the dict #############
feature_dict = {
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.encode('utf8')])),
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_data])),
'image/label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_label])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')]))
}
###########################################################
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example
def create_tf_record(output_filename, file_pars):
###################### My code here #######################
writer = tf.python_io.TFRecordWriter(output_filename)
idx = 0
for data_path, label_path in file_pars:
idx += 1
if idx % 100 == 0:
print('Image index %d', idx)
if not os.path.exists(data_path) or not os.path.exists(label_path):
logging.warning('Could not find [%s,%s], ignoring example.', data_path, label_path)
continue
try:
tf_example = dict_to_tf_example(data_path, label_path)
# 忽略尺寸不符合要求的照片
if not tf_example:
continue
writer.write(tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: [%s,%s], ignoring.', data_path, label_path)
writer.close()
###########################################################
def read_images_names(root, train=True):
txt_fname = os.path.join(root, 'ImageSets/Segmentation/', 'train.txt' if train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
data = []
label = []
for fname in images:
data.append('%s/JPEGImages/%s.jpg' % (root, fname))
label.append('%s/SegmentationClass/%s.png' % (root, fname))
return zip(data, label)
def main(_):
logging.info('Prepare dataset file names')
train_output_path = os.path.join(FLAGS.output_dir, 'fcn_train.record')
val_output_path = os.path.join(FLAGS.output_dir, 'fcn_val.record')
train_files = read_images_names(FLAGS.data_dir, True)
val_files = read_images_names(FLAGS.data_dir, False)
create_tf_record(train_output_path, train_files)
create_tf_record(val_output_path, val_files)
if __name__ == '__main__':
tf.app.run()