Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gaviota support for uci and selfplay #2044

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/lczero-common
Submodule lczero-common updated 1 files
+0 −88 proto/net.proto
3 changes: 2 additions & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

project('lc0', 'cpp',
project('lc0', ['c', 'cpp'],
default_options : ['cpp_std=c++17', 'b_ndebug=if-release', 'warning_level=3', 'b_lto=true', 'b_vscrt=mt'],
meson_version: '>=0.55')

Expand Down Expand Up @@ -691,6 +691,7 @@ endif

if get_option('lc0')
files += common_files
deps += subproject('gaviotatb').get_variable('gaviotatb_dep')
executable('lc0', 'src/main.cc',
files, include_directories: includes, dependencies: deps, install: true)
endif
Expand Down
5 changes: 4 additions & 1 deletion src/benchmark/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,16 @@ void Benchmark::Run() {
tree.ResetToPosition(position, {});

const auto start = std::chrono::steady_clock::now();
std::unique_ptr<bool> gaviotaEnabled_;
gaviotaEnabled_ = std::make_unique<bool>(false);

auto search = std::make_unique<Search>(
tree, network.get(),
std::make_unique<CallbackUciResponder>(
std::bind(&Benchmark::OnBestMove, this, std::placeholders::_1),
std::bind(&Benchmark::OnInfo, this, std::placeholders::_1)),
MoveList(), start, std::move(stopper), false, false, option_dict,
&cache, nullptr);
&cache, nullptr, &gaviotaEnabled_);
search->StartThreads(option_dict.Get<int>(kThreadsOptionId));
search->Wait();
const auto end = std::chrono::steady_clock::now();
Expand Down
39 changes: 38 additions & 1 deletion src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ const OptionId kSyzygyTablebaseId{
"List of Syzygy tablebase directories, list entries separated by system "
"separator (\";\" for Windows, \":\" for Linux).",
's'};
const OptionId kGaviotaTablebaseId{"gaviotatb-paths", "GaviotaPath",
"List of Gaviota tablebase directories. If both Syzygy and Gaviota are "
"provided, Gaviota will take precedence when only 5 pieces remain. "
"Note that if this parameter is set it is assumed that all Gaviota "
"tables (3, 4 and 5-men) are available, but this is not checked, "
"so using this parameter without all of these is not supported."};
const OptionId kPonderId{"", "Ponder",
"This option is ignored. Here to please chess GUIs."};
const OptionId kUciChess960{
Expand Down Expand Up @@ -116,6 +122,7 @@ void EngineController::PopulateOptions(OptionsParser* options) {
options->UnhideOption(SearchParams::kMultiPvId);
}
options->Add<StringOption>(kSyzygyTablebaseId);
options->Add<StringOption>(kGaviotaTablebaseId);
// Add "Ponder" option to signal to GUIs that we support pondering.
// This option is currently not used by lc0 in any way.
options->Add<BoolOption>(kPonderId) = true;
Expand All @@ -139,6 +146,13 @@ void EngineController::ResetMoveTimer() {
move_start_time_ = std::chrono::steady_clock::now();
}

// Needed for Gaviota
#ifdef _WIN32
#define SEP_CHAR ';'
#else
#define SEP_CHAR ':'
#endif

// Updates values from Uci options.
void EngineController::UpdateFromUciOptions() {
SharedLock lock(busy_mutex_);
Expand All @@ -158,6 +172,29 @@ void EngineController::UpdateFromUciOptions() {
tb_paths_.clear();
}

// Init Gaviota, if a path is given
auto dtmPaths = options_.Get<std::string>(kGaviotaTablebaseId);
if (dtmPaths.size() != 0) {
std::stringstream path_string_stream(dtmPaths);
std::string path;
auto paths = tbpaths_init();
while (std::getline(path_string_stream, path, SEP_CHAR)) {
paths = tbpaths_add(paths, path.c_str());
}
tb_init(0, tb_CP4, paths);
tbcache_init(64 * 1024 * 1024, 64);
if (tb_availability() != 63) {
std::cerr << "UNEXPECTED gaviota availability" << std::endl;
gaviotaEnabled_ = std::make_unique<bool>(false);
return;
} else {
gaviotaEnabled_ = std::make_unique<bool>(true);
std::cerr << "Found Gaviota TBs" << std::endl;
}
} else {
gaviotaEnabled_ = std::make_unique<bool>(false);
}

// Network.
const auto network_configuration =
NetworkFactory::BackendConfiguration(options_);
Expand Down Expand Up @@ -394,7 +431,7 @@ void EngineController::Go(const GoParams& params) {
*tree_, network_.get(), std::move(responder),
StringsToMovelist(params.searchmoves, tree_->HeadPosition().GetBoard()),
*move_start_time_, std::move(stopper), params.infinite, params.ponder,
options_, &cache_, syzygy_tb_.get());
options_, &cache_, syzygy_tb_.get(), &gaviotaEnabled_);

LOGFILE << "Timer started at "
<< FormatTime(SteadyClockToSystemClock(*move_start_time_));
Expand Down
1 change: 1 addition & 0 deletions src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class EngineController {
std::unique_ptr<Search> search_;
std::unique_ptr<NodeTree> tree_;
std::unique_ptr<SyzygyTablebase> syzygy_tb_;
std::unique_ptr<bool> gaviotaEnabled_ = nullptr;
std::unique_ptr<Network> network_;
NNCache cache_;

Expand Down
212 changes: 200 additions & 12 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,215 @@ namespace {
// Maximum delay between outputting "uci info" when nothing interesting happens.
const int kUciInfoMinimumFrequencyMs = 5000;

void gaviota_tb_probe_hard(const Position& pos, unsigned int& info,
unsigned int& dtm) {
unsigned int wsq[17];
unsigned int bsq[17];
unsigned char wpc[17];
unsigned char bpc[17];

auto stm = pos.IsBlackToMove() ? tb_BLACK_TO_MOVE : tb_WHITE_TO_MOVE;
auto& board = pos.IsBlackToMove() ? pos.GetThemBoard() : pos.GetBoard();
auto epsq = tb_NOSQUARE;
for (auto sq : board.en_passant()) {
// Our internal representation stores en_passant 2 rows away
// from the actual sq.
if (sq.row() == 0) {
epsq = (TB_squares)(sq.as_int() + 16);
} else {
epsq = (TB_squares)(sq.as_int() - 16);
}
}
int idx = 0;
for (auto sq : (board.ours() & board.kings())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_KING;
idx++;
}
for (auto sq : (board.ours() & board.knights())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_KNIGHT;
idx++;
}
for (auto sq : (board.ours() & board.queens())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_QUEEN;
idx++;
}
for (auto sq : (board.ours() & board.rooks())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_ROOK;
idx++;
}
for (auto sq : (board.ours() & board.bishops())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_BISHOP;
idx++;
}
for (auto sq : (board.ours() & board.pawns())) {
wsq[idx] = (TB_squares)sq.as_int();
wpc[idx] = tb_PAWN;
idx++;
}
wsq[idx] = tb_NOSQUARE;
wpc[idx] = tb_NOPIECE;

idx = 0;
for (auto sq : (board.theirs() & board.kings())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_KING;
idx++;
}
for (auto sq : (board.theirs() & board.knights())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_KNIGHT;
idx++;
}
for (auto sq : (board.theirs() & board.queens())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_QUEEN;
idx++;
}
for (auto sq : (board.theirs() & board.rooks())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_ROOK;
idx++;
}
for (auto sq : (board.theirs() & board.bishops())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_BISHOP;
idx++;
}
for (auto sq : (board.theirs() & board.pawns())) {
bsq[idx] = (TB_squares)sq.as_int();
bpc[idx] = tb_PAWN;
idx++;
}
bsq[idx] = tb_NOSQUARE;
bpc[idx] = tb_NOPIECE;

tb_probe_hard(stm, epsq, tb_NOCASTLE, wsq, bsq, wpc, bpc, &info, &dtm);
}

bool root_probe_gaviota(const Position& pos, std::vector<Move>* safe_moves) {
// if the position is winning the strategy is trivial: shortest mate for the winning side, longest mate for the losing side.
// if the position is draw, all non-losing moves are equal.

// Generate the list of legal moves.
auto root_moves = pos.GetBoard().GenerateLegalMoves();

// Create a vector to store dtm information in.
std::vector<unsigned int> dtms (root_moves.size());
// And a vector for info information.
std::vector<unsigned int> infos (root_moves.size());
unsigned int minimum_dtm = 1000;
unsigned int maximum_dtm = 0;
unsigned int target_dtm = 0;
int dtm_idx = 0;
bool at_least_there_is_a_draw = false;

// for all legal moves identify minimum and maximum dtm, if any.
for (auto& move : root_moves) {
Position next_pos = Position(pos, move);
unsigned int info;
unsigned int dtm;
gaviota_tb_probe_hard(next_pos, info, dtm);
// LOGFILE << "DTM for move: " << move.as_string() << " is " << dtm << " and info is " << info << "\n";
dtms[dtm_idx] = dtm;
infos[dtm_idx] = info;
dtm_idx++;
// info == 0 means draw.
// info == 1 and info == 2 implies a decisive move

// if some moves have info == 1 and others have info == 2, then the info == 1 moves appears to be
// moves that lose in a otherwise won position.

// when info == 2 && dtm is odd then root is losing
// when info == 2 && dtm is even then root is winning does the same hold when info == 1?

if (info == 2 || info == 1){
if(dtm % 2 == 0 || dtm == 0){
// root is winning, minimise
if(dtm < minimum_dtm) minimum_dtm = dtm;
} else {
// it root is losing, maximise
if(dtm > maximum_dtm) maximum_dtm = dtm;
}
} else {
// Set the draw flag if not already set
if (!at_least_there_is_a_draw){
at_least_there_is_a_draw = true;
}
}
}

// Set a target DTM if the game is not drawn, as implied by
// minium_dtm or maximum_dtm are changed from there default values
if (minimum_dtm != 1000) {
target_dtm = minimum_dtm;
// LOGFILE << "Winning, opting for lowest dtm = " << target_dtm;
} else {
// Only care about how to lose when actually losing.
if (!at_least_there_is_a_draw) {
target_dtm = maximum_dtm;
// LOGFILE << "Losing, opting for highest dtm = " << target_dtm;
}
}

if (minimum_dtm != 1000 || !at_least_there_is_a_draw) {
dtm_idx = 0;
for (auto& move : root_moves) {
if (dtms[dtm_idx] == target_dtm && /* stalemate is also dtm 0, make sure to pick a proper mate */ infos[dtm_idx] != 0) {
safe_moves->push_back(move);
}
dtm_idx++;
}
} else {
// LOGFILE << "Drawing is optimal exclude losing moves";
// Draw is the optimal outcome, but keep only drawing moves (info == 0 means draw).
dtm_idx = 0;
for (auto& move : root_moves) {
if (infos[dtm_idx] == 0) {
safe_moves->push_back(move);
}
dtm_idx++;
}
}
return true;
}

MoveList MakeRootMoveFilter(const MoveList& searchmoves,
SyzygyTablebase* syzygy_tb,
const PositionHistory& history, bool fast_play,
std::atomic<int>* tb_hits, bool* dtz_success) {
std::atomic<int>* tb_hits, bool* dtz_success,
std::unique_ptr<bool>* gaviotaEnabled) {
assert(tb_hits);
assert(dtz_success);
// Search moves overrides tablebase.
if (!searchmoves.empty()) return searchmoves;
const auto& board = history.Last().GetBoard();
MoveList root_moves;
if (!syzygy_tb || !board.castlings().no_legal_castle() ||
(board.ours() | board.theirs()).count() > syzygy_tb->max_cardinality()) {
return root_moves;
}
if (syzygy_tb->root_probe(

// Select TB to use.
// If gaviota is available and at most 5 pieces left, then use gaviota, else use syzygy.

if (gaviotaEnabled && (board.ours() | board.theirs()).count() <= 5 &&
root_probe_gaviota(history.Last(), &root_moves)){
tb_hits->fetch_add(1, std::memory_order_acq_rel);
} else {
// Try syzygy instead
if (!syzygy_tb || !board.castlings().no_legal_castle() ||
(board.ours() | board.theirs()).count() > syzygy_tb->max_cardinality()) {
return root_moves;
}
if (syzygy_tb->root_probe(
history.Last(), fast_play || history.DidRepeatSinceLastZeroingMove(),
false, &root_moves)) {
*dtz_success = true;
tb_hits->fetch_add(1, std::memory_order_acq_rel);
} else if (syzygy_tb->root_probe_wdl(history.Last(), &root_moves)) {
tb_hits->fetch_add(1, std::memory_order_acq_rel);
*dtz_success = true;
tb_hits->fetch_add(1, std::memory_order_acq_rel);
} else if (syzygy_tb->root_probe_wdl(history.Last(), &root_moves)) {
tb_hits->fetch_add(1, std::memory_order_acq_rel);
}
}
return root_moves;
}
Expand Down Expand Up @@ -155,7 +343,7 @@ Search::Search(const NodeTree& tree, Network* network,
std::chrono::steady_clock::time_point start_time,
std::unique_ptr<SearchStopper> stopper, bool infinite,
bool ponder, const OptionsDict& options, NNCache* cache,
SyzygyTablebase* syzygy_tb)
SyzygyTablebase* syzygy_tb, std::unique_ptr<bool>* gaviotaEnabled)
: ok_to_respond_bestmove_(!infinite && !ponder),
stopper_(std::move(stopper)),
root_node_(tree.GetCurrentHead()),
Expand All @@ -169,7 +357,7 @@ Search::Search(const NodeTree& tree, Network* network,
initial_visits_(root_node_->GetN()),
root_move_filter_(MakeRootMoveFilter(
searchmoves_, syzygy_tb_, played_history_,
params_.GetSyzygyFastPlay(), &tb_hits_, &root_is_in_dtz_)),
params_.GetSyzygyFastPlay(), &tb_hits_, &root_is_in_dtz_, gaviotaEnabled)),
uci_responder_(std::move(uci_responder)) {
if (params_.GetMaxConcurrentSearchers() != 0) {
pending_searchers_.store(params_.GetMaxConcurrentSearchers(),
Expand Down
9 changes: 8 additions & 1 deletion src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "neural/cache.h"
#include "neural/network.h"
#include "syzygy/syzygy.h"
#include "gtb-probe.h"
#include "utils/logging.h"
#include "utils/mutex.h"

Expand All @@ -55,7 +56,7 @@ class Search {
std::chrono::steady_clock::time_point start_time,
std::unique_ptr<SearchStopper> stopper, bool infinite, bool ponder,
const OptionsDict& options, NNCache* cache,
SyzygyTablebase* syzygy_tb);
SyzygyTablebase* syzygy_tb, std::unique_ptr<bool>* gaviotaEnabled);

~Search();

Expand Down Expand Up @@ -167,6 +168,12 @@ class Search {
// Fixed positions which happened before the search.
const PositionHistory& played_history_;

// Probes Gaviota tables to determine which moves are on the optimal play path.
// Thread safe.
// Returns false if the position is not in the tablebase.
// Safe moves are added to the safe_moves output paramater.
bool root_probe_gaviota(const Position& pos, std::vector<Move>* safe_moves);

Network* const network_;
const SearchParams params_;
const MoveList searchmoves_;
Expand Down
Loading