-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvec_test.go
113 lines (87 loc) · 2.54 KB
/
vec_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package seafan
import (
"testing"
"github.com/stretchr/testify/assert"
G "gorgonia.org/gorgonia"
)
func getData(t *testing.T) *GData {
x1 := []float64{1, 2, 3, 4, 8, 9, 10}
x2 := []string{"a", "b", "c", "a", "a", "a", "a"}
x3 := []int32{4, 5, 6, 1, 2, 2, 2}
gData := NewGData()
e := gData.AppendC(NewRawCast(x1, nil), "x1", false, nil, true)
assert.Nil(t, e)
e = gData.AppendD(NewRawCast(x2, nil), "x2", nil, true)
assert.Nil(t, e)
e = gData.AppendD(NewRawCast(x3, nil), "x3", nil, true)
assert.Nil(t, e)
e = gData.MakeOneHot("x2", "x2Oh")
assert.Nil(t, e)
return gData
}
func TestNewVecData(t *testing.T) {
gData := getData(t)
vecData := NewVecData("test", gData)
assert.Equal(t, vecData.Rows(), gData.Get("x1").Summary.NRows)
}
func TestVecData_Batch(t *testing.T) {
vecData := NewVecData("test", getData(t))
g := G.NewGraph()
nd := G.NewTensor(g, G.Float64, 2, G.WithName("x1"), G.WithShape(vecData.BatchSize(), 1))
nds := G.Nodes{nd}
e := vecData.Init()
assert.Nil(t, e)
for try := 0; try < 2; try++ {
act := make([]float64, 0)
for vecData.Batch(nds) {
act = append(act, nds[0].Value().Data().([]float64)...)
}
assert.ElementsMatch(t, act, vecData.Get("x1").Data)
}
}
func TestVecData_Row(t *testing.T) {
gData := getData(t)
vecData := NewVecData("test", gData)
take := 1
newPipe, e := vecData.Row(take)
assert.Nil(t, e)
x2 := newPipe.Get("x2")
x2Big := vecData.Get("x2")
assert.Equal(t, x2Big.Raw.Data[take].(string), x2.Raw.Data[0].(string))
}
func TestVecData_Where(t *testing.T) {
gData := getData(t)
vecData := NewVecData("test", gData)
equalTo := []any{"b", "c"}
newPipe, e := vecData.Where("x2", equalTo)
assert.Nil(t, e)
x2, e := newPipe.GData().GetRaw("x2")
assert.Nil(t, e)
assert.ElementsMatch(t, x2.Data, equalTo)
}
func TestSliceVecData(t *testing.T) {
vecData := NewVecData("test", getData(t))
slice, e := NewSlice("x2", 0, vecData, nil)
assert.Nil(t, e)
g := G.NewGraph()
nd := G.NewTensor(g, G.Float64, 2, G.WithName("x1"), G.WithShape(vecData.BatchSize(), 1))
nds := G.Nodes{nd}
e = vecData.Init()
assert.Nil(t, e)
x1Exp := make([][]float64, 3)
x1Exp[0], x1Exp[1], x1Exp[2] = []float64{1, 4, 8, 9, 10}, []float64{2}, []float64{3}
// run through the slices
ind := 0
for slice.Iter() {
sl := slice.MakeSlicer()
newVec, e := vecData.Slice(sl)
assert.Nil(t, e)
x1act := make([]float64, 0)
// run through the batches, accumulate x1
for newVec.Batch(nds) {
x1act = append(x1act, nds[0].Value().Data().([]float64)...)
}
assert.ElementsMatch(t, x1Exp[ind], x1act)
ind++
}
}