forked from kosoraYintai/PARL-Sample
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainTestMaze.py
169 lines (153 loc) · 4.62 KB
/
TrainTestMaze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#coding:UTF-8
import numpy as np
from maze_unionFind.MazeAgent import MazeAgent
from maze_unionFind.MazeModel import MazeModel
from rpm.FcPolicyReplayMemory import Experience,FcPolicyReplayMemory
from parl.algorithms.dqn import DQN
from tqdm import tqdm
from maze_unionFind.MazeEnv import MazeEnv,meetWall
import time
import matplotlib.pyplot as plt
#查看最优路径
def seeBestPath(env,path):
maze=np.zeros((env.row,env.col))
for i in range(0,env.row):
for j in range(0,env.col):
pos=np.array([i,j])
if meetWall(env.wallList, pos):
maze[i][j]=1
for pos in path:
maze[pos[0]][pos[1]]=2
for i in range(0,env.row):
for j in range(0,env.col):
if maze[i][j]==0:
print('O',end=' ')
elif maze[i][j]==2:
print('●',end=' ')
else:
print('X',end=' ')
print()
#状态的维度
StateShape=(2,)
#------hyper parameters start
#以下超参数均可微调
#需要训练几行几列的迷宫
mazeRow=5
mazeCol=6
#经验池大小
MEMORY_SIZE = int(1e4)
#warm-up
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 4
#网络学习频率
UPDATE_FREQ = 2
#衰减系数
GAMMA = 0.99
#学习率
LEARNING_RATE = 1e-3
#一共训练多少步
TOTAL=1e5
#batch-size
batchSize=64
#------hyper parameters end
#平均奖励
meanReward=0
#记录训练的episode轮数
trainEp=0
#记录日志的频率
logFreq=10
#学习曲线 记录平均奖励
learning_curve=[]
def run_train_episode(env, agent, rpm):
global trainEp
global meanReward
global learning_curve
total_reward = 0
all_cost = []
state= env.reset()
step = 0
trainFlag=False
while True:
step += 1
action = agent.sample(state)
next_state, reward, isOver,_ = env.step(action)
rpm.append(Experience(state, action, reward, isOver,next_state))
if rpm.size() > MEMORY_WARMUP_SIZE:
trainFlag=True
if step % UPDATE_FREQ == 0:
batch_state, batch_action, batch_reward, batch_isOver,batch_next_state = rpm.sample_batch(
batchSize)
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if trainFlag :
trainEp+=1
meanReward=meanReward+(total_reward-meanReward)/trainEp
if trainEp%logFreq==0:
learning_curve.append(meanReward)
print('\n trainEpisode:{} total_reward: {:.3f}, meanReward:{:.3f} mean_cost: {:.3f}'.format(trainEp,total_reward, meanReward,np.mean(all_cost)))
return total_reward, step
def run_test_episode(env, agent, rpm):
total_reward = 0
state= env.reset()
path=[]
path.append(state.copy())
while True:
action = agent.sample(state)
next_state, reward, isOver,_ = env.step(action)
total_reward += reward
state = next_state
path.append(state.copy())
if isOver:
break
seeBestPath(env, path)
def trainTest():
env = MazeEnv(m=mazeRow,n=mazeCol)
time.sleep(5)
rpm = FcPolicyReplayMemory(max_size=MEMORY_SIZE, state_shape=StateShape)
action_dim = 4
hyperparas = {
'action_dim': action_dim,
'lr': LEARNING_RATE,
'gamma': GAMMA
}
model = MazeModel(act_dim=action_dim)
algorithm = DQN(model, hyperparas)
agent = MazeAgent(algorithm, action_dim)
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
__, step = run_train_episode(env, agent, rpm)
pbar.update(step)
print()
print('开始训练!')
total_step = 0
with tqdm(total=TOTAL) as pbar:
while total_step <= TOTAL:
__, step = run_train_episode(env, agent, rpm)
total_step += step
if trainEp%logFreq==0:
pbar.set_description('totalStep:{},exploration:{:.3f}'.format(total_step,agent.exploration))
pbar.update(step)
print()
print('训练完毕\n')
time.sleep(5)
print("随机生成迷宫:")
print()
env.render()
print()
print("显示最优路径:")
print()
run_test_episode(env, agent,rpm)
#平均奖励的学习曲线
X=np.arange(0,len(learning_curve))
X*=logFreq
plt.title('LearningCurve')
plt.xlabel('TrainEpisode')
plt.ylabel('AvgReward')
plt.plot(X,learning_curve)
plt.show()
if __name__ == '__main__':
trainTest()