-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplay.py
110 lines (83 loc) · 2.32 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from gui import GUI
from game import Game
from mcts import mcts
from network import AlphaZeroNet
import time
import torch
import random
import time
import sys
from tqdm import tqdm
def play_move_player(game: Game, gui: GUI):
gui.handle_events()
while True:
if gui.clicked is not None:
ix, iy = gui.clicked
a = ix
if a in game.get_actions():
game.apply(a)
break
gui.draw()
gui.handle_events()
def play_move_ai_mcts(game: Game, net: AlphaZeroNet):
pi, a, _ = mcts(net, game, 100, eval=True)
game.apply(a)
def play_move_random(game: Game):
a = random.choice(game.get_actions())
game.apply(a)
def play(net: AlphaZeroNet):
game = Game()
gui = GUI(game.board)
gui.draw()
while not game.is_over():
if game.to_play() == 1: # Yellow
# play_move_player(game, gui)
play_move_ai_mcts(game, net)
# play_move_random(game)
else: # Red
play_move_player(game, gui)
# play_move_ai_mcts(game, net)
# play_move_random(game)
gui.draw()
gui.handle_events()
print(f"Outcome: {game.outcome()}")
while gui.running:
gui.draw()
gui.handle_events()
def test(net: AlphaZeroNet, net2: AlphaZeroNet, num_games):
results = { +1: 0, -1: 0, 0: 0 }
with tqdm(total=num_games, desc="Playing games", unit="game") as prog_bar:
for i_game in range(num_games):
game = Game()
while not game.is_over():
if game.to_play() == 1: # Yellow
play_move_ai_mcts(game, net)
# play_move_random(game)
else: # Red
play_move_ai_mcts(game, net2)
# play_move_random(game)
results[game.outcome()] += 1
prog_bar.set_postfix_str(f"Yellow = {results[1]} | Red = {results[-1]} | Draw = {results[0]}")
prog_bar.update(1)
print()
print(f"Yellow:\t{100 * results[1] / num_games}%")
print(f"Red:\t{100 * results[-1] / num_games}%")
print(f"Draw:\t{100 * results[0] / num_games}%")
print()
if __name__ == "__main__":
net = AlphaZeroNet()
net.cuda()
# net.initialize_parameters()
net.load_state_dict(torch.load("data/model.pt")["state_dict"])
net.eval()
net2 = AlphaZeroNet()
net2.cuda()
# net2.initialize_parameters()
net2.load_state_dict(torch.load("data/model.pt")["state_dict"])
net2.eval()
with torch.no_grad():
if len(sys.argv) > 1 and sys.argv[1] == "test":
num_games = int(sys.argv[2]) if len(sys.argv) > 2 else 100
test(net, net2, num_games)
else:
play(net)