Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PER #2159

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Add PER #2159

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions examples/torch/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from garage.envs.wrappers.stack_frames import StackFrames
from garage.experiment.deterministic import set_seed
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.replay_buffer import PathBuffer
from garage.replay_buffer import PERReplayBuffer
from garage.sampler import FragmentWorker, LocalSampler
from garage.torch import set_gpu_mode
from garage.torch.algos import DQN
Expand All @@ -40,6 +40,9 @@
n_train_steps=125,
target_update_freq=2,
buffer_batch_size=32,
double_q=True,
per_beta_init=0.4,
per_alpha=0.6,
max_epsilon=1.0,
min_epsilon=0.01,
decay_ratio=0.1,
Expand Down Expand Up @@ -104,7 +107,7 @@ def main(env=None,


# pylint: disable=unused-argument
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30)
@wrap_experiment(snapshot_mode='none')
def dqn_atari(ctxt=None,
env=None,
seed=24,
Expand Down Expand Up @@ -150,8 +153,12 @@ def dqn_atari(ctxt=None,
steps_per_epoch = hyperparams['steps_per_epoch']
sampler_batch_size = hyperparams['sampler_batch_size']
num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
replay_buffer = PathBuffer(
capacity_in_transitions=hyperparams['buffer_size'])

replay_buffer = PERReplayBuffer(hyperparams['buffer_size'],
num_timesteps,
env.spec,
alpha=hyperparams['per_alpha'],
beta_init=hyperparams['per_beta_init'])

qf = DiscreteCNNQFunction(
env_spec=env.spec,
Expand Down Expand Up @@ -179,6 +186,7 @@ def dqn_atari(ctxt=None,
replay_buffer=replay_buffer,
steps_per_epoch=steps_per_epoch,
qf_lr=hyperparams['lr'],
double_q=hyperparams['double_q'],
clip_gradient=hyperparams['clip_gradient'],
discount=hyperparams['discount'],
min_buffer_size=hyperparams['min_buffer_size'],
Expand Down
1 change: 0 additions & 1 deletion src/garage/envs/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# entry points don't close their viewer windows.
KNOWN_GYM_NOT_CLOSE_VIEWER = [
# Please keep alphabetized
'gym.envs.atari',
'gym.envs.box2d',
'gym.envs.classic_control'
]
Expand Down
3 changes: 2 additions & 1 deletion src/garage/replay_buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from garage.replay_buffer.her_replay_buffer import HERReplayBuffer
from garage.replay_buffer.path_buffer import PathBuffer
from garage.replay_buffer.per_replay_buffer import PERReplayBuffer
from garage.replay_buffer.replay_buffer import ReplayBuffer

__all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']
__all__ = ['PERReplayBuffer', 'ReplayBuffer', 'HERReplayBuffer', 'PathBuffer']
14 changes: 11 additions & 3 deletions src/garage/replay_buffer/path_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,15 @@ def sample_transitions(self, batch_size):

Returns:
dict: A dict of arrays of shape (batch_size, flat_dim).
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.

"""
idx = np.random.randint(self._transitions_stored, size=batch_size)
return {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}
w = np.ones(batch_size)
data = {key: buf_arr[idx] for key, buf_arr in self._buffer.items()}
return data, w, idx

def sample_timesteps(self, batch_size):
"""Sample a batch of timesteps from the buffer.
Expand All @@ -132,9 +137,12 @@ def sample_timesteps(self, batch_size):

Returns:
TimeStepBatch: The batch of timesteps.
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.

"""
samples = self.sample_transitions(batch_size)
samples, w, idx = self.sample_transitions(batch_size)
step_types = np.array([
StepType.TERMINAL if terminal else StepType.MID
for terminal in samples['terminals'].reshape(-1)
Expand All @@ -147,7 +155,7 @@ def sample_timesteps(self, batch_size):
next_observations=samples['next_observations'],
step_types=step_types,
env_infos={},
agent_infos={})
agent_infos={}), w, idx

def _next_path_segments(self, n_indices):
"""Compute where the next path should be stored.
Expand Down
145 changes: 145 additions & 0 deletions src/garage/replay_buffer/per_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Prioritized Experience Replay."""

import numpy as np

from garage import StepType, TimeStepBatch
from garage.replay_buffer.path_buffer import PathBuffer


class PERReplayBuffer(PathBuffer):
"""Replay buffer for PER (Prioritized Experience Replay).

PER assigns priorities to transitions in the buffer. Typically
these priority of each transition is proportional to the corresponding
loss computed at each update step. The priorities are then used to create
a probability distribution when sampling such that higher priority
transitions are sampled more frequently. For more see
https://arxiv.org/abs/1511.05952.

Args:
capacity_in_transitions (int): total size of transitions in the buffer.
env_spec (EnvSpec): Environment specification.
total_timesteps (int): Total timesteps the experiment will run for.
This is used to calculate the beta parameter when sampling.
alpha (float): hyperparameter that controls the degree of
prioritization. Typically between [0, 1], where 0 corresponds to
no prioritization (uniform sampling).
beta_init (float): Initial value of beta exponent in importance
sampling. Beta is linearly annealed from beta_init to 1
over total_timesteps.
"""

def __init__(self,
capacity_in_transitions,
total_timesteps,
env_spec,
alpha=0.6,
beta_init=0.5):
self._alpha = alpha
self._beta_init = beta_init
self._total_timesteps = total_timesteps
self._curr_timestep = 0
self._priorities = np.zeros((capacity_in_transitions, ), np.float32)
self._rng = np.random.default_rng()
super().__init__(capacity_in_transitions, env_spec)

def sample_timesteps(self, batch_size):
"""Sample a batch of timesteps from the buffer.

Args:
batch_size (int): Number of timesteps to sample.

Returns:
TimeStepBatch: The batch of timesteps.
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.

"""
samples, w, idx = self.sample_transitions(batch_size)
step_types = np.array([
StepType.TERMINAL if terminal else StepType.MID
for terminal in samples['terminals'].reshape(-1)
],
dtype=StepType)
return TimeStepBatch(env_spec=self._env_spec,
observations=samples['observations'],
actions=samples['actions'],
rewards=samples['rewards'],
next_observations=samples['next_observations'],
step_types=step_types,
env_infos={},
agent_infos={}), w, idx

def sample_transitions(self, batch_size):
"""Sample a batch of transitions from the buffer.

Args:
batch_size (int): Number of transitions to sample.

Returns:
dict: A dict of arrays of shape (batch_size, flat_dim).
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.

"""
priorities = self._priorities
if self._transitions_stored < self._capacity:
priorities = self._priorities[:self._transitions_stored]
probs = priorities**self._alpha
probs /= probs.sum()
idx = self._rng.choice(self._transitions_stored,
size=batch_size,
p=probs)

beta = self._beta_init + self._curr_timestep * (
1.0 - self._beta_init) / self._total_timesteps
beta = min(1.0, beta)
transitions = {
key: buf_arr[idx]
for key, buf_arr in self._buffer.items()
}

w = (self._transitions_stored * probs[idx])**(-beta)
w /= w.max()
w = np.array(w)

return transitions, w, idx

def update_priorities(self, indices, priorities):
"""Update priorities of timesteps.

Args:
indices (np.ndarray): Array of indices corresponding to the
timesteps/priorities to update.
priorities (list[float]): new priorities to set.

"""
for idx, priority in zip(indices, priorities):
self._priorities[int(idx)] = priority

def add_path(self, path):
"""Add a path to the buffer.

This differs from the underlying buffer's add_path method
in that the priorities for the new samples are set to
the maximum of all priorities in the buffer.

Args:
path (dict): A dict of array of shape (path_len, flat_dim).

"""
path_len = len(path['observations'])
self._curr_timestep += path_len

# find the indices where the path will be stored
first_seg, second_seg = self._next_path_segments(path_len)

# set priorities for new timesteps = max(self._priorities)
# or 1 if buffer is empty
max_priority = self._priorities.max() or 1.
self._priorities[first_seg.start:first_seg.stop] = max_priority
if second_seg != range(0, 0):
self._priorities[second_seg.start:second_seg.stop] = max_priority
super().add_path(path)
3 changes: 3 additions & 0 deletions src/garage/replay_buffer/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def sample(self, batch_size):

Args:
batch_size(int): The number of transitions to be sampled.
np.ndarray: Weights of the timesteps.
np.ndarray: Indices of sampled timesteps
in the replay buffer.

"""
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion src/garage/tf/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _optimize_policy(self):
float: Q value predicted by the q network.

"""
timesteps = self._replay_buffer.sample_timesteps(
timesteps, _, _ = self._replay_buffer.sample_timesteps(
self._buffer_batch_size)

observations = timesteps.observations
Expand Down
2 changes: 1 addition & 1 deletion src/garage/tf/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _optimize_policy(self):
numpy.float64: Loss of policy.

"""
timesteps = self._replay_buffer.sample_timesteps(
timesteps, _, _ = self._replay_buffer.sample_timesteps(
self._buffer_batch_size)

observations = timesteps.observations
Expand Down
2 changes: 1 addition & 1 deletion src/garage/tf/algos/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _optimize_policy(self, itr):
float: Q value predicted by the q network.

"""
timesteps = self._replay_buffer.sample_timesteps(
timesteps, _, _ = self._replay_buffer.sample_timesteps(
self._buffer_batch_size)

observations = timesteps.observations
Expand Down
6 changes: 2 additions & 4 deletions src/garage/torch/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import numpy as np
import torch

from garage import (_Default,
log_performance,
make_optimizer,
from garage import (_Default, log_performance, make_optimizer,
obtain_evaluation_episodes)
from garage.np.algos import RLAlgorithm
from garage.sampler import FragmentWorker, LocalSampler
Expand Down Expand Up @@ -188,7 +186,7 @@ def train_once(self, itr, episodes):
for _ in range(self._n_train_steps):
if (self.replay_buffer.n_transitions_stored >=
self._min_buffer_size):
samples = self.replay_buffer.sample_transitions(
samples, _, _ = self.replay_buffer.sample_transitions(
self._buffer_batch_size)
samples['rewards'] *= self._reward_scale
qf_loss, y, q, policy_loss = torch_to_np(
Expand Down
32 changes: 25 additions & 7 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from garage import _Default, log_performance, make_optimizer
from garage._functions import obtain_evaluation_episodes
from garage.np.algos import RLAlgorithm
from garage.replay_buffer import PERReplayBuffer
from garage.sampler import FragmentWorker
from garage.torch import global_device, np_to_torch
from garage.torch import global_device, np_to_torch, torch_to_np


class DQN(RLAlgorithm):
Expand Down Expand Up @@ -122,6 +123,9 @@
self._qf_optimizer = make_optimizer(qf_optimizer,
module=self._qf,
lr=qf_lr)

self._prioritized_replay = isinstance(self.replay_buffer,
PERReplayBuffer)
self._eval_env = eval_env

def train(self, trainer):
Expand Down Expand Up @@ -192,10 +196,12 @@
for _ in range(self._n_train_steps):
if (self.replay_buffer.n_transitions_stored >=
self._min_buffer_size):
timesteps = self.replay_buffer.sample_timesteps(
self._buffer_batch_size)
qf_loss, y, q = tuple(v.cpu().numpy()
for v in self._optimize_qf(timesteps))
timesteps, weights, indices = (
self.replay_buffer.sample_timesteps(
self._buffer_batch_size))
qf_loss, y, q = tuple(
v.cpu().numpy()
for v in self._optimize_qf(timesteps, weights, indices))

self._episode_qf_losses.append(qf_loss)
self._epoch_ys.append(y)
Expand Down Expand Up @@ -228,11 +234,15 @@
tabular.record('QFunction/AverageAbsY',
np.mean(np.abs(self._epoch_ys)))

def _optimize_qf(self, timesteps):
def _optimize_qf(self, timesteps, weights=None, indices=None):
"""Perform algorithm optimizing.

Args:
timesteps (TimeStepBatch): Processed batch data.
weights (np.ndarray[float]): Weights used by PER when updating
the network.
indices (list[int or float]): Indices of the sampled
timesteps in the replay buffer.

Returns:
qval_loss: Loss of Q-value predicted by the Q-network.
Expand Down Expand Up @@ -274,7 +284,15 @@
# optimize qf
qvals = self._qf(inputs)
selected_qs = torch.sum(qvals * actions, axis=1)
qval_loss = F.smooth_l1_loss(selected_qs, y_target)
qval_loss = F.smooth_l1_loss(selected_qs, y_target, reduction='none')

if self._prioritized_replay:
qval_loss *= np_to_torch(weights)
priorities = qval_loss + 1e-5 # offset to avoid 0 priorities
priorities = torch_to_np(priorities.data.cpu())
self.replay_buffer.update_priorities(indices, priorities)

Check warning on line 293 in src/garage/torch/algos/dqn.py

View check run for this annotation

Codecov / codecov/patch

src/garage/torch/algos/dqn.py#L290-L293

Added lines #L290 - L293 were not covered by tests

qval_loss = qval_loss.mean()

self._qf_optimizer.zero_grad()
qval_loss.backward()
Expand Down
Loading