Skip to content

Commit

Permalink
Fix storage of training data.
Browse files Browse the repository at this point in the history
  • Loading branch information
donkirkby committed Nov 9, 2023
1 parent 967b3cd commit b38c432
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 8 deletions.
4 changes: 4 additions & 0 deletions docs/journal/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ positive / negative position data, as well as the one-hot position data.
It looks like we don't get much improvement past 50,000 positions, and I don't
see much difference between the two different data formats.

I tried changing the think time for each move from 0.1s to 2.0s and generated
50,000 positions. It took 28 hours, and the model had a validation loss of
0.161 - worse than the quick searches!

[Training with positive/negative]: 2023/training-pos-neg.png
[Training with one hot]: 2023/training-one-hot.png
[Training +/- on 100,000]: 2023/training-100_000-pos-neg.png
Expand Down
2 changes: 1 addition & 1 deletion zero_play/connect4/neural_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def train(self, boards: np.ndarray, outputs: np.ndarray):

self.checkpoint_name += ' + training'

callbacks = [EarlyStopping(patience=5)]
callbacks = [EarlyStopping(patience=10)]

history = self.model.fit(
np.expand_dims(boards, -1),
Expand Down
3 changes: 0 additions & 3 deletions zero_play/play_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def play(self, games: int = 1, flip: bool = False, display: bool = False):
original_o.player_number = o_number
self.players[x_number] = original_x
self.players[o_number] = original_o
for player_results in self.results:
print(player_results.get_summary())
print(ties, 'ties')
x_results = self.get_player_results(original_x)
o_results = self.get_player_results(original_o)

Expand Down
11 changes: 8 additions & 3 deletions zero_play/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,24 @@ def train(search_milliseconds: int,
milliseconds=search_milliseconds,
data_size=training_size)

boards_df = pd.DataFrame.from_records(boards)
flattened_boards = boards.reshape(
training_size,
start_state.board_height*start_state.board_width)
boards_df = pd.DataFrame.from_records(flattened_boards)
outputs_df = pd.DataFrame.from_records(outputs)
boards_df.to_csv(boards_path)
outputs_df.to_csv(outputs_path)

boards = boards.reshape(training_size, 6, 7)
boards = boards.reshape(training_size,
start_state.board_height,
start_state.board_width)

start = datetime.now()
filename = f'checkpoint-{i:02d}.h5'
logger.info('Training for %s.', filename)
history = training_net.train(boards, outputs)
training_time = datetime.now() - start
print(f'Trained for {training_time}.')
logger.info(f'Trained for {training_time}.')

if is_reprocessing:
plot_loss(history)
Expand Down
6 changes: 5 additions & 1 deletion zero_play/zero_play.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging
import logging.config
import math
import os
import sys
Expand Down Expand Up @@ -64,6 +64,10 @@
DEFAULT_SEARCH_MILLISECONDS = 500
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s:%(name)s: %(message)s")
logging.config.dictConfig(dict(
version=1,
incremental=True,
loggers={'zero_play.mcts_player': dict(level=logging.INFO)}))


class AboutDialog(QDialog):
Expand Down

0 comments on commit b38c432

Please sign in to comment.