-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathMatcher.cpp
112 lines (92 loc) · 3.06 KB
/
Matcher.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include "config.h"
#include <limits>
#include <cmath>
#include <unordered_map>
#include "Matcher.h"
#include "FastBoard.h"
#include "Utils.h"
#include "GTP.h"
#include "MCPolicy.h"
#include "Patterns.h"
#include "PatternHash.h"
Matcher* Matcher::get_Matcher(void) {
static Matcher s_matcher;
return &s_matcher;
}
int Matcher::PatHashG(uint32 p) const {
return (p ^= p >> 3) % G_SIZE;
}
int Matcher::PatHashV(uint32 d, uint32 p) const {
p ^= Utils::rotl(p, 11) ^ ((p + d) >> 6);
return p % V_SIZE;
}
int Matcher::PatIndex(uint32 pattern) const {
auto g_v = G[PatHashG(pattern)];
return PatHashV(g_v, pattern);
}
int Matcher::matches(int color, int pattern) const {
auto idx = PatIndex(pattern);
return m_patterns[color][idx];
}
// initialize matcher data
Matcher::Matcher() {
rescale_policy_weights();
m_patterns[FastBoard::BLACK].resize(V_SIZE);
m_patterns[FastBoard::WHITE].resize(V_SIZE);
#ifdef DEBUG
// Crash
unsigned short fill = std::numeric_limits<decltype(fill)>::max();
#else
// Don't crash
unsigned short fill = 0;
#endif
std::fill(begin(m_patterns[0]), end(m_patterns[0]), fill);
std::fill(begin(m_patterns[1]), end(m_patterns[1]), fill);
std::unordered_map<int, size_t> pattern_indexes;
// Convert the list of valid patterns into a list
// of indexes of the exact minimal size
for (auto const& pat : PolicyWeights::live_patterns) {
size_t pat_idx = pattern_indexes.size();
pattern_indexes.emplace(pat, pat_idx);
}
// minimal board we need is 3x3
FastBoard board;
board.reset_board(3);
// center square
int startvtx = board.get_vertex(1, 1);
for (auto i : Matcher::Patterns) {
int w = i;
// fill board
for (int k = 7; k >= 0; k--) {
board.set_square(startvtx + board.get_extra_dir(k),
static_cast<FastBoard::square_t>(w & 3));
w = w >> 2;
}
int reducpat1 = board.get_pattern3_augment_spec(startvtx, w, false);
int reducpat2 = board.get_pattern3_augment_spec(startvtx, w, true);
int pathash = PatIndex(i);
auto it1 = pattern_indexes.find(reducpat1);
if (it1 != pattern_indexes.cend()) {
m_patterns[FastBoard::BLACK][pathash] = it1->second;
}
auto it2 = pattern_indexes.find(reducpat2);
if (it2 != pattern_indexes.cend()) {
m_patterns[FastBoard::WHITE][pathash] = it2->second;
}
}
}
void Matcher::rescale_policy_weights() {
#if 1
// e^(x/t) = e^x^(1/t)
for (size_t i = 0; i < NUM_FEATURES; i++) {
PolicyWeights::feature_weights[i] *= PolicyWeights::feature_weights_sl[i];
PolicyWeights::feature_weights[i] =
std::pow(PolicyWeights::feature_weights[i], 1.0f / cfg_mc_softmax);
}
for (size_t i = 0; i < NUM_PATTERNS; i++) {
PolicyWeights::pattern_weights[i] *= PolicyWeights::pattern_weights_sl[i];
PolicyWeights::pattern_weights[i] =
std::pow(PolicyWeights::pattern_weights[i], 1.0f / cfg_mc_softmax);
}
#endif
}