-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_games.py
88 lines (74 loc) · 3.52 KB
/
generate_games.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import os
import random
import re
import sys
from uuid import uuid4
import pyffish as sf
from tqdm import tqdm
import uci
def get_pieces(fen):
return tuple(sorted(re.findall(r'(?:\+)?[A-Za-z]', fen.split(' ')[0])))
def generate_fens(engine, variant, book, **limits):
if variant not in sf.variants():
raise Exception("Unsupported variant: {}".format(variant))
startfens = list()
if book:
with open(book) as epdfile:
for l in epdfile:
startfens.append(l.strip())
else:
startfens.append(sf.start_fen(variant))
engine.setoption('UCI_Variant', variant)
while True:
engine.newgame()
move_stack = []
start_fen = random.choice(startfens)
fens = list()
hmvc = list()
last_change = 0
while (sf.legal_moves(variant, start_fen, move_stack)
and not sf.is_optional_game_end(variant, start_fen, move_stack)[0]):
engine.position(start_fen, move_stack)
bestmove, _ = engine.go(**limits)
move_stack.append(bestmove)
fens.append(sf.get_fen(variant, start_fen, move_stack))
if len(fens) >= 2 and get_pieces(fens[-2]) != get_pieces(fens[-1]):
last_change = len(move_stack)
hmvc.append(len(move_stack) - last_change)
if not sf.legal_moves(variant, start_fen, move_stack):
pov_score = sf.game_result(variant, start_fen, move_stack)
else:
_, pov_score = sf.is_optional_game_end(variant, start_fen, move_stack)
color = sf.get_fen(variant, start_fen, move_stack).split(' ')[1]
white_score = pov_score if color == 'w' else -pov_score
result = '1-0' if white_score > 0 else '0-1' if white_score < 0 else '1/2-1/2'
game_uuid = uuid4()
for fen, move, halfmove in zip(fens, move_stack[1:] + ['none'], hmvc):
yield '{};variant {};bm {};hmvc {};result {};game {}'.format(fen, variant, move, halfmove, result, game_uuid)
def write_fens(stream, engine, variant, count, book, **limits):
generator = generate_fens(engine, variant, book, **limits)
for _ in tqdm(range(count)):
epd = next(generator)
stream.write(epd + os.linesep)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--engine', required=True, help='chess variant engine path, e.g., to Fairy-Stockfish')
parser.add_argument('-o', '--ucioptions', type=lambda kv: kv.split("="), action='append', default=[],
help='UCI option as key=value pair. Repeat to add more options.')
parser.add_argument('-v', '--variant', default='chess', help='variant to generate positions for')
parser.add_argument('-c', '--count', type=int, default=1000, help='number of positions')
parser.add_argument('-d', '--depth', type=int, default=None, help='search depth')
parser.add_argument('-t', '--movetime', type=int, default=None, help='search movetime (ms)')
parser.add_argument('-b', '--book', type=str, default=None, help='EPD opening book')
args = parser.parse_args()
engine = uci.Engine([args.engine], dict(args.ucioptions))
sf.set_option("VariantPath", engine.options.get("VariantPath", ""))
limits = dict()
if args.depth:
limits['depth'] = args.depth
if args.movetime:
limits['movetime'] = args.movetime
if not limits:
parser.error('At least one of --depth and --movetime is required.')
write_fens(sys.stdout, engine, args.variant, args.count, args.book, **limits)