Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robert-lieck committed Jan 20, 2025
1 parent f434410 commit 7048dda
Showing 1 changed file with 62 additions and 56 deletions.
118 changes: 62 additions & 56 deletions rldurham/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class Recorder(gym.Wrapper, gym.utils.RecordConstructorArgs):
# see RecordEpisodeStatistics for inspiration

def __init__(self, env, info=True, video=False, logs=False, key="recorder",
video_folder="videos", name_prefix="xxxx00-agent-video",
video_folder="videos", video_prefix="xxxx00-agent-video",
full_stats=False, smoothing=None):
gym.utils.RecordConstructorArgs.__init__(self)
if video:
env = VideoRecorder(env,
video_folder=video_folder,
name_prefix=name_prefix,
name_prefix=video_prefix,
episode_trigger=self._video_episode_trigger,
name_func=self._video_name_func)
gym.Wrapper.__init__(self, env)
Expand Down Expand Up @@ -223,10 +223,10 @@ def write_log(self, folder="logs", file="xxxx00-agent-log.txt"):
os.makedirs(folder)
path = os.path.join(folder, file)
df = pd.DataFrame({
"episode_count": self._episode_count_log,
"episode_reward_sum": self._episode_reward_sum_log,
"episode_squared_reward_sum": self._episode_squared_reward_sum_log,
"episode_length": self._episode_length_log,
"count": self._episode_count_log,
"reward_sum": self._episode_reward_sum_log,
"squared_reward_sum": self._episode_squared_reward_sum_log,
"length": self._episode_length_log,
})
df.to_csv(path, index=False, sep='\t')

Expand Down Expand Up @@ -268,54 +268,60 @@ def _update(self, new_info, tracked_info):
else:
tracked_info[k].append(new_info[k])

def plot(
self, show=True, ax=None, key='recorder', ignore_empty=True,
length=False, r_sum=False, r_mean=False, r_std=False,
length_=False, r_sum_=False, r_mean_=False, r_std_=False
):
if not self.info and ignore_empty:
return
fig = None
if ax is None:
fig, ax = plt.subplots(1, 1)

def plot_tracked_info(tracked_info, show=True, ax=None, key='recorder', ignore_empty=True,
length=False, r_sum=False, r_mean=False, r_std=False,
length_=False, r_sum_=False, r_mean_=False, r_std_=False,
):
if not tracked_info and ignore_empty:
return
fig = None
if ax is None:
fig, ax = plt.subplots(1, 1)

def get_kwargs(flag, **kwargs):
if isinstance(flag, dict):
return {**kwargs, **flag}
else:
return kwargs
idx = tracked_info[key]['idx']

_length = tracked_info[key]['length']
_r_sum = np.array(tracked_info[key]['r_sum'])
_r_mean = np.array(tracked_info[key]['r_mean'])
_r_std = np.array(tracked_info[key]['r_std'])
if length:
ax.plot(idx, _length, **get_kwargs(length, label='length'))
if r_sum:
ax.plot(idx, _r_sum, **get_kwargs(r_sum, label='r_sum'))
if r_mean:
ax.plot(idx, _r_mean, **get_kwargs(r_mean, label='r_mean'))
if r_std:
ax.fill_between(idx, _r_mean - _r_std, _r_mean + _r_std, **get_kwargs(r_std, label='r_std', alpha=0.2, color='tab:grey'))

_length_ = tracked_info[key]['length_']
_r_sum_ = np.array(tracked_info[key]['r_sum_'])
_r_mean_ = np.array(tracked_info[key]['r_mean_'])
_r_std_ = np.array(tracked_info[key]['r_std_'])
if length_:
ax.plot(idx, _length_, **get_kwargs(length_, label='length_'))
if r_sum_:
ax.plot(idx, _r_sum_, **get_kwargs(r_sum_, label='r_sum_'))
if r_mean_:
ax.plot(idx, _r_mean_, **get_kwargs(r_mean_, label='r_mean_'))
if r_std_:
ax.fill_between(idx, _r_mean_ - _r_std_, _r_mean_ + _r_std_, **get_kwargs(r_std_, label='r_std_', alpha=0.2, color='tab:grey'))

ax.set_xlabel('episode index')
ax.legend()
if show:
plt.show()
disp.clear_output(wait=True)
if fig is not None:
return fig, ax
def get_kwargs(flag, **kwargs):
if isinstance(flag, dict):
return {**kwargs, **flag}
else:
return kwargs
idx = self.info[key]['idx']

if length:
_length = self.info[key]['length']
ax.plot(idx, _length, **get_kwargs(length, label='length'))
if r_sum:
_r_sum = np.array(self.info[key]['r_sum'])
ax.plot(idx, _r_sum, **get_kwargs(r_sum, label='r_sum'))
_r_mean = None
if r_mean:
_r_mean = np.array(self.info[key]['r_mean'])
ax.plot(idx, _r_mean, **get_kwargs(r_mean, label='r_mean'))
if r_std:
if _r_mean is None:
_r_mean = np.array(self.info[key]['r_mean'])
_r_std = np.array(self.info[key]['r_std'])
ax.fill_between(idx, _r_mean - _r_std, _r_mean + _r_std, **get_kwargs(r_std, label='r_std', alpha=0.2, color='tab:grey'))

if length_:
_length_ = self.info[key]['length_']
ax.plot(idx, _length_, **get_kwargs(length_, label='length_'))
if r_sum_:
_r_sum_ = np.array(self.info[key]['r_sum_'])
ax.plot(idx, _r_sum_, **get_kwargs(r_sum_, label='r_sum_'))
_r_mean_ = None
if r_mean_:
_r_mean_ = np.array(self.info[key]['r_mean_'])
ax.plot(idx, _r_mean_, **get_kwargs(r_mean_, label='r_mean_'))
if r_std_:
if _r_mean_ is None:
_r_mean_ = np.array(self.info[key]['r_mean_'])
_r_std_ = np.array(self.info[key]['r_std_'])
ax.fill_between(idx, _r_mean_ - _r_std_, _r_mean_ + _r_std_, **get_kwargs(r_std_, label='r_std_', alpha=0.2, color='tab:grey'))

ax.set_xlabel('episode index')
ax.legend()
if show:
plt.show()
disp.clear_output(wait=True)
if fig is not None:
return fig, ax

0 comments on commit 7048dda

Please sign in to comment.