forked from gxd1994/TextBoxes-TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathload_batch.py
87 lines (66 loc) · 2.34 KB
/
load_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# *_* coding:utf-8 *_*
"""
This script produce a batch trainig
"""
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..')))
import tensorflow as tf
from datasets import sythtextprovider
import tf_utils
from processing import txt_preprocessing
slim = tf.contrib.slim
def get_batch(dataset_dir,
num_readers,
batch_size,
out_shape,
net,
anchors,
FLAGS,
file_pattern = '*.tfrecord',
is_training = True,
shuffe = False):
dataset = sythtextprovider.get_datasets(dataset_dir,file_pattern = file_pattern)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=512 * 16 + 20 * batch_size,
common_queue_min=512 * 16,
shuffle=shuffe)
[image, shape, glabels, gbboxes,height,width] = provider.get(['image', 'shape',
'object/label',
'object/bbox','height','width'])
if is_training:
image, glabels, gbboxes,num = \
txt_preprocessing.preprocess_image(image, glabels, gbboxes, height, width,
out_shape,use_whiten=FLAGS.use_whiten,is_training=is_training)
glocalisations, gscores = \
net.bboxes_encode( gbboxes, anchors, num)
batch_shape = [1] + [len(anchors)] * 2
r = tf.train.shuffle_batch(
tf_utils.reshape_list([image, glocalisations, gscores]),
batch_size=batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=100 * batch_size,
min_after_dequeue= 50 * batch_size
)
b_image, b_glocalisations, b_gscores= \
tf_utils.reshape_list(r, batch_shape)
return b_image, b_glocalisations, b_gscores
else:
image, glabels, gbboxes,bbox_img, num = \
txt_preprocessing.preprocess_image(image, glabels,gbboxes, height,width,
out_shape,use_whiten=FLAGS.use_whiten,is_training=is_training)
glocalisations, gscores = \
net.bboxes_encode( gbboxes, anchors, num)
batch_shape = [1] * 4 + [len(anchors)] * 2
r = tf.train.batch(
tf_utils.reshape_list([image, glabels, gbboxes, bbox_img,
glocalisations, gscores]),
batch_size=batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=50 * batch_size,
dynamic_pad=True)
image, glabels, gbboxes,g_bbox_img,glocalisations, gscores = \
tf_utils.reshape_list(r, batch_shape)
return image, glabels, gbboxes, g_bbox_img, glocalisations, gscores