-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathim2col_cython.pyx
89 lines (70 loc) · 3.59 KB
/
im2col_cython.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Fei-Fei Li, Andrej Karpathy, Justin Johnson (2015)
import numpy as np
cimport numpy as np
cimport cython
# DTYPE = np.float64
# ctypedef np.float64_t DTYPE_t
ctypedef fused DTYPE_t:
np.float32_t
np.float64_t
def im2col_cython(np.ndarray[DTYPE_t, ndim=4] x, int field_height,
int field_width, int padding, int stride):
cdef int N = x.shape[0]
cdef int C = x.shape[1]
cdef int H = x.shape[2]
cdef int W = x.shape[3]
cdef int HH = (H + 2 * padding - field_height) // stride + 1
cdef int WW = (W + 2 * padding - field_width) // stride + 1
cdef int p = padding
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.pad(x,
((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
cdef np.ndarray[DTYPE_t, ndim=2] cols = np.zeros(
(C * field_height * field_width, N * HH * WW),
dtype=x.dtype)
# Moving the inner loop to a C function with no bounds checking works, but does
# not seem to help performance in any measurable way.
im2col_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
field_height, field_width, padding, stride)
return cols
@cython.boundscheck(False)
cdef int im2col_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
np.ndarray[DTYPE_t, ndim=4] x_padded,
int N, int C, int H, int W, int HH, int WW,
int field_height, int field_width, int padding, int stride) except? -1:
cdef int c, ii, jj, row, yy, xx, i, col
for c in range(C):
for yy in range(HH):
for xx in range(WW):
for ii in range(field_height):
for jj in range(field_width):
row = c * field_width * field_height + ii * field_height + jj
for i in range(N):
col = yy * WW * N + xx * N + i
cols[row, col] = x_padded[i, c, stride * yy + ii, stride * xx + jj]
def col2im_cython(np.ndarray[DTYPE_t, ndim=2] cols, int N, int C, int H, int W,
int field_height, int field_width, int padding, int stride):
cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
cdef int HH = (H + 2 * padding - field_height) // stride + 1
cdef int WW = (W + 2 * padding - field_width) // stride + 1
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * padding, W + 2 * padding), dtype=cols.dtype)
# Moving the inner loop to a C-function with no bounds checking improves performance quite a bit for col2im.
col2im_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
field_height, field_width, padding, stride)
if padding > 0:
return x_padded[:, :, padding:-padding, padding:-padding]
return x_padded
@cython.boundscheck(False)
cdef int col2im_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
np.ndarray[DTYPE_t, ndim=4] x_padded,
int N, int C, int H, int W, int HH, int WW,
int field_height, int field_width, int padding, int stride) except? -1:
cdef int c, ii, jj, row, yy, xx, i, col
for c in range(C):
for ii in range(field_height):
for jj in range(field_width):
row = c * field_width * field_height + ii * field_height + jj
for yy in range(HH):
for xx in range(WW):
for i in range(N):
col = yy * WW * N + xx * N + i
x_padded[i, c, stride * yy + ii, stride * xx + jj] += cols[row, col]