-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotter_callback.py
51 lines (46 loc) · 1.67 KB
/
plotter_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import keras
import numpy as np
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
class Plotter(keras.callbacks.Callback):
def __init__(self, monitor, scale='linear', plot_during_train=True, save_to_file=None):
super().__init__()
if plt is None:
raise ValueError(
"Must be able to import Matplotlib to use the Plotter.")
self.scale = scale
self.monitor = monitor
self.plot_during_train = plot_during_train
self.save_to_file = save_to_file
plt.ion()
self.fig = plt.figure()
self.title = "{} per Epoch".format(self.monitor)
self.xlabel = "Epoch"
self.ylabel = self.monitor
self.ax = self.fig.add_subplot(111, title=self.title,
xlabel=self.xlabel, ylabel=self.ylabel)
self.ax.set_yscale(self.scale)
self.x = []
self.y_train = []
self.y_val = []
# self.ax.plot(self.x, self.y_train, 'b-', self.x, self.y_val, 'g-')
def on_train_end(self, logs={}):
# plt.ioff()
# plt.show()
return
def on_epoch_end(self, epoch, logs={}):
# self.line1.set_ydata(logs.get('loss'))
# self.line2.set_ydata(logs.get('val_loss'))
self.x.append(len(self.x))
self.y_train.append(logs.get(self.monitor))
self.y_val.append(logs.get("val_" + self.monitor))
self.ax.clear()
# Set up the plot
self.fig.suptitle(self.title)
self.ax.set_yscale(self.scale)
# Actually plot
self.ax.plot(self.x, self.y_train, 'b-', self.x, self.y_val, 'g-')
self.fig.canvas.draw()
return