-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_training_curve.py
68 lines (52 loc) · 1.83 KB
/
plot_training_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorboard.backend.event_processing import event_accumulator
def read_data(load_dir, tag="perf/avg_reward_100"):
events = os.listdir(load_dir)
for event in events:
path = os.path.join(load_dir, event)
ea = event_accumulator.EventAccumulator(path, size_guidance={
event_accumulator.COMPRESSED_HISTOGRAMS: 0,
event_accumulator.IMAGES: 0,
event_accumulator.AUDIO: 0,
event_accumulator.SCALARS: 2500,
event_accumulator.HISTOGRAMS: 0,
})
ea.Reload()
tags = ea.Tags()
if tag not in tags["scalars"]: continue
if len(ea.Scalars(tag)) == 2500:
return np.array([s.value for s in ea.Scalars(tag)])
return None
def plot_rewards_curve(save_path,
load_path_lstm,
n_seeds=8,
n_workers=8,
):
lstm_data = np.zeros((n_seeds, 2500))
count = 0
for seed_idx in tqdm(range(n_seeds)):
lstm_workers = []
for worker in range(n_workers):
lstm_event = read_data(load_dir=load_path_lstm+f"_{seed_idx+1}_{worker}")
if lstm_event is not None:
lstm_workers += [lstm_event]
else:
count += 1
lstm_data[seed_idx] = np.array(lstm_workers).mean(axis=0)
data = []
for seed_idx in range(n_seeds):
for i in range(2500):
data += [{'Episode': i, 'Reward': lstm_data[seed_idx][i], "RNN Type": "LSTM"}]
df = pd.DataFrame(data)
sns.lineplot(x="Episode", y="Reward", data=df, ci="sd")
plt.show()
if __name__ == "__main__":
plot_rewards_curve(
'./harlow_final_training.png',
'./logs_final/Harlow_Final_LSTM',
)