Skip to content

Commit

Permalink
Add PER
Browse files Browse the repository at this point in the history
  • Loading branch information
maliesa96 committed Nov 6, 2020
1 parent cdda5fc commit 4e708fd
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 13 deletions.
16 changes: 12 additions & 4 deletions examples/torch/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from garage.envs.wrappers.stack_frames import StackFrames
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.replay_buffer import PERReplayBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.torch import set_gpu_mode
from garage.torch.algos import DQN
Expand All @@ -40,6 +40,9 @@
n_train_steps=125,
target_update_freq=2,
buffer_batch_size=32,
double_q=True,
per_beta_init=0.4,
per_alpha=0.6,
max_epsilon=1.0,
min_epsilon=0.01,
decay_ratio=0.1,
Expand Down Expand Up @@ -104,7 +107,7 @@ def main(env=None,


# pylint: disable=unused-argument
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30)
@wrap_experiment(snapshot_mode='none')
def dqn_atari(ctxt=None,
env=None,
seed=24,
Expand Down Expand Up @@ -150,8 +153,12 @@ def dqn_atari(ctxt=None,
steps_per_epoch = hyperparams['steps_per_epoch']
sampler_batch_size = hyperparams['sampler_batch_size']
num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
replay_buffer = PathBuffer(
capacity_in_transitions=hyperparams['buffer_size'])

replay_buffer = PERReplayBuffer(hyperparams['buffer_size'],
num_timesteps,
env.spec,
alpha=hyperparams['per_alpha'],
beta_init=hyperparams['per_beta_init'])

qf = DiscreteCNNQFunction(
env_spec=env.spec,
Expand Down Expand Up @@ -179,6 +186,7 @@ def dqn_atari(ctxt=None,
replay_buffer=replay_buffer,
steps_per_epoch=steps_per_epoch,
qf_lr=hyperparams['lr'],
double_q=hyperparams['double_q'],
clip_gradient=hyperparams['clip_gradient'],
discount=hyperparams['discount'],
min_buffer_size=hyperparams['min_buffer_size'],
Expand Down
1 change: 0 additions & 1 deletion src/garage/envs/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# entry points don't close their viewer windows.
KNOWN_GYM_NOT_CLOSE_VIEWER = [
# Please keep alphabetized
'gym.envs.atari',
'gym.envs.box2d',
'gym.envs.classic_control'
]
Expand Down
3 changes: 2 additions & 1 deletion src/garage/replay_buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from garage.replay_buffer.her_replay_buffer import HERReplayBuffer
from garage.replay_buffer.path_buffer import PathBuffer
from garage.replay_buffer.per_replay_buffer import PERReplayBuffer
from garage.replay_buffer.replay_buffer import ReplayBuffer

__all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']
__all__ = ['PERReplayBuffer', 'ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']
145 changes: 145 additions & 0 deletions src/garage/replay_buffer/per_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Prioritized Experience Replay."""

import numpy as np

from garage import StepType, TimeStepBatch
from garage.replay_buffer.path_buffer import PathBuffer


class PERReplayBuffer(PathBuffer):
"""Replay buffer for PER (Prioritized Experience Replay).
PER assigns priorities to transitions in the buffer. Typically
these priority of each transition is proportional to the corresponding
loss computed at each update step. The priorities are then used to create
a probability distribution when sampling such that higher priority
transitions are sampled more frequently. For more see
https://arxiv.org/abs/1511.05952.
Args:
capacity_in_transitions (int): total size of transitions in the buffer.
env_spec (EnvSpec): Environment specification.
total_timesteps (int): Total timesteps the experiment will run for.
This is used to calculate the beta parameter when sampling.
alpha (float): hyperparameter that controls the degree of
prioritization. Typically between [0, 1], where 0 corresponds to
no prioritization (uniform sampling).
beta_init (float): Initial value of beta exponent in importance
sampling. Beta is linearly annealed from beta_init to 1
over total_timesteps.
"""

def __init__(self,
capacity_in_transitions,
total_timesteps,
env_spec,
alpha=0.6,
beta_init=0.5):
self._alpha = alpha
self._beta_init = beta_init
self._total_timesteps = total_timesteps
self._curr_timestep = 0
self._priorities = np.zeros((capacity_in_transitions, ), np.float32)
self._rng = np.random.default_rng()
super().__init__(capacity_in_transitions, env_spec)

def sample_timesteps(self, batch_size):
"""Sample a batch of timesteps from the buffer.
Args:
batch_size (int): Number of timesteps to sample.
Returns:
TimeStepBatch: The batch of timesteps.
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.
"""
samples, w, idx = self.sample_transitions(batch_size)
step_types = np.array([
StepType.TERMINAL if terminal else StepType.MID
for terminal in samples['terminals'].reshape(-1)
],
dtype=StepType)
return TimeStepBatch(env_spec=self._env_spec,
observations=samples['observations'],
actions=samples['actions'],
rewards=samples['rewards'],
next_observations=samples['next_observations'],
step_types=step_types,
env_infos={},
agent_infos={}), w, idx

def sample_transitions(self, batch_size):
"""Sample a batch of transitions from the buffer.
Args:
batch_size (int): Number of transitions to sample.
Returns:
dict: A dict of arrays of shape (batch_size, flat_dim).
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.
"""
priorities = self._priorities
if self._transitions_stored < self._capacity:
priorities = self._priorities[:self._transitions_stored]
probs = priorities**self._alpha
probs /= probs.sum()
idx = self._rng.choice(self._transitions_stored,
size=batch_size,
p=probs)

beta = self._beta_init + self._curr_timestep * (
1.0 - self._beta_init) / self._total_timesteps
beta = min(1.0, beta)
transitions = {
key: buf_arr[idx]
for key, buf_arr in self._buffer.items()
}

w = (self._transitions_stored * probs[idx])**(-beta)
w /= w.max()
w = np.array(w)

return transitions, w, idx

def update_priorities(self, indices, priorities):
"""Update priorities of timesteps.
Args:
indices (np.ndarray): Array of indices corresponding to the
timesteps/priorities to update.
priorities (list[float]): new priorities to set.
"""
for idx, priority in zip(indices, priorities):
self._priorities[int(idx)] = priority

def add_path(self, path):
"""Add a path to the buffer.
This differs from the underlying buffer's add_path method
in that the priorities for the new samples are set to
the maximum of all priorities in the buffer.
Args:
path (dict): A dict of array of shape (path_len, flat_dim).
"""
path_len = len(path['observations'])
self._curr_timestep += path_len

# find the indices where the path will be stored
first_seg, second_seg = self._next_path_segments(path_len)

# set priorities for new timesteps = max(self._priorities)
# or 1 if buffer is empty
max_priority = self._priorities.max() or 1.
self._priorities[list(first_seg)] = max_priority
if second_seg != range(0, 0):
self._priorities[list(second_seg)] = max_priority
super().add_path(path)
39 changes: 32 additions & 7 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from garage import _Default, log_performance, make_optimizer
from garage._functions import obtain_evaluation_episodes
from garage.np.algos import RLAlgorithm
from garage.replay_buffer import PERReplayBuffer
from garage.sampler import FragmentWorker
from garage.torch import global_device, np_to_torch
from garage.torch import global_device, np_to_torch, torch_to_np


class DQN(RLAlgorithm):
Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(
self._qf_optimizer = make_optimizer(qf_optimizer,
module=self._qf,
lr=qf_lr)

self._prioritized_replay = isinstance(self.replay_buffer,
PERReplayBuffer)
self._eval_env = eval_env

def train(self, trainer):
Expand Down Expand Up @@ -192,10 +196,18 @@ def _train_once(self, itr, episodes):
for _ in range(self._n_train_steps):
if (self.replay_buffer.n_transitions_stored >=
self._min_buffer_size):
timesteps = self.replay_buffer.sample_timesteps(
self._buffer_batch_size)
qf_loss, y, q = tuple(v.cpu().numpy()
for v in self._optimize_qf(timesteps))
if self._prioritized_replay:
timesteps, weights, indices = (
self.replay_buffer.sample_timesteps(
self._buffer_batch_size))
qf_loss, y, q = tuple(v.cpu().numpy()
for v in self._optimize_qf(
timesteps, weights, indices))
else:
timesteps = self.replay_buffer.sample_timesteps(
self._buffer_batch_size)
qf_loss, y, q = tuple(
v.cpu().numpy() for v in self._optimize_qf(timesteps))

self._episode_qf_losses.append(qf_loss)
self._epoch_ys.append(y)
Expand Down Expand Up @@ -228,11 +240,16 @@ def _log_eval_results(self, epoch):
tabular.record('QFunction/AverageAbsY',
np.mean(np.abs(self._epoch_ys)))

def _optimize_qf(self, timesteps):
def _optimize_qf(self, timesteps, weights=None, indices=None):
"""Perform algorithm optimizing.
Args:
timesteps (TimeStepBatch): Processed batch data.
weights (np.ndarray[float]): Weights used by PER when updating
the network. Should be None if PER is not being used.
indices (list[int or float]): Indices of the sampled
timesteps in the replay buffer. Should be None
if PER is not being used.
Returns:
qval_loss: Loss of Q-value predicted by the Q-network.
Expand Down Expand Up @@ -274,7 +291,15 @@ def _optimize_qf(self, timesteps):
# optimize qf
qvals = self._qf(inputs)
selected_qs = torch.sum(qvals * actions, axis=1)
qval_loss = F.smooth_l1_loss(selected_qs, y_target)
qval_loss = F.smooth_l1_loss(selected_qs, y_target, reduction='none')

if self._prioritized_replay:
qval_loss *= np_to_torch(weights)
priorities = qval_loss + 1e-5 # offset to avoid 0 priorities
priorities = torch_to_np(priorities.data.cpu())
self.replay_buffer.update_priorities(indices, priorities)

qval_loss = qval_loss.mean()

self._qf_optimizer.zero_grad()
qval_loss.backward()
Expand Down
Loading

0 comments on commit 4e708fd

Please sign in to comment.