-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathwymlp1.hpp
79 lines (76 loc) · 3.81 KB
/
wymlp1.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include "wyhash.h"
#include <math.h>
static inline float wymlp_act(float x){ return (x/(1+(((int)(x>0)<<1)-1)*x)); }
static inline float wymlp_gra(float x){ return ((1-(((int)(x>0)<<1)-1)*x)*(1-(((int)(x>0)<<1)-1)*x)); }
template<class type, unsigned input, unsigned hidden, unsigned depth, unsigned output, unsigned task>
double wymlp(type *weight, type *x, type *y, type eta, uint64_t seed, double dropout) {
#ifdef WYMLP_RNN
if(weight==NULL) return (input+1)*hidden+hidden*hidden+output*hidden;
#define woff(i,l) (l?(l<depth?(input+1)*hidden+i*hidden:(input+1)*hidden+hidden*hidden+i*hidden ):i*hidden)
#else
if(weight==NULL) return (input+1)*hidden+(depth-1)*hidden*hidden+output*hidden;
#define woff(i,l) (l?(input+1)*hidden+(l-1)*hidden*hidden+i*hidden:i*hidden)
#endif
type a[2*depth*hidden+output]= {}, *d=a+depth*hidden, *o=a+2*depth*hidden, wh=1/sqrtf(hidden), wi=(1-(eta<0)*dropout)/sqrtf(input+1); uint64_t drop=dropout*~0ull;
for(unsigned i=0; i<=input; i++) {
type *w=weight+woff(i,0), s=(i==input?1:x[i])*(eta<0||wyhash64(i,seed)>=drop);
for(unsigned j=0; j<hidden; j++) a[j]+=s*w[j];
}
for(unsigned i=0; i<hidden; i++) a[i]=i?wymlp_act(wi*a[i]):1;
for(unsigned l=1; l<=depth; l++) {
type *p=a+(l-1)*hidden, *q=(l==depth?o:a+l*hidden);
for(unsigned i=0; i<(l==depth?output:hidden); i++) {
type *w=weight+woff(i,l), s=0;
for(unsigned j=0; j<hidden; j++) s+=w[j]*p[j];
q[i]=(l==depth?s*wh:(i?wymlp_act(s*wh):1));
}
}
switch(task) {
case 0: { for(unsigned i=0; i<output; i++) if(eta<0) y[i]=o[i]; else o[i]=(o[i]-y[i])*eta; } break;
case 1: { for(unsigned i=0; i<output; i++) if(eta<0) y[i]=1/(1+expf(-o[i])); else o[i]=(1/(1+expf(-o[i]))-y[i])*eta; } break;
case 2: { type m=o[0], s=0;
for(unsigned i=1; i<output; i++) if(o[i]>m) m=o[i];
for(unsigned i=0; i<output; i++) s+=(o[i]=expf(o[i]-m));
for(unsigned i=0; i<output; i++) if(eta<0) y[i]=o[i]/s; else o[i]=(o[i]/s-(i==(unsigned)y[0]))*eta;
} break;
}
if(eta<0) return 0;
for(unsigned l=depth; l; l--) {
type *p=a+(l-1)*hidden, *q=(l==depth?o:a+l*hidden), *g=d+(l-1)*hidden, *h=(l==depth?o:d+l*hidden);
for(unsigned i=0; i<(l==depth?output:hidden); i++) {
type *w=weight+woff(i,l), s=(l==depth?q[i]:h[i]*wymlp_gra(q[i]))*wh;
for(unsigned j=0; j<hidden; j++) { g[j]+=s*w[j]; w[j]-=s*p[j]; }
}
}
for(unsigned i=0; i<hidden; i++) d[i]*=wymlp_gra(a[i])*wi;
for(unsigned i=0; i<=input; i++) {
type *w=weight+woff(i,0), s=(i==input?1:x[i])*(eta<0||wyhash64(i,seed)>=drop);
for(unsigned j=0; j<hidden; j++) w[j]-=s*d[j];
}
return 0;
}
/*
Author: Wang Yi <[email protected]>
Example:
int main(void){
float x[4]={1,2,3,5}, y[1]={2};
vector<float> weight(wymlp<float,12,32,4,1,0>(NULL,NULL,NULL,0,0,-1)); //set dropout<0 to return size
for(size_t i=0; i<weight.size(); i++) weight[i]=3.0*rand()/RAND_MAX-1.5;
for(unsigned i=0; i<1000000; i++){
x[0]+=0.01; y[0]+=0.1; //some "new" data
wymlp<float,12,32,4,1,0>(weight.data(), x, y, 0.1, wygrand(), 0.5); // training. set eta>0 to train
wymlp<float,12,32,4,1,0>(weight.data(), x, y, -1, wygrand(), 0.5); // training. set eta<0 to predict
}
return 0;
}
Comments:
0: task=0: regression; task=1: logistic; task=2: softmax
1: dropout<0 lead to size() function
2: eta<0 lead to prediction only.
3: The expected |X[i]|, |Y[i]| should be around 1. Normalize yor input and output first.
4: In practice, it is OK to call model function parallelly with multi-threads, however, they may be slower for small net.
5: The code is portable, however, if Ofast is used on X86, SSE or AVX or even AVX512 will enable very fast code!
6: The default and suggested model is shared hidden-hidden weights. If you want vanilla MLP, use the following code
if(weight==NULL) return (input+1)*hidden+(depth-1)*hidden*hidden+output*hidden;
#define woff(i,l) (l?(input+1)*hidden+(l-1)*hidden*hidden+i*hidden:i*hidden)
*/