-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpaperStatsGen.py
executable file
·61 lines (41 loc) · 2.01 KB
/
paperStatsGen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from constants import *
from NeuralNetworkAgent import *
import os
def testTrainAccuracies(chkptDir, dataPath, outputFilePath):
f = open(outputFilePath, 'w')
f.write('Filename,Testing Accuracy,Training Accuracy')
testFiles = os.listdir(chkptDir)
chkptFiles = [f for f in testFiles if not f.endswith('.meta')]
chkptFiles = sorted(chkptFiles)
slModel = PolicyNetworkAgent(BATCH_SIZE)
trainStatesBatch, trainLabelsBatch, testStatesBatch, testLabelsBatch = slModel.readInputs(dataPath)
for chkpt in chkptFiles:
print 'Loading checkpoint %s' % f
layerOuts, weights, biases, betas, scales, cumCost, train_op = slModel.createSLPolicyAgent()
testAccuracy = slModel.testAgentAccuracy(trainStatesBatch, trainLabelsBatch,
trainStatesBatch.shape[0],
slModel.batch_size, chkpt)
testAccuracy = slModel.testAgentAccuracy(testStatesBatch, testLabelsBatch,
testStatesBatch.shape[0],
slModel.batch_size, chkpt)
print '\tTraining Error: %f, Testing Error: %f' %(trainAccracy,testAccuracy)
f.write('%s,%f,%f' % (chkpt,testAccuracy,trainAccracy))
f.close()
def playPachi(chkptDir, gamesPerModel, outputFilePath):
f = open(outputFilePath, 'w')
f.write('Filename,Win Percentage')
testFiles = os.listdir(chkptDir)
chkptFiles = [f for f in testFiles if not f.endswith('.meta')]
chkptFiles = sorted(chkptFiles)
rlModel = PolicyNetworkAgent(BATCH_SIZE)
for chkpt in chkptFiles:
print 'Loading checkpoint %s' % f
layerOuts, weights, biases, betas, scales, cumCost, train_op = rlModel.createSLPolicyAgent()
_, _, _, _, winNum = RL_Playout(gamesPerModel, rlModel, filename=None, opponentModel=None,
doRecord=False, verbose=False, playbyplay=False)
winRate = float(winNum)/gamesPerModel
print '\tWin rate: %f' % winRate
f.write('%s,%f' % (chkpt,winRate))
f.close()
if __name__ == "__main__":
testTrainAccuracies('./tmpchkpt','/data/go/augmented/human700_augmented.hdf5','result.csv')