This repository has been archived by the owner on Mar 10, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathswarm_wave.py
executable file
·123 lines (110 loc) · 6.15 KB
/
swarm_wave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import time
import copy
import numpy as np
from typing import Callable
from fractalai.swarm import Swarm, DynamicTree
class SwarmWave(Swarm):
tree = DynamicTree()
def __init__(self, env, model, n_walkers: int=100, balance: float=1.,
reward_limit: float=None, samples_limit: int=None, render_every: int=1e10,
accumulate_rewards: bool=True, dt_mean: float=None, dt_std: float=None,
min_dt: int=1, custom_reward: Callable=None, custom_end: Callable=None,
process_obs: Callable=None, custom_skipframe: Callable=None,
keep_best: bool=False, can_win: bool=False,
save_data: bool=True, prune_tree: bool=True):
"""
:param env: Environment that will be sampled.
:param model: Model used for sampling actions from observations.
:param n_walkers: Number of walkers that the swarm will use
:param balance: Balance coefficient for the virtual reward formula.
:param reward_limit: Maximum reward that can be reached before stopping the swarm.
:param samples_limit: Maximum number of time the Swarm can sample the environment
befors stopping.
:param render_every: Number of iterations that will be performed before printing the Swarm
status.
:param accumulate_rewards: Use the accumulated reward when scoring the walkers.
False to use instantaneous reward.
:param dt_mean: Mean skipframe used for exploring.
:param dt_std: Standard deviation for the skipframe. Sampled from a normal distribution.
:param min_dt: Minimum skipframe to be used by the swarm.
:param custom_reward: Callable for calculating a custom reward function.
:param custom_end: Callable for calculating custom boundary conditions.
:param process_obs: Callable for doing custom observation processing.
:param custom_skipframe: Callable for sampling the skipframe values of the walkers.
:param keep_best: Keep track of the best accumulated reward found so far.
:param can_win: If the game can be won when a given score is achieved, set to True. Meant
to be used with Atari games like Boxing, Pong, IceHockey, etc.
:param save_data: Store data to construct a tree of paths.
:param prune_tree: Delete a path if no walker is expanding it.
"""
super(SwarmWave, self).__init__(env=env, model=model, n_walkers=n_walkers,
balance=balance, reward_limit=reward_limit,
samples_limit=samples_limit, render_every=render_every,
accumulate_rewards=accumulate_rewards, dt_mean=dt_mean,
dt_std=dt_std, custom_end=custom_end,
custom_reward=custom_reward, keep_best=keep_best,
min_dt=min_dt, process_obs=process_obs, can_win=can_win,
custom_skipframe=custom_skipframe)
self.save_data = save_data
self.prune_tree = prune_tree
self.old_ids = np.zeros(self.n_walkers)
self._current_index = None
self._curr_states = []
self._curr_actions = []
self._curr_dts = []
self._current_ix = -1
def __str__(self):
text = super(SwarmWave, self).__str__()
if self.save_data:
efi = (len(self.tree.data.nodes) / self._n_samples_done) * 100
sam_step = self._n_samples_done / len(self.tree.data.nodes)
samples = len(self.tree.data.nodes)
else:
efi, samples, sam_step = 0, 0, 0
new_text = "{}\n"\
"Efficiency {:.2f}%\n" \
"Generated {} Examples |" \
" {:.2f} samples per example.\n".format(text, efi, samples, sam_step)
return new_text
def init_swarm(self, state: np.ndarray=None, obs: np.ndarray=None):
super(SwarmWave, self).init_swarm(state=state, obs=obs)
self.tree.data.nodes[0]["obs"] = obs if obs is not None else self.env.reset()[1]
self.tree.data.nodes[0]["terminal"] = False
def step_walkers(self):
old_ids = self.walkers_id.copy()
super(SwarmWave, self).step_walkers()
if self.save_data:
for i, idx in enumerate(self.walkers_id):
self.tree.append_leaf(int(idx), parent_id=int(old_ids[i]),
state=self.data.get_states([idx]).copy()[0],
action=self.data.get_actions([idx]).copy()[0],
dt=copy.deepcopy(self.dt[i]))
def clone(self):
super(SwarmWave, self).clone()
# Prune tree to save memory
if self.save_data and self.prune_tree:
dead_leafs = list(set(self._pre_clone_ids) - set(self._post_clone_ids))
self.tree.prune_tree(dead_leafs, self._post_clone_ids)
def recover_game(self, index=None) -> tuple:
"""
By default, returns the game sampled with the highest score.
:param index: id of the leaf where the returned game will finish.
:return: a list containing the observations of the target sampled game.
"""
if index is None:
index = self.walkers_id[self.rewards.argmax()]
return self.tree.get_branch(index)
def render_game(self, index=None, sleep: float=0.02):
"""Renders the game stored in the tree that ends in the node labeled as index."""
states, actions, dts = self.recover_game(index)
for state, action, dt in zip(states, actions, dts):
_, _, _, end, _ = self._env.step(action, state=state, n_repeat_action=1)
self._env.render()
time.sleep(sleep)
for i in range(max(0, dt - 1)):
self._env.step(action, n_repeat_action=1)
self._env.render()
time.sleep(sleep)
def run_swarm(self, state: np.ndarray=None, obs: np.ndarray=None, print_swarm: bool=False):
self.tree.reset()
super(SwarmWave, self).run_swarm(state=state, obs=obs, print_swarm=print_swarm)