forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbigtable_input.py
757 lines (642 loc) · 28.6 KB
/
bigtable_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Read Minigo game examples from a Bigtable.
"""
import bisect
import collections
import datetime
import math
import multiprocessing
import operator
import re
import struct
import time
import numpy as np
from tqdm import tqdm
from absl import flags
from google.cloud import bigtable
from google.cloud.bigtable import row_filters as bigtable_row_filters
from google.cloud.bigtable import column_family as bigtable_column_family
import tensorflow as tf
from tensorflow.contrib import cloud as contrib_cloud
import utils
flags.DEFINE_string('cbt_project', None,
'The project used to connect to the cloud bigtable ')
# cbt_instance: identifier of Cloud Bigtable instance in cbt_project.
flags.DEFINE_string('cbt_instance', None,
'The identifier of the cloud bigtable instance in cbt_project')
# cbt_table: identifier of Cloud Bigtable table in cbt_instance.
# The cbt_table is expected to be accompanied by one with an "-nr"
# suffix, for "no-resign".
flags.DEFINE_string('cbt_table', None,
'The table within the cloud bigtable instance to use')
FLAGS = flags.FLAGS
# Constants
ROW_PREFIX = 'g_{:0>10}_'
ROWCOUNT_PREFIX = 'ct_{:0>10}_'
# Model tabels (models, models_for_eval) row key
MODEL_PREFIX = "m_{run}_{num:0>10}"
# Name of model
MODEL_NAME = b'model'
# Maximum number of concurrent processes to use when issuing requests against
# Bigtable. Value taken from default in the load-testing tool described here:
#
# https://github.com/googleapis/google-cloud-go/blob/master/bigtable/cmd/loadtest/loadtest.go
MAX_BT_CONCURRENCY = 100
# Column family and qualifier constants.
# Column Families
METADATA = 'metadata'
TFEXAMPLE = 'tfexample'
# Column Qualifiers
# Note that in CBT, families are strings and qualifiers are bytes.
TABLE_STATE = b'table_state'
WAIT_CELL = b'wait_for_game_number'
GAME_COUNTER = b'game_counter'
MOVE_COUNT = b'move_count'
# Patterns
_game_row_key = re.compile(r'g_(\d+)_m_(\d+)')
_game_from_counter = re.compile(r'ct_(\d+)_')
# The string information needed to construct a client of a Bigtable table.
BigtableSpec = collections.namedtuple(
'BigtableSpec',
['project', 'instance', 'table'])
# Information needed to create a mix of two Game queues.
# r = resign/regular; c = calibration (no-resign)
GameMix = collections.namedtuple(
'GameMix',
['games_r', 'moves_r',
'games_c', 'moves_c',
'selection'])
def cbt_intvalue(value):
"""Decode a big-endian uint64.
Cloud Bigtable stores integers as big-endian uint64,
and performs this translation when integers are being
set. But when being read, the values need to be
decoded.
"""
return int(struct.unpack('>q', value)[0])
def make_single_array(ds, batch_size=8*1024):
"""Create a single numpy array from a dataset.
The dataset must have only one dimension, that is,
the length of its `output_shapes` and `output_types`
is 1, and its output shape must be `[]`, that is,
every tensor in the dataset must be a scalar.
Args:
ds: a TF Dataset.
batch_size: how many elements to read per pass
Returns:
a single numpy array.
"""
if isinstance(ds.output_types, tuple) or isinstance(ds.output_shapes, tuple):
raise ValueError('Dataset must have a single type and shape')
nshapes = len(ds.output_shapes)
if nshapes > 0:
raise ValueError('Dataset must be comprised of scalars (TensorShape=[])')
batches = []
with tf.Session() as sess:
ds = ds.batch(batch_size)
iterator = ds.make_initializable_iterator()
sess.run(iterator.initializer)
get_next = iterator.get_next()
with tqdm(desc='Elements', unit_scale=1) as pbar:
try:
while True:
batches.append(sess.run(get_next))
pbar.update(len(batches[-1]))
except tf.errors.OutOfRangeError:
pass
if batches:
return np.concatenate(batches)
return np.array([], dtype=ds.output_types.as_numpy_dtype)
def _histogram_move_keys_by_game(sess, ds, batch_size=8*1024):
"""Given dataset of key names, return histogram of moves/game.
Move counts are written by the game players, so
this is mostly useful for repair or backfill.
Args:
sess: TF session
ds: TF dataset containing game move keys.
batch_size: performance tuning parameter
"""
ds = ds.batch(batch_size)
# Turns 'g_0000001234_m_133' into 'g_0000001234'
ds = ds.map(lambda x: tf.strings.substr(x, 0, 12))
iterator = ds.make_initializable_iterator()
sess.run(iterator.initializer)
get_next = iterator.get_next()
h = collections.Counter()
try:
while True:
h.update(sess.run(get_next))
except tf.errors.OutOfRangeError:
pass
# NOTE: Cannot be truly sure the count is right till the end.
return h
def _game_keys_as_array(ds):
"""Turn keys of a Bigtable dataset into an array.
Take g_GGG_m_MMM and create GGG.MMM numbers.
Valuable when visualizing the distribution of a given dataset in
the game keyspace.
"""
ds = ds.map(lambda row_key, cell: row_key)
# want 'g_0000001234_m_133' is '0000001234.133' and so forth
ds = ds.map(lambda x:
tf.strings.to_number(tf.strings.substr(x, 2, 10) +
'.' +
tf.strings.substr(x, 15, 3),
out_type=tf.float64))
return make_single_array(ds)
def _delete_rows(args):
"""Delete the given row keys from the given Bigtable.
The args are (BigtableSpec, row_keys), but are passed
as a single argument in order to work with
multiprocessing.Pool.map. This is also the reason why this is a
top-level function instead of a method.
"""
btspec, row_keys = args
bt_table = bigtable.Client(btspec.project).instance(
btspec.instance).table(btspec.table)
rows = [bt_table.row(k) for k in row_keys]
for r in rows:
r.delete()
bt_table.mutate_rows(rows)
return row_keys
class GameQueue:
"""Queue of games stored in a Cloud Bigtable.
The state of the table is stored in the `table_state`
row, which includes the columns `metadata:game_counter`.
"""
def __init__(self, project_name, instance_name, table_name):
"""Constructor.
Args:
project_name: string name of GCP project having table.
instance_name: string name of CBT instance in project.
table_name: string name of CBT table in instance.
"""
self.btspec = BigtableSpec(project_name, instance_name, table_name)
self.bt_table = bigtable.Client(
self.btspec.project, admin=True).instance(
self.btspec.instance).table(self.btspec.table)
self.tf_table = contrib_cloud.BigtableClient(
self.btspec.project,
self.btspec.instance).table(self.btspec.table)
def create(self):
"""Create the table underlying the queue.
Create the 'metadata' and 'tfexample' column families
and their properties.
"""
if self.bt_table.exists():
utils.dbg('Table already exists')
return
max_versions_rule = bigtable_column_family.MaxVersionsGCRule(1)
self.bt_table.create(column_families={
METADATA: max_versions_rule,
TFEXAMPLE: max_versions_rule})
@property
def latest_game_number(self):
"""Return the number of the next game to be written."""
table_state = self.bt_table.read_row(
TABLE_STATE,
filter_=bigtable_row_filters.ColumnRangeFilter(
METADATA, GAME_COUNTER, GAME_COUNTER))
if table_state is None:
return 0
return cbt_intvalue(table_state.cell_value(METADATA, GAME_COUNTER))
@latest_game_number.setter
def latest_game_number(self, latest):
table_state = self.bt_table.row(TABLE_STATE)
table_state.set_cell(METADATA, GAME_COUNTER, int(latest))
table_state.commit()
def games_by_time(self, start_game, end_game):
"""Given a range of games, return the games sorted by time.
Returns [(time, game_number), ...]
The time will be a `datetime.datetime` and the game
number is the integer used as the basis of the row ID.
Note that when a cluster of self-play nodes are writing
concurrently, the game numbers may be out of order.
"""
move_count = b'move_count'
rows = self.bt_table.read_rows(
ROWCOUNT_PREFIX.format(start_game),
ROWCOUNT_PREFIX.format(end_game),
filter_=bigtable_row_filters.ColumnRangeFilter(
METADATA, move_count, move_count))
def parse(r):
rk = str(r.row_key, 'utf-8')
game = _game_from_counter.match(rk).groups()[0]
return (r.cells[METADATA][move_count][0].timestamp, game)
return sorted([parse(r) for r in rows], key=operator.itemgetter(0))
def delete_row_range(self, format_str, start_game, end_game):
"""Delete rows related to the given game range.
Args:
format_str: a string to `.format()` by the game numbers
in order to create the row prefixes.
start_game: the starting game number of the deletion.
end_game: the ending game number of the deletion.
"""
row_keys = make_single_array(
self.tf_table.keys_by_range_dataset(
format_str.format(start_game),
format_str.format(end_game)))
row_keys = list(row_keys)
if not row_keys:
utils.dbg('No rows left for games %d..%d' % (
start_game, end_game))
return
utils.dbg('Deleting %d rows: %s..%s' % (
len(row_keys), row_keys[0], row_keys[-1]))
# Reverse the keys so that the queue is left in a more
# sensible end state if you change your mind (say, due to a
# mistake in the timestamp) and abort the process: there will
# be a bit trimmed from the end, rather than a bit
# trimmed out of the middle.
row_keys.reverse()
total_keys = len(row_keys)
utils.dbg('Deleting total of %d keys' % total_keys)
concurrency = min(MAX_BT_CONCURRENCY,
multiprocessing.cpu_count() * 2)
with multiprocessing.Pool(processes=concurrency) as pool:
batches = []
with tqdm(desc='Keys', unit_scale=2, total=total_keys) as pbar:
for b in utils.iter_chunks(bigtable.row.MAX_MUTATIONS,
row_keys):
pbar.update(len(b))
batches.append((self.btspec, b))
if len(batches) >= concurrency:
pool.map(_delete_rows, batches)
batches = []
pool.map(_delete_rows, batches)
batches = []
def trim_games_since(self, t, max_games=500000):
"""Trim off the games since the given time.
Search back no more than max_games for this time point, locate
the game there, and remove all games since that game,
resetting the latest game counter.
If `t` is a `datetime.timedelta`, then the target time will be
found by subtracting that delta from the time of the last
game. Otherwise, it will be the target time.
"""
latest = self.latest_game_number
earliest = int(latest - max_games)
gbt = self.games_by_time(earliest, latest)
if not gbt:
utils.dbg('No games between %d and %d' % (earliest, latest))
return
most_recent = gbt[-1]
if isinstance(t, datetime.timedelta):
target = most_recent[0] - t
else:
target = t
i = bisect.bisect_right(gbt, (target,))
if i >= len(gbt):
utils.dbg('Last game is already at %s' % gbt[-1][0])
return
when, which = gbt[i]
utils.dbg('Most recent: %s %s' % most_recent)
utils.dbg(' Target: %s %s' % (when, which))
which = int(which)
self.delete_row_range(ROW_PREFIX, which, latest)
self.delete_row_range(ROWCOUNT_PREFIX, which, latest)
self.latest_game_number = which
def bleakest_moves(self, start_game, end_game):
"""Given a range of games, return the bleakest moves.
Returns a list of (game, move, q) sorted by q.
"""
bleak = b'bleakest_q'
rows = self.bt_table.read_rows(
ROW_PREFIX.format(start_game),
ROW_PREFIX.format(end_game),
filter_=bigtable_row_filters.ColumnRangeFilter(
METADATA, bleak, bleak))
def parse(r):
rk = str(r.row_key, 'utf-8')
g, m = _game_row_key.match(rk).groups()
q = r.cell_value(METADATA, bleak)
return int(g), int(m), float(q)
return sorted([parse(r) for r in rows], key=operator.itemgetter(2))
def require_fresh_games(self, number_fresh):
"""Require a given number of fresh games to be played.
Args:
number_fresh: integer, number of new fresh games needed
Increments the cell `table_state=metadata:wait_for_game_number`
by the given number of games. This will cause
`self.wait_for_fresh_games()` to block until the game
counter has reached this number.
"""
latest = self.latest_game_number
table_state = self.bt_table.row(TABLE_STATE)
table_state.set_cell(METADATA, WAIT_CELL, int(latest + number_fresh))
table_state.commit()
print("== Setting wait cell to ", int(latest + number_fresh), flush=True)
def wait_for_fresh_games(self, poll_interval=15.0):
"""Block caller until required new games have been played.
Args:
poll_interval: number of seconds to wait between checks
If the cell `table_state=metadata:wait_for_game_number` exists,
then block the caller, checking every `poll_interval` seconds,
until `table_state=metadata:game_counter is at least the value
in that cell.
"""
wait_until_game = self.read_wait_cell()
if not wait_until_game:
return
latest_game = self.latest_game_number
last_latest = latest_game
while latest_game < wait_until_game:
utils.dbg('Latest game {} not yet at required game {} '
'(+{}, {:0.3f} games/sec)'.format(
latest_game,
wait_until_game,
latest_game - last_latest,
(latest_game - last_latest) / poll_interval
))
time.sleep(poll_interval)
last_latest = latest_game
latest_game = self.latest_game_number
def read_wait_cell(self):
"""Read the value of the cell holding the 'wait' value,
Returns the int value of whatever it has, or None if the cell doesn't
exist.
"""
table_state = self.bt_table.read_row(
TABLE_STATE,
filter_=bigtable_row_filters.ColumnRangeFilter(
METADATA, WAIT_CELL, WAIT_CELL))
if table_state is None:
utils.dbg('No waiting for new games needed; '
'wait_for_game_number column not in table_state')
return None
value = table_state.cell_value(METADATA, WAIT_CELL)
if not value:
utils.dbg('No waiting for new games needed; '
'no value in wait_for_game_number cell '
'in table_state')
return None
return cbt_intvalue(value)
def count_moves_in_game_range(self, game_begin, game_end):
"""Count the total moves in a game range.
Args:
game_begin: integer, starting game
game_end: integer, ending game
Uses the `ct_` keyspace for rapid move summary.
"""
rows = self.bt_table.read_rows(
ROWCOUNT_PREFIX.format(game_begin),
ROWCOUNT_PREFIX.format(game_end),
filter_=bigtable_row_filters.ColumnRangeFilter(
METADATA, MOVE_COUNT, MOVE_COUNT))
return sum([int(r.cell_value(METADATA, MOVE_COUNT)) for r in rows])
def moves_from_games(self, start_game, end_game, moves, shuffle,
column_family, column):
"""Dataset of samples and/or shuffled moves from game range.
Args:
n: an integer indicating how many past games should be sourced.
moves: an integer indicating how many moves should be sampled
from those N games.
column_family: name of the column family containing move examples.
column: name of the column containing move examples.
shuffle: if True, shuffle the selected move examples.
Returns:
A dataset containing no more than `moves` examples, sampled
randomly from the last `n` games in the table.
"""
start_row = ROW_PREFIX.format(start_game)
end_row = ROW_PREFIX.format(end_game)
# NOTE: Choose a probability high enough to guarantee at least the
# required number of moves, by using a slightly lower estimate
# of the total moves, then trimming the result.
total_moves = self.count_moves_in_game_range(start_game, end_game)
probability = moves / (total_moves * 0.99)
utils.dbg('Row range: %s - %s; total moves: %d; probability %.3f; moves %d' % (
start_row, end_row, total_moves, probability, moves))
ds = self.tf_table.parallel_scan_range(start_row, end_row,
probability=probability,
columns=[(column_family, column)])
if shuffle:
utils.dbg('Doing a complete shuffle of %d moves' % moves)
ds = ds.shuffle(moves)
ds = ds.take(moves)
return ds
def moves_from_last_n_games(self, n, moves, shuffle,
column_family, column):
"""Randomly choose a given number of moves from the last n games.
Args:
n: number of games at the end of this GameQueue to source.
moves: number of moves to be sampled from `n` games.
shuffle: if True, shuffle the selected moves.
column_family: name of the column family containing move examples.
column: name of the column containing move examples.
Returns:
a dataset containing the selected moves.
"""
self.wait_for_fresh_games()
latest_game = self.latest_game_number
utils.dbg('Latest game in %s: %s' % (self.btspec.table, latest_game))
if latest_game == 0:
raise ValueError('Cannot find a latest game in the table')
start = int(max(0, latest_game - n))
ds = self.moves_from_games(start, latest_game, moves, shuffle,
column_family, column)
return ds
def _write_move_counts(self, sess, h):
"""Add move counts from the given histogram to the table.
Used to update the move counts in an existing table. Should
not be needed except for backfill or repair.
Args:
sess: TF session to use for doing a Bigtable write.
tf_table: TF Cloud Bigtable to use for writing.
h: a dictionary keyed by game row prefix ("g_0023561") whose values
are the move counts for each game.
"""
def gen():
for k, v in h.items():
# The keys in the histogram may be of type 'bytes'
k = str(k, 'utf-8')
vs = str(v)
yield (k.replace('g_', 'ct_') + '_%d' % v, vs)
yield (k + '_m_000', vs)
mc = tf.data.Dataset.from_generator(gen, (tf.string, tf.string))
wr_op = self.tf_table.write(mc,
column_families=[METADATA],
columns=[MOVE_COUNT])
sess.run(wr_op)
def update_move_counts(self, start_game, end_game, interval=1000):
"""Used to update the move_count cell for older games.
Should not be needed except for backfill or repair.
move_count cells will be updated in both g_<game_id>_m_000 rows
and ct_<game_id>_<move_count> rows.
"""
for g in range(start_game, end_game, interval):
with tf.Session() as sess:
start_row = ROW_PREFIX.format(g)
end_row = ROW_PREFIX.format(g + interval)
print('Range:', start_row, end_row)
start_time = time.time()
ds = self.tf_table.keys_by_range_dataset(start_row, end_row)
h = _histogram_move_keys_by_game(sess, ds)
self._write_move_counts(sess, h)
end_time = time.time()
elapsed = end_time - start_time
print(' games/sec:', len(h)/elapsed)
def set_fresh_watermark(game_queue, count_from, window_size,
fresh_fraction=0.05, minimum_fresh=20000):
"""Sets the metadata cell used to block until some quantity of games have been played.
This sets the 'freshness mark' on the `game_queue`, used to block training
until enough new games have been played. The number of fresh games required
is the larger of:
- The fraction of the total window size
- The `minimum_fresh` parameter
The number of games required can be indexed from the 'count_from' parameter.
Args:
game_queue: A GameQueue object, on whose backing table will be modified.
count_from: the index of the game to compute the increment from
window_size: an integer indicating how many past games are considered
fresh_fraction: a float in (0,1] indicating the fraction of games to wait for
minimum_fresh: an integer indicating the lower bound on the number of new
games.
"""
already_played = game_queue.latest_game_number - count_from
print("== already_played: ", already_played, flush=True)
if window_size > count_from: # How to handle the case when the window is not yet 'full'
game_queue.require_fresh_games(int(minimum_fresh * .9))
else:
num_to_play = max(0, math.ceil(window_size * .9 * fresh_fraction) - already_played)
print("== Num to play: ", num_to_play, flush=True)
game_queue.require_fresh_games(num_to_play)
def mix_by_decile(games, moves, deciles=9):
"""Compute a mix of regular and calibration games by decile.
deciles should be an integer between 0 and 10 inclusive.
"""
assert 0 <= deciles <= 10
# The prefixes and suffixes below have the following meanings:
# ct_: count
# fr_: fraction
# _r: resign (ordinary)
# _nr: no-resign
ct_total = 10
lesser = ct_total - math.floor(deciles)
greater = ct_total - lesser
ct_r, ct_nr = greater, lesser
fr_r = ct_r / ct_total
fr_nr = ct_nr / ct_total
games_r = math.ceil(games * fr_r)
moves_r = math.ceil(moves * fr_r)
games_c = math.floor(games * fr_nr)
moves_c = math.floor(moves * fr_nr)
selection = np.array([0] * ct_r + [1] * ct_nr, dtype=np.int64)
return GameMix(games_r, moves_r,
games_c, moves_c,
selection)
def get_unparsed_moves_from_last_n_games(games, games_nr, n,
moves=2**21,
shuffle=True,
column_family=TFEXAMPLE,
column='example',
values_only=True):
"""Get a dataset of serialized TFExamples from the last N games.
Args:
games, games_nr: GameQueues of the regular selfplay and calibration
(aka 'no resign') games to sample from.
n: an integer indicating how many past games should be sourced.
moves: an integer indicating how many moves should be sampled
from those N games.
column_family: name of the column family containing move examples.
column: name of the column containing move examples.
shuffle: if True, shuffle the selected move examples.
values_only: if True, return only column values, no row keys.
Returns:
A dataset containing no more than `moves` examples, sampled
randomly from the last `n` games in the table.
"""
mix = mix_by_decile(n, moves, 9)
resign = games.moves_from_last_n_games(
mix.games_r,
mix.moves_r,
shuffle,
column_family, column)
no_resign = games_nr.moves_from_last_n_games(
mix.games_c,
mix.moves_c,
shuffle,
column_family, column)
choice = tf.data.Dataset.from_tensor_slices(mix.selection).repeat().take(moves)
ds = tf.data.experimental.choose_from_datasets([resign, no_resign], choice)
if shuffle:
ds = ds.shuffle(len(mix.selection) * 2)
if values_only:
ds = ds.map(lambda row_name, s: s)
return ds
def get_unparsed_moves_from_games(games_r, games_c,
start_r, start_c,
mix,
shuffle=True,
column_family=TFEXAMPLE,
column='example',
values_only=True):
"""Get a dataset of serialized TFExamples from a given start point.
Args:
games_r, games_c: GameQueues of the regular selfplay and calibration
(aka 'no resign') games to sample from.
start_r: an integer indicating the game number to start at in games_r.
start_c: an integer indicating the game number to start at in games_c.
mix: the result of mix_by_decile()
shuffle: if True, shuffle the selected move examples.
column_family: name of the column family containing move examples.
column: name of the column containing move examples.
values_only: if True, return only column values, no row keys.
Returns:
A dataset containing no more than the moves implied by `mix`,
sampled randomly from the game ranges implied.
"""
resign = games_r.moves_from_games(
start_r, start_r + mix.games_r, mix.moves_r, shuffle, column_family, column)
calibrated = games_c.moves_from_games(
start_c, start_c + mix.games_c, mix.moves_c, shuffle, column_family, column)
moves = mix.moves_r + mix.moves_c
choice = tf.data.Dataset.from_tensor_slices(mix.selection).repeat().take(moves)
ds = tf.data.experimental.choose_from_datasets([resign, calibrated], choice)
if shuffle:
ds = ds.shuffle(len(mix.selection) * 2)
if values_only:
ds = ds.map(lambda row_name, s: s)
return ds
def count_elements_in_dataset(ds, batch_size=1*1024, parallel_batch=8):
"""Count and return all the elements in the given dataset.
Debugging function. The elements in a dataset cannot be counted
without enumerating all of them. By counting in batch and in
parallel, this method allows rapid traversal of the dataset.
Args:
ds: The dataset whose elements should be counted.
batch_size: the number of elements to count a a time.
parallel_batch: how many batches to count in parallel.
Returns:
The number of elements in the dataset.
"""
with tf.Session() as sess:
dsc = ds.apply(tf.data.experimental.enumerate_dataset())
dsc = dsc.apply(tf.data.experimental.map_and_batch(
lambda c, v: c, batch_size, num_parallel_batches=parallel_batch))
iterator = dsc.make_initializable_iterator()
sess.run(iterator.initializer)
get_next = iterator.get_next()
counted = 0
try:
while True:
# The numbers in the tensors are 0-based indicies,
# so add 1 to get the number counted.
counted = sess.run(tf.reduce_max(get_next)) + 1
utils.dbg('Counted so far: %d' % counted)
except tf.errors.OutOfRangeError:
pass
utils.dbg('Counted total: %d' % counted)
return counted