Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a typo in Coach.py #322

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ checkpoints/
# For PyCharm users
.idea/

# environment
myenv/
env/

# handy tests
test.py

*.ipynb
activate.sh
.ipynb_checkpoints/
1 change: 1 addition & 0 deletions Arena.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def playGame(self, verbose=False):
assert self.display
print("Game over: Turn ", str(it), "Result ", str(self.game.getGameEnded(board, 1)))
self.display(board)

return curPlayer * self.game.getGameEnded(board, curPlayer)

def playGames(self, num, verbose=False):
Expand Down
42 changes: 27 additions & 15 deletions Coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections import deque
from pickle import Pickler, Unpickler
from random import shuffle

from concurrent.futures import ThreadPoolExecutor
import concurrent
import numpy as np
from tqdm import tqdm

Expand All @@ -23,7 +24,7 @@ class Coach():
def __init__(self, game, nnet, args):
self.game = game
self.nnet = nnet
self.pnet = self.nnet.__class__(self.game) # the competitor network
self.pnet = self.nnet.__class__(self.game, input_channels = self.nnet.args.input_channels, num_channels = self.nnet.args.num_channels) # the competitor network
self.args = args
self.mcts = MCTS(self.game, self.nnet, self.args)
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
Expand All @@ -41,32 +42,37 @@ def executeEpisode(self):
uses temp=0.

Returns:
trainExamples: a list of examples of the form (canonicalBoard, currPlayer, pi,v)
trainExamples: a list of examples of the form (canonicalBoard, pi, v)
pi is the MCTS informed policy vector, v is +1 if
the player eventually won the game, else -1.
"""
trainExamples = []
board = self.game.getInitBoard()
self.curPlayer = 1
curPlayer = 1
episodeStep = 0

mcts = MCTS(self.game, self.nnet, self.args) # reset search tree TODO: do we really need to reset?

while True:
episodeStep += 1
canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
canonicalBoard = self.game.getCanonicalForm(board, curPlayer)
temp = int(episodeStep < self.args.tempThreshold)

pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
pi = mcts.getActionProb(canonicalBoard, temp=temp)
sym = self.game.getSymmetries(canonicalBoard, pi)
for b, p in sym:
trainExamples.append([b, self.curPlayer, p, None])
trainExamples.append([b, curPlayer, p, None])

action = np.random.choice(len(pi), p=pi)
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
board, curPlayer = self.game.getNextState(board, curPlayer, action)

r = self.game.getGameEnded(board, self.curPlayer)
r = self.game.getGameEnded(board, curPlayer)

if r != 0:
return [(x[0], x[2], r * ((-1) ** (x[1] != self.curPlayer))) for x in trainExamples]
if r == 2:
# game draw. We did not collect any rewards.
# shall we drop these training examples?
r = 0
return [(x[0], x[2], r * ((-1) ** (x[1] != curPlayer))) for x in trainExamples]

def learn(self):
"""
Expand All @@ -84,10 +90,15 @@ def learn(self):
if not self.skipFirstSelfPlay or i > 1:
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

for _ in tqdm(range(self.args.numEps), desc="Self Play"):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
iterationTrainExamples += self.executeEpisode()
with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor:
# Launch async simulations
futures = [executor.submit(self.executeEpisode) for _ in range(self.args.numEps)]

with tqdm(total=self.args.numEps, desc=f"Self Play with {self.args.num_workers} workers") as pbar:
for future in concurrent.futures.as_completed(futures):
iterationTrainExamples += future.result()
pbar.update(1)

# save the iteration examples to the history
self.trainExamplesHistory.append(iterationTrainExamples)

Expand Down Expand Up @@ -119,7 +130,7 @@ def learn(self):
pwins, nwins, draws = arena.playGames(self.args.arenaCompare)

log.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws))
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold or (pwins + nwins) < self.args.arenaCompare * 0.15:
log.info('REJECTING NEW MODEL')
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
else:
Expand All @@ -135,6 +146,7 @@ def saveTrainExamples(self, iteration):
if not os.path.exists(folder):
os.makedirs(folder)
filename = os.path.join(folder, self.getCheckpointFile(iteration) + ".examples")
log.info(f"saving train examples: {filename}")
with open(filename, "wb+") as f:
Pickler(f).dump(self.trainExamplesHistory)
f.closed
Expand Down
111 changes: 73 additions & 38 deletions MCTS.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import math
import sys


import numpy as np

Expand All @@ -24,33 +26,46 @@ def __init__(self, game, nnet, args):

self.Es = {} # stores game.getGameEnded ended for board s
self.Vs = {} # stores game.getValidMoves for board s

def getActionProb(self, canonicalBoard, temp=1):
"""
This function performs numMCTSSims simulations of MCTS starting from
canonicalBoard.
Performs MCTS simulations starting from canonicalBoard, for numMCTSSims times

Returns:
probs: a policy vector where the probability of the ith action is
proportional to Nsa[(s,a)]**(1./temp)
"""
for i in range(self.args.numMCTSSims):

for _ in range(self.args.numMCTSSims):
self.search(canonicalBoard)

# Comput action probabilities
s = self.game.stringRepresentation(canonicalBoard)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

counts = np.array(
[self.Nsa.get((s, a), 0) for a in range(self.game.getActionSize())],
dtype=np.float32
)

if 'verbose' in self.args and self.args.verbose == 1:
total_counts = counts.sum()
probs = counts.reshape(canonicalBoard.shape)
MCTS.display(probs)
MCTS.display(probs / (total_counts + EPS))
s = self.game.stringRepresentation(canonicalBoard)
probs = np.array(self.Ps[s]).reshape(canonicalBoard.shape)
MCTS.display(probs)

if temp == 0:
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
bestA = np.random.choice(bestAs)
probs = [0] * len(counts)
bestA = np.random.choice(np.flatnonzero(counts == counts.max()))
probs = np.zeros_like(counts, dtype=np.float32)
probs[bestA] = 1
return probs

counts = [x ** (1. / temp) for x in counts]
counts_sum = float(sum(counts))
probs = [x / counts_sum for x in counts]
return probs
else:
counts = counts ** (1. / temp)
probs = counts / (counts.sum() + EPS)
return probs

def search(self, canonicalBoard):
"""
Expand All @@ -74,17 +89,24 @@ def search(self, canonicalBoard):

s = self.game.stringRepresentation(canonicalBoard)

# Check terminal state
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)

if self.Es[s] != 0:
# terminal node
if self.Es[s] == 2:
# draw
return 0
return -self.Es[s]

# Expand the leaf node
if s not in self.Ps:
# leaf node

self.Ps[s], v = self.nnet.predict(canonicalBoard)

valids = self.game.getValidMoves(canonicalBoard, 1)
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
self.Ps[s] *= valids # masking invalid moves
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # renormalize
Expand All @@ -94,43 +116,56 @@ def search(self, canonicalBoard):
# NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
# If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.
log.error("All valid moves were masked, doing a workaround.")
self.Ps[s] = self.Ps[s] + valids
self.Ps[s] /= np.sum(self.Ps[s])
self.Ps[s] = valids / valids.sum()

self.Vs[s] = valids
self.Ns[s] = 0
return -v

valids = self.Vs[s]
cur_best = -float('inf')
best_act = -1
sqrt_Ns = math.sqrt(self.Ns[s] + EPS)

# Vectorized UCB calculation
ucb_values = np.array([
self.Qsa.get((s, a), 0) +
self.args.cpuct * self.Ps[s][a] * sqrt_Ns / (1 + self.Nsa.get((s, a), 0))
if valids[a] else -float('inf')
for a in range(self.game.getActionSize())
])

# pick the action with the highest upper confidence bound
for a in range(self.game.getActionSize()):
if valids[a]:
if (s, a) in self.Qsa:
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
1 + self.Nsa[(s, a)])
else:
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) # Q = 0 ?

if u > cur_best:
cur_best = u
best_act = a

a = best_act
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
best_act = np.argmax(ucb_values)
next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act)
next_s = self.game.getCanonicalForm(next_s, next_player)

v = self.search(next_s)

if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
if (s, best_act) in self.Qsa:
self.Qsa[(s, best_act)] = (self.Nsa[(s, best_act)] * self.Qsa[(s, best_act)] + v) / (self.Nsa[(s, best_act)] + 1)
self.Nsa[(s, best_act)] += 1

else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Qsa[(s, best_act)] = v
self.Nsa[(s, best_act)] = 1

self.Ns[s] += 1
return -v


@staticmethod
def display(board):
n = board.shape[0]
print(" ", end="")
for y in range(n):
print(y, end=" ")
print("")
print("-----------------------")
for y in range(n):
print(y, "|", end="") # print the row #
for x in range(n):
piece = board[x][y] # get the piece to print
print(f"{piece:.2f}", end=" ")
print("|")

print("-----------------------")

30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
# Enhanced Alpha Zero General framework

This repo is forked from the [Alpha Zero General](https://github.com/suragnair/alpha-zero-general) repo with following enhancements.

- Mac Mini GPU support
- Split the board into two planes, one for black stones and one for white, to align with the AlphaGo Zero paper.
- Added gomoku implementation.


Run unit test


TODO

- Use Numba to optimize the speed on CPU.
- Implement asynchronized MCTS.
- Fine tune the Gomoku training algorithm. Some exploration directions:
- Increase the simulation steps
- Build a deeper CNN



Folling content is from the original README file

# Alpha Zero General (any game, any framework!)
A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play based reinforcement learning based on the AlphaGo Zero paper (Silver et al). It is designed to be easy to adopt for any two-player turn-based adversarial game and any deep learning framework of your choice. A sample implementation has been provided for the game of Othello in PyTorch and Keras. An accompanying tutorial can be found [here](https://suragnair.github.io/posts/alphazero.html). We also have implementations for many other games like GoBang and TicTacToe.

Expand Down Expand Up @@ -40,6 +64,12 @@ If you found this work useful, feel free to cite it as
}
```

### Testing

```
pytest
```

### Contributing
While the current code is fairly functional, we could benefit from the following contributions:
* Game logic files for more games that follow the specifications in ```Game.py```, along with their neural networks
Expand Down
Loading