-
Notifications
You must be signed in to change notification settings - Fork 11
/
arow_tests.py
35 lines (24 loc) · 935 Bytes
/
arow_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import arow
import unittest
class InstanceTests(unittest.TestCase):
def test_instance_1(self):
data = "-1 1:0.1 2:0.5 9:0.1"
inst = arow.instance_from_svm_input(data)
#print inst
class AROWTests(unittest.TestCase):
def test_arow_1(self):
dataset = ["-1 1:0.1 2:0.5 9:0.1",
"+1 1:0.6 2:0.2 8:0.2",
"-1 1:0.1 2:0.6 8:0.3",
"+1 1:0.4 2:0.7 9:0.4",
]
data = [arow.instance_from_svm_input(d) for d in dataset]
cl = arow.AROW()
print [cl.predict(d).label for d in data]
print [d.costs for d in data]
cl.train(data)
print [cl.predict(d, verbose=True).label for d in data]
print [cl.predict(d, verbose=True).featureValueWeights for d in data]
print [d.costs for d in data]
if __name__ == "__main__":
unittest.main()