-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
162 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ nosetests.xml | |
|
||
# saved games | ||
hivegame/saved_games/* | ||
hivegame/model_saved/* | ||
|
||
# AI results | ||
hivegame/AI/temp/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from gym import Env | ||
from gym.spaces import Box | ||
|
||
from AI.environment import Environment | ||
from AI.gym.HiveSpace import HiveActionSpace | ||
from AI.random_player import RandomPlayer | ||
from hive_utils import GameStatus | ||
from hivegame.hive import Hive | ||
from hivegame import hive_representation as represent | ||
import numpy as np | ||
|
||
class HiveEnv(Env): | ||
def _state(self): | ||
return represent.string_representation(represent.two_dim_representation(represent.get_adjacency_state(self.env.hive))) | ||
|
||
def __init__(self): | ||
self.reward_range = (-1., 1.) | ||
self.env = Environment() | ||
self.action_space = HiveActionSpace(self.env.hive) | ||
self.observation_space = Box(low=0, high=9, shape= (12, 11), dtype=np.int32) | ||
|
||
# opponent | ||
self.opponent = RandomPlayer() | ||
|
||
def reset(self): | ||
self.env = Environment() | ||
self.action_space = HiveActionSpace(self.env.hive) | ||
return self._state() | ||
|
||
def _reward(self) -> (float, bool): | ||
reward = 0. | ||
if self.env.hive.check_victory() == GameStatus.WHITE_WIN: | ||
reward += 1. | ||
elif self.env.hive.check_victory() == GameStatus.BLACK_WIN: | ||
reward -= 1. | ||
done = self.env.hive.check_victory() != GameStatus.UNFINISHED | ||
return reward, done | ||
|
||
def step(self, action: int): | ||
inner_action = self.env.hive.action_from_vector(action) | ||
self.env.hive.action_piece_to(*inner_action) | ||
(reward, done) = self._reward() | ||
if not done: | ||
# opponent's turn | ||
# Let him play until I have available moves (pass) | ||
passed = True | ||
while passed: | ||
passed = False | ||
opponent_action = self.opponent.step(self.env) | ||
if opponent_action == 'pass': | ||
self.env.hive.current_player = self.env.hive._toggle_player(self.env.current_player) | ||
return self._state(), reward, done, {} | ||
self.env.hive.action_piece_to(*opponent_action) | ||
(reward, done) = self._reward() | ||
if not self.env.get_all_possible_actions(): | ||
self.env.hive.current_player = self.env.hive._toggle_player(self.env.current_player) | ||
passed = True | ||
return self._state(), reward, done, {} | ||
return self._state(), reward, done, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import random | ||
|
||
from gym.spaces import Discrete, Box | ||
|
||
from AI.environment import Environment | ||
from hive import Hive | ||
from hivegame import hive_representation as represent | ||
|
||
|
||
class HiveActionSpace(Discrete): | ||
def _val_indices(self): | ||
# TODO currently only white player supported (which is indicated by 1) | ||
val_moves = represent.get_all_action_vector(self.env.hive) | ||
return [i for i, v in enumerate(val_moves) if v > 0] | ||
|
||
def __init__(self, hive: Hive): | ||
self.env = Environment() | ||
self.env.hive = hive | ||
super().__init__(self.env.getActionSize()) | ||
|
||
def sample(self): | ||
val_indices = self._val_indices() | ||
if not val_indices: | ||
raise RuntimeError("Player is not able to move") | ||
return random.choice(val_indices) | ||
|
||
def contains(self, x): | ||
if not super().contains(x): | ||
return False | ||
if x not in self._val_indices(): | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import sys, os | ||
|
||
import logging | ||
|
||
from AI.alpha_player import AlphaPlayer | ||
from AI.environment import Environment | ||
from AI.random_player import RandomPlayer | ||
from AI.utils.keras.NNet import NNetWrapper | ||
from arena import Arena | ||
from hive_utils import dotdict | ||
from project import ROOT_DIR | ||
|
||
args = dotdict({ | ||
'numIters': 28, | ||
'numEps': 7, | ||
'tempThreshold': 15, | ||
'updateThreshold': 0.5, | ||
'maxlenOfQueue': 200000, | ||
'numMCTSSims': 2, | ||
'arenaCompare': 40, | ||
'cpuct': 0.3, | ||
|
||
'checkpoint': './temp/', | ||
'load_model': False, | ||
'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'), | ||
'numItersForTrainExamplesHistory': 20, | ||
|
||
}) | ||
|
||
def main(): | ||
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" | ||
logging.basicConfig(level=logging.DEBUG, format=FORMAT) | ||
env = Environment() | ||
nnet = NNetWrapper(env) | ||
nnet.load_model(folder=os.path.join(ROOT_DIR, 'model_saved'), filename='model.h5') | ||
alphaPlayer = AlphaPlayer(env, nnet, args) | ||
randomPlayer = RandomPlayer() | ||
arena = Arena(alphaPlayer, randomPlayer, env) | ||
alpha_wins, random_wins, draws = arena.playGames(5000) | ||
print("aplha won: {} times, random won: {} times, number of draws: {}".format(alpha_wins, random_wins, draws)) | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,16 @@ | ||
import sys, os | ||
|
||
import logging | ||
|
||
from AI.alpha_player import AlphaPlayer | ||
from AI.environment import Environment | ||
from AI.random_player import RandomPlayer | ||
from AI.utils.keras.NNet import NNetWrapper | ||
from arena import Arena | ||
from hive_utils import dotdict | ||
from project import ROOT_DIR | ||
|
||
args = dotdict({ | ||
'numIters': 28, | ||
'numEps': 7, | ||
'tempThreshold': 15, | ||
'updateThreshold': 0.5, | ||
'maxlenOfQueue': 200000, | ||
'numMCTSSims': 2, | ||
'arenaCompare': 40, | ||
'cpuct': 0.3, | ||
|
||
'checkpoint': './temp/', | ||
'load_model': False, | ||
'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'), | ||
'numItersForTrainExamplesHistory': 20, | ||
|
||
}) | ||
|
||
def main(): | ||
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" | ||
logging.basicConfig(level=logging.DEBUG, format=FORMAT) | ||
env = Environment() | ||
nnet = NNetWrapper(env) | ||
nnet.load_model(folder=os.path.join(ROOT_DIR, 'model_saved'), filename='model.h5') | ||
alphaPlayer = AlphaPlayer(env, nnet, args) | ||
randomPlayer = RandomPlayer() | ||
arena = Arena(alphaPlayer, randomPlayer, env) | ||
alpha_wins, random_wins, draws = arena.playGames(5000) | ||
print("aplha won: {} times, random won: {} times, number of draws: {}".format(alpha_wins, random_wins, draws)) | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(main()) | ||
import gym | ||
|
||
from AI.gym.HiveEnv import HiveEnv | ||
|
||
env = HiveEnv() | ||
for i_episode in range(20): | ||
observation = env.reset() | ||
for t in range(100): | ||
#env.render() | ||
#print(observation) | ||
action = env.action_space.sample() | ||
observation, reward, done, info = env.step(action) | ||
if done: | ||
print("Episode finished after {} timesteps".format(t+1)) | ||
break | ||
env.close() |