Skip to content

Commit

Permalink
Add OpenAI Gym support
Browse files Browse the repository at this point in the history
  • Loading branch information
Wastack committed Oct 19, 2019
1 parent 654e3ae commit c0964e7
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ nosetests.xml

# saved games
hivegame/saved_games/*
hivegame/model_saved/*

# AI results
hivegame/AI/temp/*
2 changes: 1 addition & 1 deletion hivegame/AI/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def getGameEnded_original(self, board, player_num):
else:
raise ValueError('Unexpected game status')

def getValidMoves(self, board, player_num):
def getValidMoves(self, board, player_num) -> List[int]:
hive = Hive.load_state_with_player(board, self._player_to_inner_player(player_num))
return represent.get_all_action_vector(hive)

Expand Down
59 changes: 59 additions & 0 deletions hivegame/AI/gym/HiveEnv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from gym import Env
from gym.spaces import Box

from AI.environment import Environment
from AI.gym.HiveSpace import HiveActionSpace
from AI.random_player import RandomPlayer
from hive_utils import GameStatus
from hivegame.hive import Hive
from hivegame import hive_representation as represent
import numpy as np

class HiveEnv(Env):
def _state(self):
return represent.string_representation(represent.two_dim_representation(represent.get_adjacency_state(self.env.hive)))

def __init__(self):
self.reward_range = (-1., 1.)
self.env = Environment()
self.action_space = HiveActionSpace(self.env.hive)
self.observation_space = Box(low=0, high=9, shape= (12, 11), dtype=np.int32)

# opponent
self.opponent = RandomPlayer()

def reset(self):
self.env = Environment()
self.action_space = HiveActionSpace(self.env.hive)
return self._state()

def _reward(self) -> (float, bool):
reward = 0.
if self.env.hive.check_victory() == GameStatus.WHITE_WIN:
reward += 1.
elif self.env.hive.check_victory() == GameStatus.BLACK_WIN:
reward -= 1.
done = self.env.hive.check_victory() != GameStatus.UNFINISHED
return reward, done

def step(self, action: int):
inner_action = self.env.hive.action_from_vector(action)
self.env.hive.action_piece_to(*inner_action)
(reward, done) = self._reward()
if not done:
# opponent's turn
# Let him play until I have available moves (pass)
passed = True
while passed:
passed = False
opponent_action = self.opponent.step(self.env)
if opponent_action == 'pass':
self.env.hive.current_player = self.env.hive._toggle_player(self.env.current_player)
return self._state(), reward, done, {}
self.env.hive.action_piece_to(*opponent_action)
(reward, done) = self._reward()
if not self.env.get_all_possible_actions():
self.env.hive.current_player = self.env.hive._toggle_player(self.env.current_player)
passed = True
return self._state(), reward, done, {}
return self._state(), reward, done, {}
32 changes: 32 additions & 0 deletions hivegame/AI/gym/HiveSpace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import random

from gym.spaces import Discrete, Box

from AI.environment import Environment
from hive import Hive
from hivegame import hive_representation as represent


class HiveActionSpace(Discrete):
def _val_indices(self):
# TODO currently only white player supported (which is indicated by 1)
val_moves = represent.get_all_action_vector(self.env.hive)
return [i for i, v in enumerate(val_moves) if v > 0]

def __init__(self, hive: Hive):
self.env = Environment()
self.env.hive = hive
super().__init__(self.env.getActionSize())

def sample(self):
val_indices = self._val_indices()
if not val_indices:
raise RuntimeError("Player is not able to move")
return random.choice(val_indices)

def contains(self, x):
if not super().contains(x):
return False
if x not in self._val_indices():
return False
return True
44 changes: 44 additions & 0 deletions hivegame/elaborate_AI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys, os

import logging

from AI.alpha_player import AlphaPlayer
from AI.environment import Environment
from AI.random_player import RandomPlayer
from AI.utils.keras.NNet import NNetWrapper
from arena import Arena
from hive_utils import dotdict
from project import ROOT_DIR

args = dotdict({
'numIters': 28,
'numEps': 7,
'tempThreshold': 15,
'updateThreshold': 0.5,
'maxlenOfQueue': 200000,
'numMCTSSims': 2,
'arenaCompare': 40,
'cpuct': 0.3,

'checkpoint': './temp/',
'load_model': False,
'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'),
'numItersForTrainExamplesHistory': 20,

})

def main():
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
env = Environment()
nnet = NNetWrapper(env)
nnet.load_model(folder=os.path.join(ROOT_DIR, 'model_saved'), filename='model.h5')
alphaPlayer = AlphaPlayer(env, nnet, args)
randomPlayer = RandomPlayer()
arena = Arena(alphaPlayer, randomPlayer, env)
alpha_wins, random_wins, draws = arena.playGames(5000)
print("aplha won: {} times, random won: {} times, number of draws: {}".format(alpha_wins, random_wins, draws))


if __name__ == '__main__':
sys.exit(main())
1 change: 0 additions & 1 deletion hivegame/hive_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def get_all_possible_actions(hive: 'Hive') -> Set[Tuple[HivePiece, hexutil.Hex]]
the target location of the action.
"""
result = set()
print(hive)

# choose the current players played pieces
my_pieces = hive.level.get_played_pieces(hive.current_player)
Expand Down
2 changes: 2 additions & 0 deletions hivegame/pieces/ant_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def available_moves_vector(self, hive: 'Hive', pos: hexutil.Hex):
"""
It assumes that the ant can step onto a maximum of pre-specifihive.locate('wA1')ed number of cells
"""
if self.check_blocked(hive, pos):
return [0] * self.move_vector_size
available_moves_count = len(self.available_moves(hive, pos))
assert available_moves_count < AntPiece.MAX_STEP_COUNT
result = [1] * available_moves_count + [0] * (AntPiece.MAX_STEP_COUNT - available_moves_count)
Expand Down
10 changes: 4 additions & 6 deletions hivegame/pieces/beetle_piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def available_moves(self, hive: 'Hive', pos: hexutil.Hex):

def available_moves_vector(self, hive: 'Hive', pos: hexutil.Hex):
if self.check_blocked(hive, pos):
return [0] * 6
return [0] * self.move_vector_size

result = []
aval_moves = self.available_moves(hive, pos)
Expand All @@ -60,8 +60,6 @@ def __repr__(self):
return "%s%s%s" % (self.color, "B", self.number)

def index_to_target_cell(self, hive: 'Hive', number: int, pos: hexutil.Hex):
aval_moves = self.available_moves(hive, pos)
# index of available moves, starting from 0
num_in_list = sum(self.available_moves_vector(hive, pos)[:number]) - 1
assert len(aval_moves) > num_in_list
return aval_moves[num_in_list]
aval_indexes = (i for i, v in enumerate(self.available_moves_vector(hive, pos)) if v > 0)
assert number in aval_indexes
return pos.neighbours()[number]
3 changes: 3 additions & 0 deletions hivegame/pieces/piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def available_moves_vector(self, hive: 'Hive', pos: hexutil.Hex):
def index_to_target_cell(self, hive: 'Hive', number: int, pos: hexutil.Hex) -> 'hexutil.Hex':
aval_moves = self.available_moves(hive, pos)
if len(aval_moves) <= number or number >= self.move_vector_size:
print(self)
print("check_blocked: {}".format(self.check_blocked(hive, pos)))
print("len aval moves: {}, number: {}, move_vector_size: {}".format(len(aval_moves), number, self.move_vector_size))
raise HiveException("moving piece with action number is out of bounds", 10001)
return aval_moves[number]

Expand Down
60 changes: 16 additions & 44 deletions hivegame/playground.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,16 @@
import sys, os

import logging

from AI.alpha_player import AlphaPlayer
from AI.environment import Environment
from AI.random_player import RandomPlayer
from AI.utils.keras.NNet import NNetWrapper
from arena import Arena
from hive_utils import dotdict
from project import ROOT_DIR

args = dotdict({
'numIters': 28,
'numEps': 7,
'tempThreshold': 15,
'updateThreshold': 0.5,
'maxlenOfQueue': 200000,
'numMCTSSims': 2,
'arenaCompare': 40,
'cpuct': 0.3,

'checkpoint': './temp/',
'load_model': False,
'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'),
'numItersForTrainExamplesHistory': 20,

})

def main():
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
env = Environment()
nnet = NNetWrapper(env)
nnet.load_model(folder=os.path.join(ROOT_DIR, 'model_saved'), filename='model.h5')
alphaPlayer = AlphaPlayer(env, nnet, args)
randomPlayer = RandomPlayer()
arena = Arena(alphaPlayer, randomPlayer, env)
alpha_wins, random_wins, draws = arena.playGames(5000)
print("aplha won: {} times, random won: {} times, number of draws: {}".format(alpha_wins, random_wins, draws))


if __name__ == '__main__':
sys.exit(main())
import gym

from AI.gym.HiveEnv import HiveEnv

env = HiveEnv()
for i_episode in range(20):
observation = env.reset()
for t in range(100):
#env.render()
#print(observation)
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
if done:
print("Episode finished after {} timesteps".format(t+1))
break
env.close()

0 comments on commit c0964e7

Please sign in to comment.