-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_play.py
37 lines (26 loc) · 942 Bytes
/
self_play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from tqdm import tqdm
from game import Game
from mcts import mcts
def self_play(net, num_games, num_simulations):
results = { +1: 0, -1: 0, 0: 0 }
self_play_data = []
with tqdm(total=num_games, desc="Self play", unit="game") as prog_bar:
for i_game in range(num_games):
game = Game()
game_data = []
root = None
while not game.is_over():
pi, action, root = mcts(net, game, num_simulations, root=root) # Reuse MCTS results
root = root.children[action]
for s in game.get_state_symmetries():
game_data.append([s, pi, game.to_play()])
game.apply(action)
z = game.outcome()
results[z] += 1
# TODO: Remove duplicates (average their policies and values)
for i in range(len(game_data)):
game_data[i][2] *= z
self_play_data.extend(game_data)
prog_bar.set_postfix_str(f"Yellow={results[+1]} | Red={results[-1]} | Draw={results[0]}")
prog_bar.update(1)
return self_play_data