-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_rotated_MNIST_dataset_42.py
94 lines (72 loc) · 3.34 KB
/
create_rotated_MNIST_dataset_42.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import cv2
import os
def load_data(trainingData, trainingLabel, testingData, testingLabel, dataset = "MNIST"):
trainingData = os.environ[dataset] + trainingData
trainingLabel = os.environ[dataset] + trainingLabel
testingData = os.environ[dataset] + testingData
testingLabel = os.environ[dataset] + testingLabel
X_train = np.array(np.load(trainingData), dtype = np.float32).reshape(-1, 1, 28, 28)
Y_train = np.array(np.load(trainingLabel), dtype = np.uint8)
X_test = np.array(np.load(testingData), dtype = np.float32).reshape(-1, 1, 28, 28)
Y_test = np.array(np.load(testingLabel), dtype = np.uint8)
return X_train, Y_train, X_test, Y_test
def rotateImage(image, angle):
if len(image.shape) == 3:
image = image[0]
image_center = tuple(np.array(image.shape)/2)
rot_mat = cv2.getRotationMatrix2D(image_center,angle,1.0)
result = cv2.warpAffine(image, rot_mat, image.shape,flags=cv2.INTER_LINEAR)
return np.array(result[np.newaxis, :, :], dtype = np.float32)
def extend_image(inputs, size = 40):
if len(inputs.shape) == 3:
inputs = inputs.reshape(inputs.shape[0], 1, inputs.shape[1], inputs.shape[2])
extended_images = np.zeros((inputs.shape[0], 1, size, size), dtype = np.float32)
margin_size = (size - inputs.shape[2]) / 2
extended_images[:, :, margin_size:margin_size + inputs.shape[2], margin_size:margin_size + inputs
.shape[3]] = inputs
return extended_images
X_train, y_train, X_test, y_test = load_data("/X_train.npy", "/Y_train.npy", "/X_test.npy", "/Y_test.npy")
X_test = extend_image(X_test, 40)
X_train = extend_image(X_train, 40)
train_size = y_train.shape[0]
all_images = []
all_labels = []
for j in range(1):
angles_1 = list(np.random.randint(low = -90, high = 0, size = train_size // 2))
angles_2 = list(np.random.randint(low = 0, high = 90, size = train_size // 2))
angles = np.array(angles_1 + angles_2)
np.random.shuffle(angles)
rotated_image = np.array([rotateImage(X_train[i], angles[i]) for i in range(train_size)], dtype = np.float32)
all_images.append(rotated_image)
all_labels.append(y_train)
all_images = np.vstack(all_images)
all_labels = np.hstack(all_labels)
print(all_images.shape, all_labels.shape)
index = np.arange(1 * train_size)
np.random.shuffle(index)
all_images = all_images[index, 0, 6: 34, 6:34]
all_labels = all_labels[index]
x_train = extend_image(all_images, 42)
y_train = all_labels
test_size = y_test.shape[0]
all_images = []
all_labels = []
for j in range(1):
angles_1 = list(np.random.randint(low = -90, high = 0, size = test_size // 2))
angles_2 = list(np.random.randint(low = 0, high = 90, size = test_size // 2))
angles = np.array(angles_1 + angles_2)
np.random.shuffle(angles)
rotated_image = np.array([rotateImage(X_test[i], angles[i]) for i in range(test_size)], dtype = np.float32)
all_images.append(rotated_image)
all_labels.append(y_test)
all_images = np.vstack(all_images)
all_labels = np.hstack(all_labels)
print(all_images.shape, all_labels.shape)
index = np.arange(1 * test_size)
np.random.shuffle(index)
all_images = all_images[index, 0, 6: 34, 6:34]
all_labels = all_labels[index]
x_test = extend_image(all_images, 42)
y_test = all_labels
np.savez("/phddata/jiajun/Research/mnist/rotated_mnist_42.npz", x_train = x_train, y_train = y_train, x_test = x_test, y_test=y_test)