-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
204 lines (178 loc) · 7.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import glob
import sys
import os
import numpy as np
from temperature_observation import TemperatureObservation
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv, RailEnvActions
from dqn.agent import Agent
from flatland.envs.agent_utils import RailAgentStatus
from temperature_observation.utils import normalize_tree_observation, normalize_temperature_observation
from temperature_observation.utils import format_action_prob
import wandb
# flatland environment parameters, random seed is initialized to get the same map over different runs
# and test algorithms equally
seed = 69 # nice
width = 15
height = 15
# tree-observation parameters
num_agents = 3
tree_depth = 2
radius_observation = 10
# weight and biases configuration
wandb.init(project='flatlands', entity='fatlads', tags=["cdddqn_parallel", "cdddqn", "prio_exp_rpl", "temp"])
config = wandb.config
random_rail_generator = complex_rail_generator(
nr_start_goal=10, # number of start and end goals
# connections, the higher the easier it should be for the trains
nr_extra=10, # extra connections
# (useful for alternite paths), the higher the easier
min_dist=10,
max_dist=99999,
seed=seed
)
env = RailEnv(
width=width,
height=height,
rail_generator=random_rail_generator,
obs_builder_object=TemperatureObservation(tree_depth),
number_of_agents=num_agents
)
obs, info = env.reset()
# observation from the environment is first normalized to get the shape size
normalized_temp = normalize_temperature_observation(obs[0][0]).flatten()
normalized_tree = normalize_tree_observation(obs[0][1], tree_depth, radius_observation)
state_shape = np.concatenate((normalized_temp, normalized_tree)).shape
action_shape = (5,)
# specify the algorithm to use and every parameter
method = "cdddqn"
agent007 = Agent(state_shape,
action_shape[0],
(width, height),
gamma=0.99,
replace=100,
lr=0.001,
epsilon_decay=1e-3,
decay_type="flat",
initial_epsilon=1.0,
min_epsilon=0.01,
batch_size=64,
method=method)
# load previously trained model to initialize weights
if glob.glob(f"{method}*") != []:
agent007.load_model()
# visualization parameters
saving_interval = 50
max_steps = env._max_episode_steps
smoothed_normalized_score = -1.0
smoothed_completion = 0.0
smoothing = 0.99
action_count = [0] * action_shape[0]
# variables used to store informations about the timestep
action_dict = dict()
agent_obs = [None] * num_agents
agent_prev_obs = [None] * num_agents
agent_prev_action = [2] * num_agents
update_values = [False] * num_agents
# train for 3000 episodes
for episode in range(3000):
try:
# Initialize episode and utility variables
obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
done = {i: False for i in range(0, num_agents)}
done["__all__"] = False
scores = 0
step_counter = 0
# store the normalized observation of each agent
for agent in env.get_agent_handles():
if obs[agent] is not None:
norm_temp = normalize_temperature_observation(obs[agent][0]).flatten()
norm_tree = normalize_tree_observation(obs[agent][1], tree_depth, radius_observation)
agent_obs[agent] = np.concatenate((norm_temp, norm_tree))
agent_prev_obs[agent] = agent_obs[agent].copy()
for step in range(max_steps - 1):
actions = {}
agents_obs = {}
# collect actions for each agent
for agent in env.get_agent_handles():
if info['action_required'][agent]:
update_values[agent] = True
# perform action masking, giving to the agent which actions can be performed from the current
# tile. Even though the agent should learn itself the legal moves this speeds up the learning
# as q-values will be kept low for action that aren't needed
legal_moves = np.array([1 for i in range(0, 5)])
for action in RailEnvActions:
if info["status"][agent] == RailAgentStatus.ACTIVE:
legal_moves[int(action)] = int(env._check_action_on_agent(action, env.agents[agent])[-1])
action = agent007.act(agent_obs[agent], legal_moves)
action_count[action] += 1
else:
# An action is not required if the train hasn't joined the railway network,
# if it already reached its target, or if is currently malfunctioning.
# If that happens just execute DO_NOTHING.
update_values[agent] = False
action = 0
action_dict.update({agent: action})
next_obs, all_rewards, done, info = env.step(action_dict)
# Update replay buffer and train agent
for agent in env.get_agent_handles():
if update_values[agent] or done['__all__']:
# Only learn from timesteps where somethings happened
agent007.update_mem(agent_prev_obs[agent],
agent_prev_action[agent],
all_rewards[agent],
agent_obs[agent],
done[agent])
agent007.train()
agent_prev_obs[agent] = agent_obs[agent].copy()
agent_prev_action[agent] = action_dict[agent]
# Preprocess the new observations
if next_obs[agent] is not None:
norm_temp = normalize_temperature_observation(obs[agent][0]).flatten()
norm_tree = normalize_tree_observation(obs[agent][1], tree_depth, radius_observation)
agent_obs[agent] = np.concatenate((norm_temp, norm_tree))
scores += all_rewards[agent]
if done['__all__']:
break
# graphic informations
tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
if (step_counter < max_steps - 1):
completion = tasks_finished / max(1, env.get_num_agents())
normalized_score = scores / (max_steps * env.get_num_agents())
smoothed_normalized_score = smoothed_normalized_score * \
smoothing + normalized_score * (1.0 - smoothing)
smoothed_completion = smoothed_completion * \
smoothing + completion * (1.0 - smoothing)
action_probs = action_count / np.sum(action_count)
action_count = [1] * action_shape[0]
step_counter += 1
wandb.log({
"normalized_score": normalized_score,
"smoothed_normalized_score": smoothed_normalized_score,
"completion": 100*completion,
"smoothed_completion": 100*smoothed_completion
})
print(
'\r🚂 Episode {}'
'\t 🏆 Score: {:.3f}'
' Avg: {:.3f}'
'\t 💯 Done: {:.2f}%'
' Avg: {:.2f}%'
'\t 🔀 Action Probs: {}'
'\n'.format(
episode,
normalized_score,
smoothed_normalized_score,
100 * completion,
100 * smoothed_completion,
format_action_prob(action_probs)
), end=" ")
if (episode % saving_interval == 0):
agent007.save_model()
except KeyboardInterrupt:
print('Interrupted')
agent007.save_model()
try:
sys.exit(0)
except SystemExit:
os._exit(0)