forked from gorgonia/gorgonia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_broadcast_op_test.go
105 lines (85 loc) · 2.73 KB
/
example_broadcast_op_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package gorgonia
import (
"fmt"
"log"
// . "gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)
// By default, Gorgonia operations do not perform broadcasting.
// To do broadcasting, you would need to manually specify the operation
func ExampleBroadcastAdd() {
g := NewGraph()
a := NewVector(g, tensor.Float64, WithShape(2), WithName("a"), WithValue(tensor.New(tensor.WithBacking([]float64{100, 100}))))
b := NewMatrix(g, tensor.Float64, WithShape(2, 2), WithName("b"), WithValue(tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 1, 2, 2}))))
fmt.Printf("a = %v\nb =\n%v\n", a.Value(), b.Value())
_, err := Add(a, b)
fmt.Printf("a + b yields an error: %v\n\n", err)
// Note here the broadcasting of a is on the first axis, not the zeroth axis. Simply put, assume that it's already a (2,1) matrix.
ab, err := BroadcastAdd(a, b, []byte{1}, nil)
if err != nil {
fmt.Printf("uh oh, something went wrong: %v\n", err)
}
ba, err := BroadcastAdd(b, a, nil, []byte{1})
if err != nil {
fmt.Printf("uh oh, something went wrong: %v\n", err)
}
// Now, let's run the program
machine := NewTapeMachine(g)
defer machine.Close()
if err = machine.RunAll(); err != nil {
log.Fatal(err)
}
fmt.Printf("a +⃗ b =\n%v\n", ab.Value())
fmt.Printf("b +⃗ a =\n%v", ba.Value())
// Output:
// a = [100 100]
// b =
// ⎡1 1⎤
// ⎣2 2⎦
//
// a + b yields an error: Failed to infer shape. Op: + false: Shape mismatch: (2) and (2, 2)
//
// a +⃗ b =
// ⎡101 101⎤
// ⎣102 102⎦
//
// b +⃗ a =
// ⎡101 101⎤
// ⎣102 102⎦
}
func ExampleBroadcastGte_creatingTriangleMatrices() {
// Broadcasting is useful. We can create triangular dense matrices simply
g := NewGraph()
a := NewMatrix(g, tensor.Float64, WithShape(3, 1), WithName("a"), WithInit(RangedFrom(0)))
b := NewMatrix(g, tensor.Float64, WithShape(1, 4), WithName("b"), WithInit(RangedFrom(0)))
tl, err := BroadcastGte(a, b, true, []byte{1}, []byte{0})
if err != nil {
log.Fatalf("uh oh. Something went wrong %v", err)
}
tu, err := BroadcastLt(a, b, true, []byte{1}, []byte{0})
if err != nil {
log.Fatalf("uh oh. Something went wrong %v", err)
}
m := NewTapeMachine(g)
// PEDAGOGICAL:
// Uncomment the following code if you want to see what happens behind the scenes
// m.Close()
// logger := log.New(os.Stderr, "",0)
// m = NewTapeMachine(g, WithLogger(logger), WithWatchlist())
defer m.Close()
if err = m.RunAll(); err != nil {
log.Fatal(err)
}
fmt.Printf("triangular, lower:\n%v\n", tl.Value())
fmt.Printf("triangular, upper:\n%v\n", tu.Value())
// Output:
// triangular, lower:
// ⎡1 0 0 0⎤
// ⎢1 1 0 0⎥
// ⎣1 1 1 0⎦
//
// triangular, upper:
// ⎡0 1 1 1⎤
// ⎢0 0 1 1⎥
// ⎣0 0 0 1⎦
}