forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn.go
129 lines (105 loc) · 3.06 KB
/
nn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package gorgonia
import (
"github.com/chewxy/gorgonia/tensor"
"github.com/pkg/errors"
)
// BinaryXent is a convenience function for doing binary crossentropy stuff.
// The formula is as below:
// -(y * logprob) + (1-y)(1-logprob)
func BinaryXent(output, target *Node) (retVal *Node, err error) {
var one *Node
var logO, omt, omo, tLogO *Node
// which constant one to use?
var dt tensor.Dtype
if dt, err = dtypeOf(output.t); err != nil {
return nil, errors.Wrapf(err, dtypeExtractionFail, output.t)
}
switch dt {
case Float64:
one = onef64
case Float32:
one = onef32
default:
return nil, errors.Errorf(nyiFail, "BinaryXEnt", dt)
}
if logO, err = Log(output); err != nil {
return nil, errors.Wrap(err, operationError)
}
if omt, err = Sub(one, target); err != nil {
return nil, errors.Wrap(err, operationError)
}
if omo, err = Sub(one, output); err != nil {
return nil, errors.Wrap(err, operationError)
}
if tLogO, err = HadamardProd(target, logO); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = Log(omo); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = HadamardProd(omt, retVal); err != nil {
return nil, errors.Wrap(err, operationError)
}
if retVal, err = Add(tLogO, retVal); err != nil {
return nil, errors.Wrap(err, operationError)
}
return Neg(retVal)
}
// Dropout is a convenience function to implement dropout.
// It uses randomly zeroes out a *Tensor with a probability drawn from
// a uniform distribution
func Dropout(x *Node, prob float64) (retVal *Node, err error) {
if prob == 0.0 {
return x, nil
}
var dt tensor.Dtype
if dt, err = dtypeOf(x.t); err != nil {
return nil, errors.Wrap(err, dtypeOfFail)
}
var opp, pr Value // opp = 1 per p
switch dt {
case Float64:
opp, _ = anyToScalar(1.0 / prob)
pr, _ = anyToScalar(prob)
case Float32:
opp, _ = anyToScalar(float32(1.0 / prob))
pr, _ = anyToScalar(float32(prob))
default:
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
}
p := NewConstant(pr)
c := NewConstant(opp)
m := UniformRandomNode(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed")
}
if retVal, err = HadamardProd(x, retVal); err != nil {
return nil, errors.Wrap(err, mulFail)
}
return HadamardDiv(retVal, c)
}
// Rectify is a convenience function for creating rectified linear units activation functions.
// This function uses >=, which is the canonical version. If you want to use >, you can create
// your own by just following this.
func Rectify(x *Node) (retVal *Node, err error) {
var zero *Node
var dt tensor.Dtype
// which zero to use?
if dt, err = dtypeOf(x.t); err != nil {
return nil, errors.Wrap(err, dtypeOfFail)
}
switch dt {
case Float64:
zero = zerof64
case Float32:
zero = zerof32
default:
return nil, errors.Errorf(nyiFail, "ReLu", dt)
}
cmp := newElemBinOp(gteOpType, x, zero)
cmp.retSame = true
if retVal, err = applyOp(cmp, x); err != nil {
return nil, errors.Wrap(err, applyOpFail)
}
return HadamardProd(x, retVal)
}