Skip to content

Commit

Permalink
✨ feat: save buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Nov 24, 2021
1 parent c758b65 commit 9f06174
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 58 deletions.
5 changes: 3 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

# 训练超参数
MODEL_FILE = f"data/model-{WIDTH}x{HEIGHT}#{N_IN_ROW}.h5"
BUFFER_FILE = f"data/buffer-{WIDTH}x{HEIGHT}#{N_IN_ROW}.h5"
LEARNING_RATE = 2e-3
MAX_EPISODE = 10000
REWARD_GAMMA = 0.99
BUFFER_LENGTH = 10000
ENTROPY_BETA = 0.01
BATCH_SIZE = 512
EPOCHS = 5
BATCH_SIZE = 256
EPOCHS = 10
CHECK_FREQ = 50
21 changes: 19 additions & 2 deletions play.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
from collections import deque

import h5py
import numpy as np
from board import Board
from mcts import MCTS

from board import Board
from config import *
from mcts import MCTS
from policy import PolicyValueModelResNet as PolicyValueModel
from policy import mean_policy_value_fn
from ui import GUI, HeadlessUI, TerminalUI
Expand Down Expand Up @@ -67,6 +68,22 @@ def end_episode(self, winner):
self.extend(play_data)
self.clear_cache()

def save(self, filename):
states, mcts_probs, rewards = zip(*self)
f = h5py.File(filename, "w")
f["states"] = states
f["mcts_probs"] = mcts_probs
f["rewards"] = rewards
f.close()

def load(self, filename):
f = h5py.File(filename, "r")
states = f["states"]
mcts_probs = f["mcts_probs"]
rewards = f["rewards"]
self.extend(zip(states, mcts_probs, rewards))
f.close()


class Player:
"""玩家基类"""
Expand Down
4 changes: 2 additions & 2 deletions policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self):
)

@tf.function
def call(self, inputs, training=None):
def call(self, inputs):
x = self.base_net(inputs)
policy = self.policy(x)
values = self.values(x)
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self):
)

@tf.function
def call(self, inputs, training=None):
def call(self, inputs):
x = inputs
x = self.preprocess(x)
x = tf.keras.layers.add([x, self.res_1(x)])
Expand Down
106 changes: 54 additions & 52 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import argparse
import gc
import os
import random
import time

import numpy as np
import tensorflow as tf

from policy import AlphaZeroError
from config import *
from play import Game, MCTSAlphaZeroPlayer, MCTSPlayer
from policy import AlphaZeroError
from ui import HeadlessUI


parser = argparse.ArgumentParser(description='Gomoku AlphaZero')
parser.add_argument('--resume', action='store_true', help='恢复模型继续训练')
parser = argparse.ArgumentParser(description="Gomoku AlphaZero")
parser.add_argument("--resume", action="store_true", help="恢复模型继续训练")
args = parser.parse_args()


class DataAugmentor():
""" 数据扩增器
对原数据进行旋转 + 对称,共八种扩增方式 """
class DataAugmentor:
"""数据增强器
对原数据进行旋转 + 对称,共八种增强方式"""

def __init__(self, rotate=True, flip=True):
self.rotate = rotate
Expand All @@ -40,8 +42,8 @@ def __call__(self, data_batch):
return data_batch_aug


class AlphaZeroMetric():
""" AlphaZero 性能评估器 """
class AlphaZeroMetric:
"""AlphaZero 性能评估器"""

def __init__(self, n_playout=400):
self.n_playout = n_playout
Expand All @@ -51,28 +53,27 @@ def __init__(self, n_playout=400):
def __call__(self, weights, episode=0, n_games=10):
assert n_games % 2 == 0

mcts_alphazero_player = MCTSAlphaZeroPlayer(
c_puct=5, n_playout=self.n_playout)
mcts_alphazero_player.model.build(
input_shape=(None, WIDTH, HEIGHT, CHANNELS))
mcts_alphazero_player = MCTSAlphaZeroPlayer(c_puct=5, n_playout=self.n_playout)
mcts_alphazero_player.model.build(input_shape=(None, WIDTH, HEIGHT, CHANNELS))
mcts_alphazero_player.model.set_weights(weights)
mcts_player = MCTSPlayer(c_puct=5, n_playout=self.n_playout_mcts)
game = Game(mcts_alphazero_player, mcts_player, HeadlessUI())
scores = {WIN: 0, LOSE: 0, TIE: 0}
score = 0.
score = 0.0
for idx in range(n_games):
winner = game.play(is_selfplay=False)
res = winner * mcts_alphazero_player.color
scores[res] += 1
game.switch_players()
print('[Testing] Episode: {:5d}, Game: {:2d}, Score: {:2d} '.format(
episode + 1, idx, res
), end='\r')
print("[Testing] Episode: {:5d}, Game: {:2d}, Score: {:2d} ".format(episode + 1, idx, res), end="\r")
for key in scores:
score += key * scores[key]
print('[Test] Episode: {:5d}, MCTS n_playout: {:6d}, Win: {:2d}, Lose: {:2d}, Tie: {:2d}, Score: {:.2f} '.format(
episode + 1, self.n_playout_mcts, scores[WIN], scores[LOSE], scores[TIE], score
))
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(
"[Test] Episode: {:5d}, MCTS n_playout: {:6d}, Win: {:2d}, Lose: {:2d}, Tie: {:2d}, Score: {:.2f} @{} ".format(
episode + 1, self.n_playout_mcts, scores[WIN], scores[LOSE], scores[TIE], score, now
)
)
if score > self.best_score:
self.best_score = score
if score == n_games:
Expand All @@ -82,73 +83,74 @@ def __call__(self, weights, episode=0, n_games=10):
return False


class Worker():

class Worker:
def __init__(self):
self.player = MCTSAlphaZeroPlayer(c_puct=5, n_playout=400)
self.model = self.player.model
self.model.build(input_shape=(None, WIDTH, HEIGHT, CHANNELS))
self.model.summary()
if args.resume:
self.model.load_weights(MODEL_FILE)

self.opt = tf.keras.optimizers.Adam(LEARNING_RATE)
self.loss_object = AlphaZeroError()
self.mean_loss = tf.keras.metrics.Mean(name='train_loss')
self.mean_loss = tf.keras.metrics.Mean(name="train_loss")
self.game = Game(self.player, self.player, HeadlessUI())
self.data_aug = DataAugmentor(rotate=True, flip=True)
self.metric = AlphaZeroMetric(n_playout=400)

if args.resume:
self.model.load_weights(MODEL_FILE)
print("Loaded model successfully.")
if os.path.exists(BUFFER_FILE):
self.game.data_buffer.load(BUFFER_FILE)
print("Loaded buffer ({} items) successfully.".format(len(self.game.data_buffer)))

def run(self):
for episode in range(MAX_EPISODE):
winner = self.game.play(is_selfplay=True)
gc.collect()

total_loss = tf.constant(0)

for epoch in range(EPOCHS):
mini_batch = random.sample(self.game.data_buffer, min(
BATCH_SIZE, len(self.game.data_buffer)//2))
mini_batch = random.sample(self.game.data_buffer, min(BATCH_SIZE, len(self.game.data_buffer) // 2))
mini_batch = self.data_aug(mini_batch)
states_batch = tf.convert_to_tensor(
[data[0] for data in mini_batch], dtype=tf.float32)
mcts_probs_batch = tf.convert_to_tensor(
[data[1] for data in mini_batch], dtype=tf.float32)
rewards_batch = tf.convert_to_tensor(np.expand_dims(
[data[2] for data in mini_batch], axis=-1), dtype=tf.float32)
states_batch, mcts_probs_batch, rewards_batch = zip(*mini_batch)
states_batch = tf.convert_to_tensor(states_batch, dtype=tf.float32)
mcts_probs_batch = tf.convert_to_tensor(mcts_probs_batch, dtype=tf.float32)
rewards_batch = tf.convert_to_tensor(np.expand_dims(rewards_batch, axis=-1), dtype=tf.float32)

with tf.GradientTape() as tape:
policy, values = self.model(states_batch, training=True)

total_loss = self.loss_object(
mcts_probs=mcts_probs_batch,
policy=policy,
rewards=rewards_batch,
values=values)
mcts_probs=mcts_probs_batch, policy=policy, rewards=rewards_batch, values=values
)

grads = tape.gradient(
total_loss, self.model.trainable_weights)
self.opt.apply_gradients(
zip(grads, self.model.trainable_weights))
grads = tape.gradient(total_loss, self.model.trainable_weights)
self.opt.apply_gradients(zip(grads, self.model.trainable_weights))

self.mean_loss(total_loss)

print('[Training] Episode: {:5d}, Epoch: {:2d}, Winner: {:5s}, Loss: {} '.format(
episode+1,
epoch+1,
COLOR[winner],
total_loss.numpy()), end='\r')
print(
"[Training] Episode: {:5d}, Epoch: {:2d}, Winner: {:5s}, Loss: {} ".format(
episode + 1, epoch + 1, COLOR[winner], total_loss.numpy()
),
end="\r",
)

if (episode + 1) % CHECK_FREQ == 0:
print('[Train] Episode: {:5d}, Loss: {} '.format(
episode+1,
self.mean_loss.result()
))
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(
"[Train] Episode: {:5d}, Loss: {} @{} ".format(
episode + 1, self.mean_loss.result(), now
),
)
self.mean_loss.reset_states()
is_best_score = self.metric(self.model.get_weights(), episode)
if is_best_score:
self.model.save_weights(MODEL_FILE)
self.game.data_buffer.save(BUFFER_FILE)


if __name__ == '__main__':
if __name__ == "__main__":
worker = Worker()
worker.run()

0 comments on commit 9f06174

Please sign in to comment.