diff --git a/tests/connect4/test_connect4_game.py b/tests/connect4/test_connect4_game.py index 67b3e25..7c6f6d7 100644 --- a/tests/connect4/test_connect4_game.py +++ b/tests/connect4/test_connect4_game.py @@ -5,15 +5,10 @@ def test_create_board(): - expected_spaces = np.array([[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0]]) + expected_spaces = np.zeros((2, 6, 7)) board = Connect4State() - assert np.array_equal(board.get_spaces(), expected_spaces) + assert np.array_equal(board.spaces, expected_spaces) # noinspection DuplicatedCode @@ -26,15 +21,22 @@ def test_create_board_from_text(): ....X.. ...XO.. """ - expected_spaces = np.array([[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, -1, 0, 0]]) + expected_spaces = np.array([[[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0]] + ]) board = Connect4State(text) - assert np.array_equal(board.get_spaces(), expected_spaces) + assert np.array_equal(board.spaces, expected_spaces) # noinspection DuplicatedCode @@ -48,25 +50,26 @@ def test_create_board_with_coordinates(): ....X.. ...XO.. """ - expected_board = np.array([[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, -1, 0, 0]]) + expected_spaces = np.array([[[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0]] + ]) board = Connect4State(text) - assert np.array_equal(board.get_spaces(), expected_board) + assert np.array_equal(board.spaces, expected_spaces) # noinspection DuplicatedCode def test_display(): - spaces = np.array([[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, -1, 0, 0]]) expected_text = """\ ....... ....... @@ -75,19 +78,13 @@ def test_display(): ....X.. ...XO.. """ - text = Connect4State(spaces=spaces).display() + text = Connect4State(expected_text).display() assert text == expected_text # noinspection DuplicatedCode def test_display_coordinates(): - spaces = np.array([[0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, -1, 0, 0]]) expected_text = """\ 1234567 ....... @@ -97,7 +94,7 @@ def test_display_coordinates(): ....X.. ...XO.. """ - text = Connect4State(spaces=spaces).display(show_coordinates=True) + text = Connect4State(expected_text).display(show_coordinates=True) assert expected_text == text @@ -306,6 +303,7 @@ def test_longer_winner(): def test_vertical_winner(): text = """\ ....... +....... .....O. .....O. ....XO. @@ -321,6 +319,7 @@ def test_vertical_winner(): def test_diagonal1_winner(): text = """\ ....... +....... ..O.... ..XO... ..OXO.. @@ -335,6 +334,7 @@ def test_diagonal1_winner(): def test_diagonal2_winner(): text = """\ +....... ......X .....XO ..XOXOX diff --git a/tests/connect4/training_data.json b/tests/connect4/training_data.json index 37b42d9..1f9abab 100644 --- a/tests/connect4/training_data.json +++ b/tests/connect4/training_data.json @@ -1,50 +1,75 @@ -[[[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, 0]], - [[0, -1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, 0, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 0, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, 0, 0, 0, -1, 0], [-1, -1, 1, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [0, 1, 0, 0, 0, 1, 0], [-1, 1, -1, 0, 0, -1, 0], [-1, -1, 1, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 1, 0], [-1, 1, -1, 0, 0, -1, 0], [-1, -1, 1, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]], - [[0, -1, 0, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 0], [-1, 1, -1, 0, 0, -1, 0], [-1, -1, 1, 0, 0, 1, 0], [1, 1, -1, 0, 0, 1, 1], [1, -1, -1, 0, 0, -1, -1]]], - [[0.925, 0.0125, 0.0125, 0.0125, 0.0125, 0.0125, 0.0125, -1.0], - [0.006535947712418301, 0.803921568627451, 0.006535947712418301, 0.0915032679738562, 0.0, 0.0915032679738562, 0.0, 1.0], - [0.0049504950495049506, 0.9900990099009901, 0.0049504950495049506, 0.0, 0.0, 0.0, 0.0, -1.0], - [0.24372759856630824, 0.7526881720430108, 0.0, 0.0, 0.0, 0.0035842293906810036, 0.0, 1.0], - [0.31141868512110726, 0.25259515570934254, 0.0, 0.0, 0.0, 0.4359861591695502, 0.0, -1.0], - [0.4260355029585799, 0.30177514792899407, 0.0, 0.0, 0.0, 0.27218934911242604, 0.0, 1.0], - [0.10596026490066225, 0.609271523178808, 0.0, 0.0, 0.0, 0.2847682119205298, 0.0, -1.0], - [0.005847953216374269, 0.6432748538011696, 0.0, 0.0, 0.0, 0.3508771929824561, 0.0, 1.0], - [0.0, 0.6474820143884892, 0.0, 0.0, 0.0, 0.35251798561151076, 0.0, -1.0], - [0.0, 0.5325443786982249, 0.0, 0.0, 0.0, 0.46745562130177515, 0.0, 1.0], - [0.23076923076923078, 0.0, 0.0, 0.0, 0.0, 0.7692307692307693, 0.0, -1.0], - [0.1339712918660287, 0.0, 0.0, 0.0, 0.0, 0.8660287081339713, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0], - [0.8602150537634409, 0.0, 0.0, 0.0, 0.0, 0.13978494623655913, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.13043478260869565, 0.8695652173913043, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.5857740585774058, 0.41422594142259417, -1.0], - [0.0, 0.0, 0.8812785388127854, 0.0, 0.0, 0.0, 0.1187214611872146, 1.0], - [0.0, 0.0, 0.025735294117647058, 0.0, 0.0, 0.0, 0.9742647058823529, -1.0], - [0.0, 0.0, 0.997093023255814, 0.0, 0.0, 0.0, 0.0029069767441860465, 1.0], - [0.0, 0.0, 0.9597156398104265, 0.0, 0.0, 0.0, 0.04028436018957346, -1.0], - [0.02066115702479339, 0.0, 0.9793388429752066, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.9439421338155516, 0.0, 0.0325497287522604, 0.0, 0.0, 0.0, 0.023508137432188065, -1.0], - [0.39933444259567386, 0.0, 0.5990016638935108, 0.0, 0.0, 0.0, 0.0016638935108153079, 1.0], - [0.022779043280182234, 0.0, 0.9339407744874715, 0.0, 0.0, 0.0, 0.04328018223234624, -1.0]]] \ No newline at end of file +[[[[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 0, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 0, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]], + [[[0, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0, 1]], + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0], [1, 0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 1, 0]]]], + [[0.004, 0.008, 0.674, 0.004, 0.024, 0.232, 0.054, 1.000], + [0.007, 0.030, 0.778, 0.005, 0.043, 0.122, 0.016, -1.000], + [0.507, 0.010, 0.013, 0.232, 0.218, 0.003, 0.018, 1.000], + [0.015, 0.029, 0.059, 0.014, 0.017, 0.848, 0.019, -1.000], + [0.051, 0.121, 0.513, 0.297, 0.006, 0.005, 0.006, 1.000], + [0.010, 0.021, 0.008, 0.946, 0.007, 0.005, 0.003, -1.000], + [0.007, 0.008, 0.022, 0.014, 0.905, 0.034, 0.010, 1.000], + [0.701, 0.006, 0.142, 0.103, 0.027, 0.004, 0.017, -1.000], + [0.003, 0.086, 0.477, 0.307, 0.014, 0.103, 0.009, 1.000], + [0.033, 0.028, 0.018, 0.726, 0.018, 0.006, 0.173, -1.000], + [0.021, 0.012, 0.088, 0.854, 0.012, 0.005, 0.008, 1.000], + [0.034, 0.008, 0.742, 0.032, 0.019, 0.046, 0.119, -1.000], + [0.005, 0.683, 0.007, 0.040, 0.120, 0.050, 0.094, 1.000], + [0.345, 0.041, 0.092, 0.427, 0.080, 0.009, 0.006, -1.000], + [0.874, 0.011, 0.006, 0.003, 0.032, 0.072, 0.003, 1.000], + [0.004, 0.169, 0.113, 0.004, 0.703, 0.004, 0.004, -1.000], + [0.097, 0.104, 0.007, 0.015, 0.007, 0.015, 0.755, 1.000], + [0.934, 0.007, 0.010, 0.010, 0.008, 0.020, 0.010, -1.000], + [0.944, 0.009, 0.009, 0.013, 0.009, 0.009, 0.009, 1.000], + [0.022, 0.012, 0.089, 0.010, 0.505, 0.352, 0.010, -1.000], + [0.030, 0.015, 0.007, 0.121, 0.812, 0.007, 0.007, 1.000], + [0.009, 0.008, 0.009, 0.028, 0.923, 0.012, 0.011, -1.000], + [0.029, 0.022, 0.043, 0.064, 0.681, 0.085, 0.077, 1.000], + [0.006, 0.003, 0.006, 0.953, 0.000, 0.026, 0.007, -1.000], + [0.052, 0.041, 0.041, 0.052, 0.000, 0.750, 0.064, 1.000]]] diff --git a/tests/othello/test_othello_game.py b/tests/othello/test_othello_game.py index 1840fdf..889d1eb 100644 --- a/tests/othello/test_othello_game.py +++ b/tests/othello/test_othello_game.py @@ -1,3 +1,5 @@ +from textwrap import dedent + import numpy as np import pytest @@ -8,23 +10,34 @@ def test_create_board(): - x, o = OthelloState.X_PLAYER, OthelloState.O_PLAYER - # 6x6 grid of spaces, plus next player. - expected_board = [0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, o, x, 0, 0, - 0, 0, x, o, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - x] + expected_spaces = np.array([[[0, 0, 0, 0, 0, 0], # X pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0], # O pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]]) board = OthelloState() - assert board.board.tolist() == expected_board + assert np.array_equal(board.get_spaces(), expected_spaces) + assert board.get_active_player() == board.X_PLAYER + + +def test_repr(): + expected_repr = (r"OthelloState(" + r"'......\n......\n..OX..\n..XO..\n......\n......\n>X\n')") + board = OthelloState() + + assert repr(board) == expected_repr # noinspection DuplicatedCode def test_create_board_from_text(): - x, o = OthelloState.X_PLAYER, OthelloState.O_PLAYER text = """\ ...... ...... @@ -34,21 +47,26 @@ def test_create_board_from_text(): ...XO. >O """ - expected_board = [0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, x, 0, - 0, 0, 0, x, o, 0, - o] + expected_spaces = np.array([[[0, 0, 0, 0, 0, 0], # X pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0]], + [[0, 0, 0, 0, 0, 0], # O pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0]]]) board = OthelloState(text) - assert board.board.tolist() == expected_board + assert np.array_equal(board.spaces, expected_spaces) + assert board.get_active_player() == board.O_PLAYER # noinspection DuplicatedCode def test_create_board_with_coordinates(): - x, o = OthelloState.X_PLAYER, OthelloState.O_PLAYER text = """\ ABCDEF 1 ...... @@ -59,28 +77,26 @@ def test_create_board_with_coordinates(): 6 ...XO. >X """ - expected_board = [0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, x, 0, - 0, 0, 0, x, o, 0, - x] + expected_spaces = np.array([[[0, 0, 0, 0, 0, 0], # X pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0]], + [[0, 0, 0, 0, 0, 0], # O pieces + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0]]]) board = OthelloState(text) - assert board.board.tolist() == expected_board + assert np.array_equal(board.spaces, expected_spaces) + assert board.get_active_player() == board.X_PLAYER # noinspection DuplicatedCode def test_display(): - x, o = OthelloState.X_PLAYER, OthelloState.O_PLAYER - board = np.array([0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, x, 0, - 0, 0, 0, x, o, 0, - x]) expected_text = """\ ...... ...... @@ -90,21 +106,13 @@ def test_display(): ...XO. >X """ - text = OthelloState(spaces=board).display() + text = OthelloState(expected_text).display() assert text == expected_text # noinspection DuplicatedCode def test_display_coordinates(): - x, o = OthelloState.X_PLAYER, OthelloState.O_PLAYER - board = np.array([0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, x, 0, - 0, 0, 0, x, o, 0, - x]) expected_text = """\ ABCDEF 1 ...... @@ -115,7 +123,7 @@ def test_display_coordinates(): 6 ...XO. >X """ - text = OthelloState(spaces=board).display(show_coordinates=True) + text = OthelloState(expected_text).display(show_coordinates=True) assert text == expected_text @@ -457,12 +465,6 @@ def test_no_moves_for_either(): assert board.get_winner() == board.O_PLAYER -def test_create_from_array(): - board = OthelloState(spaces=np.zeros(65, dtype=np.int8)) - - assert board.board_width == 8 - - def test_training_data(): state = OthelloState() neural_net = NeuralNet(state) @@ -472,3 +474,26 @@ def test_training_data(): iterations=10, data_size=10) neural_net.train(boards, outputs) + + +def test_equality(): + state1 = OthelloState(dedent("""\ + ...... + ..O... + ..OO.. + ..OX.. + ...... + ...... + >X + """)) + state2 = OthelloState(dedent("""\ + ...... + ..O... + ..OO.. + ..OX.. + ...... + ...... + >O + """)) + + assert not state1 == state2 diff --git a/tests/test_mcts_player.py b/tests/test_mcts_player.py index 9bc0793..28ec552 100644 --- a/tests/test_mcts_player.py +++ b/tests/test_mcts_player.py @@ -59,7 +59,7 @@ def test_repr(): ... """ board = TicTacToeState(board_text) - expected_repr = "SearchNode(TicTacToeState(spaces=array([[0, -1, 0], [0, 1, 0], [0, 0, 0]])))" + expected_repr = r"SearchNode(TicTacToeState('.O.\n.X.\n...\n'))" node = SearchNode(board) node_repr = repr(node) @@ -358,38 +358,38 @@ def test_create_training_data(): start_state = TicTacToeState() manager = SearchManager(start_state, FirstChoiceHeuristic()) expected_boards, expected_outputs = zip(*[ - [start_state.get_spaces(), + [start_state.spaces, np.array([1., 0., 0., 0., 0., 0., 0., 0., 0., -1.])], [TicTacToeState("""\ X.. ... ... -""").get_spaces(), np.array([0., 1., 0., 0., 0., 0., 0., 0., 0., 1.])], +""").spaces, np.array([0., 1., 0., 0., 0., 0., 0., 0., 0., 1.])], [TicTacToeState("""\ XO. ... ... -""").get_spaces(), np.array([0., 0., 1., 0., 0., 0., 0., 0., 0., -1.])], +""").spaces, np.array([0., 0., 1., 0., 0., 0., 0., 0., 0., -1.])], [TicTacToeState("""\ XOX ... ... -""").get_spaces(), np.array([0., 0., 0., 1., 0., 0., 0., 0., 0., 1.])], +""").spaces, np.array([0., 0., 0., 1., 0., 0., 0., 0., 0., 1.])], [TicTacToeState("""\ XOX O.. ... -""").get_spaces(), np.array([0., 0., 0., 0., 1., 0., 0., 0., 0., -1.])], +""").spaces, np.array([0., 0., 0., 0., 1., 0., 0., 0., 0., -1.])], [TicTacToeState("""\ XOX OX. ... -""").get_spaces(), np.array([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.])], +""").spaces, np.array([0., 0., 0., 0., 0., 1., 0., 0., 0., 1.])], [TicTacToeState("""\ XOX OXO ... -""").get_spaces(), np.array([0., 0., 0., 0., 0., 0., 1., 0., 0., -1.])]]) +""").spaces, np.array([0., 0., 0., 0., 0., 0., 1., 0., 0., -1.])]]) expected_boards = np.stack(expected_boards) expected_outputs = np.stack(expected_outputs) @@ -399,6 +399,15 @@ def test_create_training_data(): assert repr(outputs) == repr(expected_outputs) +def test_create_more_training_data(): + start_state = TicTacToeState() + manager = SearchManager(start_state, FirstChoiceHeuristic()) + + boards, outputs = manager.create_training_data(iterations=1, data_size=11) + + assert boards.shape == (11, 2, 3, 3) + + def test_win_scores_one(): """ Expose bug where search continues after a game-ending position. """ state1 = TicTacToeState("""\ diff --git a/tests/test_play_controller.py b/tests/test_play_controller.py index cca5baf..e1c78d8 100644 --- a/tests/test_play_controller.py +++ b/tests/test_play_controller.py @@ -16,13 +16,17 @@ def __init__(self, super().__init__(board_height, board_width, spaces=spaces) def is_win(self, player: int) -> bool: - if (self.board == 0).sum() != 0: + spaces = self.spaces + empty_spaces = spaces.sum(axis=0) == 0 + if empty_spaces.sum() != 0: return False - x_count = (self.board == self.X_PLAYER).sum() - o_count = (self.board == self.O_PLAYER).sum() + piece_type = self.piece_types.index(player) + opponent_type = 1 - piece_type + player_count = spaces[piece_type].sum() + opponent_count = spaces[opponent_type].sum() if player == self.X_PLAYER: - return o_count < x_count - return x_count <= o_count + return opponent_count < player_count + return opponent_count <= player_count class SecondPlayerWinsGame(FirstPlayerWinsGame): diff --git a/tests/test_playout.py b/tests/test_playout.py index b8815e1..24aeff0 100644 --- a/tests/test_playout.py +++ b/tests/test_playout.py @@ -41,7 +41,7 @@ def display_move(self, move: int) -> str: def get_move_count(self) -> int: return self.move_count - def get_spaces(self) -> np.ndarray: + def spaces(self) -> np.ndarray: return np.ndarray([self.value, self.move_count+1]) def parse_move(self, text: str) -> int: diff --git a/tests/tictactoe/test_tictactoe_game.py b/tests/tictactoe/test_tictactoe_game.py index 2d1d825..9d48880 100644 --- a/tests/tictactoe/test_tictactoe_game.py +++ b/tests/tictactoe/test_tictactoe_game.py @@ -1,3 +1,5 @@ +from textwrap import dedent + import numpy as np import pytest @@ -5,28 +7,47 @@ def test_create_board(): - expected_spaces = np.array([[0, 0, 0], - [0, 0, 0], - [0, 0, 0]]) + expected_spaces = np.array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]) board = TicTacToeState() assert np.array_equal(board.get_spaces(), expected_spaces) +# noinspection DuplicatedCode def test_create_board_from_text(): text = """\ X.. .O. ... """ - expected_spaces = np.array([[1, 0, 0], - [0, -1, 0], - [0, 0, 0]]) + expected_spaces = np.array([[[1, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]]) board = TicTacToeState(text) assert np.array_equal(board.get_spaces(), expected_spaces) +def test_repr(): + text = dedent("""\ + X.. + .O. + ... + """) + state = TicTacToeState(text) + + assert repr(state) == r"TicTacToeState('X..\n.O.\n...\n')" + + +# noinspection DuplicatedCode def test_create_board_with_coordinates(): text = """\ ABC @@ -34,38 +55,37 @@ def test_create_board_with_coordinates(): 2 .O. 3 ... """ - expected_spaces = np.array([[1, 0, 0], - [0, -1, 0], - [0, 0, 0]]) + expected_spaces = np.array([[[1, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]]) board = TicTacToeState(text) assert np.array_equal(board.get_spaces(), expected_spaces) def test_display(): - board = TicTacToeState(spaces=np.array([[1, 0, 0], - [0, -1, 0], - [0, 0, 0]])) - expected_text = """\ -X.. -.O. -... -""" + expected_text = dedent("""\ + X.. + .O. + ... + """) + board = TicTacToeState(expected_text) text = board.display() assert text == expected_text def test_display_coordinates(): - board = TicTacToeState(spaces=np.array([[1, 0, 0], - [0, -1, 0], - [0, 0, 0]])) - expected_text = """\ - ABC -1 X.. -2 .O. -3 ... -""" + expected_text = dedent("""\ + ABC + 1 X.. + 2 .O. + 3 ... + """) + board = TicTacToeState(expected_text) text = board.display(show_coordinates=True) assert text == expected_text @@ -154,7 +174,7 @@ def test_make_move(): board2 = board1.make_move(move) display = board2.display() - assert expected_display == display + assert display == expected_display def test_make_move_o(): diff --git a/zero_play/connect4/game.py b/zero_play/connect4/game.py index 09592e8..9d7c613 100644 --- a/zero_play/connect4/game.py +++ b/zero_play/connect4/game.py @@ -1,3 +1,5 @@ +from copy import copy + import numpy as np from zero_play.game_state import GridGameState @@ -24,7 +26,7 @@ def get_valid_moves(self) -> np.ndarray: if self.get_winner() != self.NO_PLAYER: return np.zeros(self.board_width, dtype=bool) # Any zero value in top row is a valid move - return self.board[0] == 0 + return self.spaces[:, 0].sum(axis=0) == 0 def display(self, show_coordinates: bool = False) -> str: header = '1234567\n' if show_coordinates else '' @@ -38,17 +40,23 @@ def parse_move(self, text: str) -> int: def make_move(self, move: int) -> 'Connect4State': moving_player = self.get_active_player() - new_board: np.ndarray = self.board.copy() - available_idx, = np.where(new_board[:, move] == 0) + new_board = copy(self) + spaces = new_board.spaces + empty_spaces = spaces.sum(axis=0) == 0 + available_idx, = np.where(empty_spaces[:, move]) - new_board[available_idx[-1]][move] = moving_player - return Connect4State(spaces=new_board) + piece_type = self.piece_types.index(moving_player) + spaces[piece_type, available_idx[-1], move] = 1 + new_board.spaces = spaces + return new_board def is_win(self, player: int) -> bool: """ Has the given player collected four in a row in any direction? """ - row_count, column_count = self.board.shape + row_count, column_count = self.board_height, self.board_width win_count = 4 - player_pieces = self.board == player + spaces = self.spaces + piece_type = self.piece_types.index(player) + player_pieces = spaces[piece_type] if self.is_horizontal_win(player_pieces, win_count): return True if self.is_horizontal_win(player_pieces.transpose(), win_count): @@ -58,10 +66,10 @@ def is_win(self, player: int) -> bool: for start_column in range(column_count - win_count + 1): count1 = count2 = 0 for d in range(win_count): - if self.board[start_row + d, start_column + d] == player: + if player_pieces[start_row + d, start_column + d]: count1 += 1 - if self.board[start_row + d, - start_column + win_count - d - 1] == player: + if player_pieces[start_row + d, + start_column + win_count - d - 1]: count2 += 1 if count1 == win_count or count2 == win_count: return True diff --git a/zero_play/connect4/neural_net.py b/zero_play/connect4/neural_net.py index 9e81b9e..d536c1f 100644 --- a/zero_play/connect4/neural_net.py +++ b/zero_play/connect4/neural_net.py @@ -45,7 +45,7 @@ def __init__(self, start_state: GameState) -> None: num_channels = 64 kernel_size = [3, 3] regularizer = regularizers.l2(0.0001) - input_shape = (self.board_height, self.board_width, 1) + input_shape = (2, self.board_height, self.board_width, 1) model = Sequential( [Conv2D(num_channels, kernel_size, @@ -68,8 +68,8 @@ def analyse(self, board: GameState) -> typing.Tuple[float, np.ndarray]: if board.is_ended(): return self.analyse_end_game(board) - outputs = self.model(board.get_spaces().reshape( - (1, self.board_height, self.board_width, 1))).numpy() + outputs = self.model(board.spaces.reshape( + (1, 2, self.board_height, self.board_width, 1))).numpy() policy = outputs[0, :-1] value = outputs[0, -1] diff --git a/zero_play/game_state.py b/zero_play/game_state.py index 60fb7d7..0c09c15 100644 --- a/zero_play/game_state.py +++ b/zero_play/game_state.py @@ -1,6 +1,6 @@ +import math import typing from abc import ABC, abstractmethod -from io import StringIO import numpy as np @@ -66,13 +66,22 @@ def get_players(self) -> typing.Iterable[int]: def get_move_count(self) -> int: """ The number of moves that have already been made in the game. """ + @property @abstractmethod - def get_spaces(self) -> np.ndarray: + def spaces(self) -> np.ndarray: """ Extract the board spaces from the complete game state. Useful for teaching machine learning models. """ + @spaces.setter + @abstractmethod + def spaces(self, spaces: np.ndarray): + """ Set pieces on the board spaces. """ + + def get_spaces(self) -> np.ndarray: + return self.spaces + @abstractmethod def parse_move(self, text: str) -> int: """ Parse a human-readable description into a move index. @@ -97,8 +106,8 @@ def get_active_player(self) -> int: PLAYER_O. """ board = self.get_spaces() - x_count = (board == self.X_PLAYER).sum() - y_count = (board == self.O_PLAYER).sum() + x_count = board[0].sum() + y_count = board[1].sum() return self.X_PLAYER if x_count == y_count else self.O_PLAYER @abstractmethod @@ -132,23 +141,35 @@ def is_win(self, player: int) -> bool: # noinspection PyAbstractClass class GridGameState(GameState): + """ Game state for a simple grid with pieces on it. """ def __init__(self, board_height: int, board_width: int, text: str | None = None, lines: typing.Sequence[str] | None = None, - spaces: np.ndarray | None = None, - extra_count: int = 0): + spaces: np.ndarray | None = None): + """ Initialize a new instance. + + :param board_height: number of rows in the grid + :param board_width: number of columns in the grid + :param text: text representation of the game state, like that returned + by display() + :param lines: equivalent to text, but already split into lines + :param spaces: 3-dimensional boolean array 1 when a piece type is + in a grid space, 0 when it isn't, with shape + (piece_type_count, board_height, board_width) + """ self.board_height = board_height self.board_width = board_width - if spaces is None: - self.board = np.zeros(self.board_height*self.board_width + extra_count, - dtype=int) - else: - self.board = spaces - spaces = self.get_spaces() - if extra_count == 0: - self.board = spaces + if spaces is not None: + assert text is None + assert lines is None + self.spaces = spaces + return + type_count = len(self.piece_types) + packed = np.zeros(math.ceil(board_height*board_width*type_count/8), + dtype=np.uint8) + self.packed = packed if text: lines = text.splitlines() if lines: @@ -156,47 +177,72 @@ def __init__(self, # Trim off coordinates. lines = lines[1:] lines = [line[2:] for line in lines] - for i, line in enumerate(lines): - spaces[i] = [self.DISPLAY_CHARS.index(c) - 1 for c in line] + line_array = np.array(lines, dtype=str) + chars = line_array.view('U1').reshape(self.board_height, + self.board_width) + spaces = self.get_spaces() + for layer, display_char in enumerate(self.piece_displays): + spaces[layer] = chars == display_char + if spaces is not None: + self.spaces = spaces def __repr__(self): - board_repr = " ".join(repr(self.board).split()) - board_repr = board_repr.replace('[ ', '[') - return f'{self.__class__.__name__}(spaces={board_repr})' + board_text = self.display() + return f'{self.__class__.__name__}({board_text!r})' def __eq__(self, other): if not isinstance(other, GridGameState): return False - return np.array_equal(self.board, other.board) + return np.array_equal(self.spaces, other.spaces) + + @property + def piece_types(self): + return self.X_PLAYER, self.O_PLAYER + + @property + def piece_displays(self): + return 'XO' def get_move_count(self) -> int: - return int((self.get_spaces() != GameState.NO_PLAYER).sum()) + return self.spaces.sum() - def get_spaces(self) -> np.ndarray: - return self.board[:self.board_height*self.board_width].reshape( - self.board_height, - self.board_width) + @property + def spaces(self) -> np.ndarray: + type_count = len(self.piece_types) + trimmed_size = self.board_height * self.board_width * type_count + trimmed = np.unpackbits(self.packed)[:trimmed_size] + return trimmed.reshape(type_count, + self.board_height, + self.board_width) + + @spaces.setter + def spaces(self, spaces): + self.packed = np.packbits(spaces) def get_valid_moves(self) -> np.ndarray: spaces = self.get_spaces() - return spaces.reshape(self.board_height * - self.board_width) == GameState.NO_PLAYER + full_spaces = np.logical_or.accumulate(spaces)[-1] + empty_spaces = np.logical_not(full_spaces) + return empty_spaces.reshape(self.board_height * self.board_width) def display(self, show_coordinates: bool = False) -> str: - result = StringIO() + spaces = self.get_spaces().astype(bool) + display_grid = np.full((self.board_height, self.board_width), '.') + for level, char in enumerate(self.piece_displays): + np.copyto(display_grid, char, where=spaces[level]) + lines = np.full(self.board_height, '') if show_coordinates: - result.write(' ') - for i in range(65, 65+self.board_width): - result.write(chr(i)) - result.write('\n') - spaces = self.get_spaces() - for i in range(self.board_height): - if show_coordinates: - result.write(chr(49+i) + ' ') - for j in range(self.board_width): - result.write(self.DISPLAY_CHARS[spaces[i, j]+1]) - result.write('\n') - return result.getvalue() + lines = np.char.add(lines, + [chr(49+i) + ' ' + for i in range(self.board_height)]) + for j in range(self.board_width): + lines = np.char.add(lines, display_grid[:, j]) + text = '\n'.join(lines) + '\n' + if show_coordinates: + header = ' ' + ''.join(chr(i) + for i in range(65, 65 + self.board_width)) + text = header + '\n' + text + return text def display_move(self, move: int) -> str: row = move // self.board_width @@ -221,10 +267,11 @@ def parse_move(self, text: str) -> int: def make_move(self, move: int) -> 'GridGameState': moving_player = self.get_active_player() - new_board: np.ndarray = self.board.copy() + piece_type = self.piece_types.index(moving_player) + new_spaces = self.get_spaces() # always an unpacked copy i, j = move // self.board_width, move % self.board_width - new_board[i, j] = moving_player + new_spaces[piece_type, i, j] = 1 return self.__class__(board_height=self.board_height, board_width=self.board_width, - spaces=new_board) + spaces=new_spaces) diff --git a/zero_play/grid_display.py b/zero_play/grid_display.py index 1d77dfb..512a98d 100644 --- a/zero_play/grid_display.py +++ b/zero_play/grid_display.py @@ -1,6 +1,7 @@ import itertools import typing +import numpy as np from PySide6.QtGui import QColor, QBrush, QFont, QResizeEvent, QPixmap, Qt, QPainter, QPen from PySide6.QtWidgets import QGraphicsEllipseItem, \ QGraphicsSceneHoverEvent, QGraphicsSceneMouseEvent, QGraphicsScene @@ -155,12 +156,16 @@ def update_board(self, state: GameState): self.current_state = state self.valid_moves = self.current_state.get_valid_moves() is_ended = self.current_state.is_ended() - spaces = self.current_state.get_spaces() + state_spaces = self.current_state.spaces for i in range(self.current_state.board_height): for j in range(self.current_state.board_width): - player = spaces[i][j] + piece_types = np.nonzero(state_spaces[:, i, j])[0] + if piece_types.size: + piece_type = piece_types[0] + else: + piece_type = None piece = self.spaces[i][j] - if player == self.current_state.NO_PLAYER: + if piece_type is None: if is_ended: piece.setVisible(False) else: @@ -168,6 +173,7 @@ def update_board(self, state: GameState): piece.setBrush(self.background_colour) piece.setPen(self.background_colour) else: + player = self.current_state.piece_types[piece_type] piece.setVisible(True) piece.setBrush(self.get_player_brush(player)) piece.setPen(self.line_colour) @@ -240,9 +246,9 @@ def calculate_move(self, row, column): return move def is_piece_played(self, piece_item): - current_spaces = self.current_state.get_spaces() - hovered_player = current_spaces[piece_item.row][piece_item.column] - return hovered_player != self.start_state.NO_PLAYER + current_spaces = self.current_state.spaces + hovered_space = current_spaces[:, piece_item.row, piece_item.column] + return bool(hovered_space.sum()) def close(self): super().close() diff --git a/zero_play/mcts_player.py b/zero_play/mcts_player.py index 7812170..60caf25 100644 --- a/zero_play/mcts_player.py +++ b/zero_play/mcts_player.py @@ -239,6 +239,7 @@ def check_tasks(self, timeout, return_when): def get_best_move(self) -> int: best_children = self.current_node.find_best_children() + # noinspection PyTypeChecker self.current_node = child = np.random.choice(best_children) assert child.move is not None return child.move @@ -301,10 +302,10 @@ def create_training_data( the final value of this position for the active player. """ game_states: typing.List[typing.Tuple[GameState, np.ndarray]] = [] - self.search(self.current_node.game_state, milliseconds=1) # One extra to start. + self.search(self.current_node.game_state, iterations=1) # One extra to start. report_size = 0 - board_shape = self.current_node.game_state.get_spaces().shape - boards = np.zeros((data_size,) + board_shape, int) + board_shape = self.current_node.game_state.spaces.shape + boards = np.zeros((data_size,) + board_shape, np.uint8) move_count = self.current_node.game_state.get_valid_moves().size outputs = np.zeros((data_size, move_count + 1)) data_count = 0 @@ -347,6 +348,9 @@ def create_training_data( game_states.clear() self.reset() + # One extra to start the next game. + self.search(self.current_node.game_state, iterations=1) + class MctsPlayer(Player): """ Use Monte Carlo Tree Search to choose moves in a game. diff --git a/zero_play/othello/display.py b/zero_play/othello/display.py index c401b9d..eedabf5 100644 --- a/zero_play/othello/display.py +++ b/zero_play/othello/display.py @@ -16,10 +16,9 @@ def __init__(self, board_height: int = 8, board_width: int = 8): # noinspection DuplicatedCode def update_count_text(self): assert isinstance(self.current_state, OthelloState) - black_count = self.current_state.get_piece_count( - self.current_state.X_PLAYER) - white_count = self.current_state.get_piece_count( - self.current_state.O_PLAYER) + spaces = self.current_state.spaces + black_count = spaces[0].sum() + white_count = spaces[1].sum() self.ui.black_count.setText(f'{black_count}') self.ui.white_count.setText(f'{white_count}') diff --git a/zero_play/othello/game.py b/zero_play/othello/game.py index 165acff..cbbecd2 100644 --- a/zero_play/othello/game.py +++ b/zero_play/othello/game.py @@ -1,5 +1,5 @@ -import math import typing +from copy import copy import numpy as np @@ -12,12 +12,7 @@ class OthelloState(GridGameState): def __init__(self, text: str | None = None, board_height: int = 6, - board_width: int = 6, - spaces: np.ndarray | None = None): - if spaces is not None: - size = spaces.size - board_width = board_height = int(math.sqrt(size-1)) - assert text is None + board_width: int = 6): if text is None: lines = None next_player_line = None @@ -26,49 +21,53 @@ def __init__(self, next_player_line = lines.pop() super().__init__(board_height, board_width, - lines=lines, - extra_count=1, - spaces=spaces) - if spaces is not None: - return - spaces = self.get_spaces() + lines=lines) + spaces = self.spaces if text: assert next_player_line and next_player_line.startswith('>') - self.board[-1] = (self.X_PLAYER - if next_player_line.endswith('X') - else self.O_PLAYER) + self.active_player = (self.X_PLAYER + if next_player_line.endswith('X') + else self.O_PLAYER) else: - self.board[-1] = self.X_PLAYER + self.active_player = self.X_PLAYER for i in range(self.board_height//2-1, self.board_height//2+1): for j in range(self.board_width//2-1, self.board_width//2+1): player = self.X_PLAYER if (i+j) % 2 else self.O_PLAYER - spaces[i, j] = player + piece_type = self.piece_types.index(player) + spaces[piece_type, i, j] = 1 + self.spaces = spaces + + def __eq__(self, other): + return super().__eq__(other) and self.active_player == other.active_player def get_valid_moves(self) -> np.ndarray: spaces = self.get_spaces() moves = np.zeros(self.board_height * self.board_width + 1, bool) move_spaces = moves[:-1].reshape(self.board_width, self.board_height) player = self.get_active_player() - for i, j in self.find_moves(spaces, player): + piece_type = self.piece_types.index(player) + for i, j in self.find_moves(spaces, piece_type): move_spaces[i, j] = True if moves.sum() == 0: # No moves for this player, check opponent. - for _ in self.find_moves(spaces, -player): + for _ in self.find_moves(spaces, 1-piece_type): # Opponent has a move, pass is allowed. moves[-1] = True break return moves - def find_moves(self, spaces: np.ndarray, player: int): + def find_moves(self, spaces: np.ndarray, piece_type: int): for i in range(self.board_height): for j in range(self.board_width): - piece = spaces[i, j] - if piece == player: - yield from self.find_moves_from_space(spaces, i, j, player) + if spaces[piece_type, i, j]: + yield from self.find_moves_from_space(spaces, + i, + j, + piece_type) - def find_moves_from_space(self, spaces, start_row, start_column, player): + def find_moves_from_space(self, spaces, start_row, start_column, piece_type): for di in range(-1, 2): for dj in range(-1, 2): if not (di or dj): @@ -77,10 +76,10 @@ def find_moves_from_space(self, spaces, start_row, start_column, player): i = start_row + di j = start_column + dj while 0 <= i < self.board_height and 0 <= j < self.board_width: - piece = spaces[i, j] - if piece == player: + if spaces[piece_type, i, j]: break - if piece == self.NO_PLAYER: + if not spaces[1-piece_type, i, j]: + # empty space if has_flipped: yield i, j break @@ -91,7 +90,7 @@ def find_moves_from_space(self, spaces, start_row, start_column, player): def display(self, show_coordinates: bool = False) -> str: result = super().display(show_coordinates) - next_player = self.board[-1] + next_player = self.active_player return result + f'>{self.DISPLAY_CHARS[next_player+1]}\n' def display_move(self, move: int) -> str: @@ -106,17 +105,17 @@ def parse_move(self, text: str) -> int: return super().parse_move(trimmed) def make_move(self, move: int) -> 'OthelloState': - new_board: np.ndarray = self.board.copy() - player = new_board[-1] - new_board[-1] = -player + new_state = copy(self) + new_state.active_player = -self.active_player - new_state = OthelloState(spaces=new_board) if move == self.board_width * self.board_height: return new_state # It's a pass. - spaces = new_state.get_spaces() + spaces = new_state.spaces start_row = move // self.board_width start_column = move % self.board_width + piece_type = self.piece_types.index(self.active_player) + opponent_piece_type = 1 - piece_type for di in range(-1, 2): for dj in range(-1, 2): if not (di or dj): @@ -125,44 +124,47 @@ def make_move(self, move: int) -> 'OthelloState': i = start_row + di j = start_column + dj while 0 <= i < self.board_height and 0 <= j < self.board_width: - piece = spaces[i, j] - if piece == player: - for i, j in to_flip: - spaces[i, j] *= -1 + if spaces[piece_type, i, j]: + for i2, j2 in to_flip: + spaces[opponent_piece_type, i2, j2] = 0 + spaces[piece_type, i2, j2] = 1 break - if piece == self.NO_PLAYER: + elif not spaces[opponent_piece_type, i, j]: + # empty space break else: to_flip.append((i, j)) i += di j += dj - spaces[start_row, start_column] = player + spaces[piece_type, start_row, start_column] = 1 + new_state.spaces = spaces return new_state def get_active_player(self): - return self.board[-1] + return self.active_player def is_ended(self): - spaces = self.get_spaces() - player = self.board[-1] - for _ in self.find_moves(spaces, player): + spaces = self.spaces + piece_type = self.piece_types.index(self.active_player) + for _ in self.find_moves(spaces, piece_type): + # Player has a move, not ended. return False - for _ in self.find_moves(spaces, -player): + for _ in self.find_moves(spaces, 1-piece_type): + # Opponent has a move, not ended. return False return True def get_winner(self): if not self.is_ended(): return self.NO_PLAYER - total = self.board[:-1].sum() - if total > 0: + spaces = self.spaces + x_total = spaces[0].sum() + o_total = spaces[1].sum() + if x_total > o_total: return self.X_PLAYER - if total < 0: + if x_total < o_total: return self.O_PLAYER return self.NO_PLAYER - def get_piece_count(self, player: int): - return (self.board[:-1] == player).sum() - def is_win(self, player: int) -> bool: return self.get_winner() == player diff --git a/zero_play/process_display.py b/zero_play/process_display.py index e836784..8741579 100644 --- a/zero_play/process_display.py +++ b/zero_play/process_display.py @@ -15,6 +15,7 @@ def __init__(self) -> None: def close(self): self.stop_workers() + super().close() def stop_workers(self): if self.worker_thread is not None: diff --git a/zero_play/tictactoe/state.py b/zero_play/tictactoe/state.py index 7a4ca3b..bfeccf4 100644 --- a/zero_play/tictactoe/state.py +++ b/zero_play/tictactoe/state.py @@ -18,13 +18,14 @@ def __init__(self, def is_win(self, player: int) -> bool: """ Has the given player collected a triplet in any direction? """ + piece_type = self.piece_types.index(player) size = self.board_width - spaces = self.get_spaces() + player_pieces = self.get_spaces()[piece_type] # check horizontal lines for i in range(size): count = 0 for j in range(size): - if spaces[i, j] == player: + if player_pieces[i, j]: count += 1 if count == size: return True @@ -32,16 +33,16 @@ def is_win(self, player: int) -> bool: for j in range(size): count = 0 for i in range(size): - if spaces[i, j] == player: + if player_pieces[i, j]: count += 1 if count == size: return True # check two diagonal strips count1 = count2 = 0 for d in range(size): - if spaces[d, d] == player: + if player_pieces[d, d]: count1 += 1 - if spaces[d, size-d-1] == player: + if player_pieces[d, size - d - 1]: count2 += 1 if count1 == size or count2 == size: return True