-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtool_show_size.py
86 lines (51 loc) · 1.82 KB
/
tool_show_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import matplotlib.pyplot as plt
import tools
#%%
cat = plt.imread('C:/Users/fcyoung\Desktop/test.jpg') #unit8
plt.imshow(cat)
cat = tf.cast(cat, tf.float32) #[360, 300, 3]
x = tf.reshape(cat, [1, 1080, 1920, 3]) #[1, 360, 300, 3]
#%%
# First conv
with tf.variable_scope('conv1'):
w = tools.weight([3,3,3,16], is_uniform=True)#weight=[filter_height, filter_width, in_channels, out_channels]
x_w = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')#strides=[batch, height, width, channels]
b = tools.bias([16])
x_b = tf.nn.bias_add(x_w, b)
x_relu = tf.nn.relu(x_b)
x_pool = tools.pool('test1', x_relu, kernel=[1,2,2,1], stride=[1,2,2,1],is_max_pool=True)
# Second conv
with tf.variable_scope('conv2'):
w2 = tools.weight([3,3,16,32], is_uniform=True)
x_w2 = tf.nn.conv2d(x_pool, w2, strides=[1, 1, 1, 1], padding='SAME')#padding:SAME 填充
b2 = tools.bias([32])
x_b2 = tf.nn.bias_add(x_w2, b2)
x_relu2 = tf.nn.relu(x_b2)
x_pool2 = tools.pool('test2',x_relu2, kernel=[1,2,2,1],stride=[1,2,2,1], is_max_pool=False)
x_BN = tools.batch_norm(x_pool2)
#%%
def shape(x):
return str(x.get_shape())
## First conv
print('\n')
print('** First conv: **\n')
print('input size: ', shape(x))
print('w size:', shape(w))
print('x_w size: ', shape(x_w))
print('b size: ', shape(b))
print('x_b size: ', shape(x_b))
print('x_relu size: ', shape(x_relu))
print('x_pool size: ', shape(x_pool))
print('\n')
## Second conv
print('** Second conv: **\n')
print('input size: ', shape(x_pool))
print('w2 size:', shape(w2))
print('x_w2 size: ', shape(x_w2))
print('b2 size: ', shape(b2))
print('x_b2 size: ', shape(x_b2))
print('x_relu2 size: ', shape(x_relu2))
print('x_pool2 size: ', shape(x_pool2))
print('x_BN size: ', shape(x_BN))
print('\n')