forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoperatorLinAlg_const.go
82 lines (68 loc) · 1.79 KB
/
operatorLinAlg_const.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package gorgonia
import "github.com/chewxy/hm"
// āBinOpStrs is the string representation for binLAOperator
// It should be held constant
var āBinOpStrs = [maxĀBinaryOperator]string{
"×",
"×",
"⋅",
"⊗",
// "×××",
}
var āBinOpDiffExprs = [maxĀBinaryOperator]func(tA, tB bool, x, y, z, grad *Node) (Nodes, error){
matMulDiffExpr,
matVecMulDiffExpr,
vecDotDiffExpr,
outerProdDiffExpr,
}
var āBinOpDiffs = [maxĀBinaryOperator]func(tA, tB bool, x, y, z *Node) error{
matMulDiff,
matVecMulDiff,
vecDotDiff,
outerProdDiff,
}
var āBinOpTypes = [maxĀBinaryOperator]func() hm.Type{
matMulType,
matVecMulType,
vecDotType,
outerProdType,
}
/* TYPES FOR LINALG BINARY OP*/
// matVecMulOp is a function with this type:
// matVecMulOp :: (Float a) ⇒ Vector a → Matrix a → Vector a
//
// For the moment only floats are allowed
func matVecMulType() hm.Type {
a := hm.TypeVariable('a')
v := newTensorType(1, a)
m := newTensorType(2, a)
return hm.NewFnType(m, v, v)
}
// matMulOp is a function with this type:
// matMulOp :: (Float a) ⇒ Matrix a → Matrix a → Matrix a
//
// For the moment only floats are allowed
func matMulType() hm.Type {
a := hm.TypeVariable('a')
m := newTensorType(2, a)
return hm.NewFnType(m, m, m)
}
// vecDotOp is a function with this type:
// vecDotOp :: (Float a) ⇒ Vector a → Vector a → a
//
// For the moment only floats are allowed
func vecDotType() hm.Type {
a := hm.TypeVariable('a')
v := newTensorType(1, a)
return hm.NewFnType(v, v, a)
}
// outerProdOp is a function with this type:
// outerProdOp :: (Float a) ⇒ Vector a → Vector a → Matrix a
//
// For the moment only floats are allowed
func outerProdType() hm.Type {
a := hm.TypeVariable('a')
v := newTensorType(1, a)
m := newTensorType(2, a)
return hm.NewFnType(v, v, m)
}