-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmnist.h
170 lines (135 loc) · 3.09 KB
/
mnist.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#ifndef __MNIST_H__
#define __MNIST_H__
/*
* MNIST loader by Nuri Park - https://github.com/projectgalateia/mnist
*/
#ifdef USE_MNIST_LOADER /* Fundamental macro to make the code active */
#ifdef __cplusplus
extern "C" {
#endif
/*
* Make mnist_load function static.
* Define when the header is included multiple time.
*/
#ifdef MNIST_STATIC
#define _STATIC static
#else
#define _STATIC
#endif
/*
* Make mnist loader to load image data as double type.
* It divides unsigned char values by 255.0, so the results ranges from 0.0 to 1.0
*/
#ifdef MNIST_DOUBLE
#define MNIST_DATA_TYPE double
#else
#define MNIST_DATA_TYPE unsigned char
#endif
typedef struct mnist_data {
MNIST_DATA_TYPE data[28][28]; /* 28x28 data for the image */
unsigned int label; /* label : 0 to 9 */
} mnist_data;
/*
* If it's header inclusion, make only function prototype visible.
*/
#ifdef MNIST_HDR_ONLY
_STATIC int mnist_load(
const char *image_filename,
const char *label_filename,
mnist_data **data,
unsigned int *count);
#else
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
/*
* Load a unsigned int from raw data.
* MSB first.
*/
static unsigned int mnist_bin_to_int(char *v)
{
int i;
unsigned int ret = 0;
for (i = 0; i < 4; ++i) {
ret <<= 8;
ret |= (unsigned char)v[i];
}
return ret;
}
/*
* MNIST dataset loader.
*
* Returns 0 if successed.
* Check comments for the return codes.
*/
_STATIC int mnist_load(
const char *image_filename,
const char *label_filename,
mnist_data **data,
unsigned int *count)
{
int return_code = 0;
int i;
char tmp[4];
unsigned int image_cnt, label_cnt;
unsigned int image_dim[2];
FILE *ifp = fopen(image_filename, "rb");
FILE *lfp = fopen(label_filename, "rb");
if (!ifp || !lfp) {
return_code = -1; /* No such files */
goto cleanup;
}
fread(tmp, 1, 4, ifp);
if (mnist_bin_to_int(tmp) != 2051) {
return_code = -2; /* Not a valid image file */
goto cleanup;
}
fread(tmp, 1, 4, lfp);
if (mnist_bin_to_int(tmp) != 2049) {
return_code = -3; /* Not a valid label file */
goto cleanup;
}
fread(tmp, 1, 4, ifp);
image_cnt = mnist_bin_to_int(tmp);
fread(tmp, 1, 4, lfp);
label_cnt = mnist_bin_to_int(tmp);
if (image_cnt != label_cnt) {
return_code = -4; /* Element counts of 2 files mismatch */
goto cleanup;
}
for (i = 0; i < 2; ++i) {
fread(tmp, 1, 4, ifp);
image_dim[i] = mnist_bin_to_int(tmp);
}
if (image_dim[0] != 28 || image_dim[1] != 28) {
return_code = -2; /* Not a valid image file */
goto cleanup;
}
*count = image_cnt;
*data = (mnist_data *)malloc(sizeof(mnist_data) * image_cnt);
for (i = 0; i < image_cnt; ++i) {
int j;
unsigned char read_data[28 * 28];
mnist_data *d = &(*data)[i];
fread(read_data, 1, 28*28, ifp);
#ifdef MNIST_DOUBLE
for (j = 0; j < 28*28; ++j) {
d->data[j/28][j%28] = read_data[j] / 255.0;
}
#else
memcpy(d->data, read_data, 28*28);
#endif
fread(tmp, 1, 1, lfp);
d->label = tmp[0];
}
cleanup:
if (ifp) fclose(ifp);
if (lfp) fclose(lfp);
return return_code;
}
#endif /* MNIST_HDR_ONLY */
#ifdef __cplusplus
}
#endif
#endif /* USE_MNIST_LOADER */
#endif /* __MNIST_H__ */