Skip to content

Commit

Permalink
Add doc and fixes for policy kwargs (#165)
Browse files Browse the repository at this point in the history
* Update tests + changelog

* Add example + kwargs check + fix for TRPO

* Fix DQN examples + fix for codacy

* [ci skip] Add comments

* Add a note about custom class vs policy kwargs

* The activation function can now be customized for DQN, DDPG and SAC
  • Loading branch information
araffin authored Jan 17, 2019
1 parent 88a5c5d commit 396e0fa
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 51 deletions.
36 changes: 35 additions & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,41 @@ Custom Policy Network

Stable baselines provides default policy networks (see :ref:`Policies <policies>` ) for images (CNNPolicies)
and other type of input features (MlpPolicies).
However, you can also easily define a custom architecture for the policy (or value) network:

One way of customising the policy network architecture is to pass arguments when creating the model,
using ``policy_kwargs`` parameter:

.. code-block:: python
import gym
import tensorflow as tf
from stable_baselines import PPO2
# Custom MLP policy of two layers of size 32 each with tanh activation function
policy_kwargs = dict(act_fun=tf.nn.tanh, net_arch=[32, 32])
# Create the agent
model = PPO2("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=100000)
# Save the agent
model.save("ppo2-cartpole")
del model
# the policy_kwargs are automatically loaded
model = PPO2.load("ppo2-cartpole")
You can also easily define a custom architecture for the policy (or value) network:

.. note::

Defining a custom policy class is equivalent to passing ``policy_kwargs``. However,
it lets you name the policy and so makes usually the code clearer. ``policy_kwargs`` should be rather used
when doing hyperparameter search.


.. code-block:: python
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Pre-Release 2.4.0a (WIP)
- added more flexible custom LSTM policies
- added auto entropy coefficient optimization for SAC
- clip continuous actions at test time too for all algorithms (except SAC/DDPG where it is not needed)
- added a mean to pass kwargs to policy when creating a model (+ save those kwargs)
- fixed DQN examples in DQN folder
- added possibility to pass activation function for DDPG, DQN and SAC


Release 2.3.0 (2018-12-05)
Expand Down
6 changes: 3 additions & 3 deletions docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Policy Networks
===============

Stable-baselines provides a set of default policies, that can be used with most action spaces.
To customize the default policies, you can specify the `policy_kwargs` parameter to the model class you use.
Those kwargs are then passed to the policy on instantiation.
If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`).
To customize the default policies, you can specify the ``policy_kwargs`` parameter to the model class you use.
Those kwargs are then passed to the policy on instantiation (see :ref:`custom_policy` for an example).
If you need more control on the policy architecture, you can also create a custom policy (see :ref:`custom_policy`).

.. note::

Expand Down
27 changes: 25 additions & 2 deletions stable_baselines/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
self.ob_space = ob_space
self.ac_space = ac_space

@staticmethod
def _kwargs_check(feature_extraction, kwargs):
"""
Ensure that the user is not passing wrong keywords
when using policy_kwargs.
:param feature_extraction: (str)
:param kwargs: (dict)
"""
# When using policy_kwargs parameter on model creation,
# all keywords arguments must be consumed by the policy constructor except
# the ones for the cnn_extractor network (cf nature_cnn()), where the keywords arguments
# are not passed explicitely (using **kwargs to forward the arguments)
# that's why there should be not kwargs left when using the mlp_extractor
# (in that case the keywords arguments are passed explicitely)
if feature_extraction == 'mlp' and len(kwargs) > 0:
raise ValueError("Unknown keywords for policy: {}".format(kwargs))

def step(self, obs, state=None, mask=None):
"""
Returns the policy for a single step
Expand Down Expand Up @@ -243,6 +261,7 @@ class LstmPolicy(ActorCriticPolicy):
:param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64])
:param net_arch: (list) Specification of the actor-critic policy network architecture. Notation similar to the
format described in mlp_extractor but with additional support for a 'lstm' entry in the shared network part.
:param act_fun: (tf.func) the activation function to use in the neural network.
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param layer_norm: (bool) Whether or not to use layer normalizing LSTMs
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
Expand All @@ -255,6 +274,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
scale=(feature_extraction == "cnn"))

self._kwargs_check(feature_extraction, kwargs)

with tf.variable_scope("input", reuse=True):
self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph") # mask (done t-1)
# n_lstm * 2 dim because of the cell and hidden states of the LSTM
Expand Down Expand Up @@ -338,7 +359,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256
for idx, vf_layer_size in enumerate(value_only_layers):
if vf_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the value function "
"network.")
"network.")
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
latent_value = act_fun(
linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))
Expand Down Expand Up @@ -383,7 +404,7 @@ class FeedForwardPolicy(ActorCriticPolicy):
(if None, default to [64, 64])
:param net_arch: (list) Specification of the actor-critic policy network architecture (see mlp_extractor
documentation for details).
:param act_fun: the activation function to use in the neural network.
:param act_fun: (tf.func) the activation function to use in the neural network.
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
Expand All @@ -394,6 +415,8 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
scale=(feature_extraction == "cnn"))

self._kwargs_check(feature_extraction, kwargs)

if layers is not None:
warnings.warn("Usage of the `layers` parameter is deprecated! Use net_arch instead "
"(it has a different semantics though).", DeprecationWarning)
Expand Down
8 changes: 6 additions & 2 deletions stable_baselines/ddpg/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,17 @@ class FeedForwardPolicy(DDPGPolicy):
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param layer_norm: (bool) enable layer normalisation
:param act_fun: (tf.func) the activation function to use in the neural network.
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None,
cnn_extractor=nature_cnn, feature_extraction="cnn", layer_norm=False, **kwargs):
cnn_extractor=nature_cnn, feature_extraction="cnn",
layer_norm=False, act_fun=tf.nn.relu, **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
scale=(feature_extraction == "cnn"))

self._kwargs_check(feature_extraction, kwargs)
self.layer_norm = layer_norm
self.feature_extraction = feature_extraction
self.cnn_kwargs = kwargs
Expand All @@ -119,7 +123,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals

assert len(layers) >= 1, "Error: must have at least one hidden layer for the policy."

self.activ = tf.nn.relu
self.activ = act_fun

def make_actor(self, obs=None, reuse=False, scope="pi"):
if obs is None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/experiments/custom_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class CustomPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
net_arch=[dict(vf=[64], pi=[64])],
layers=[64],
feature_extraction="mlp")


Expand Down
14 changes: 3 additions & 11 deletions stable_baselines/deepq/experiments/train_mountaincar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
import gym

from stable_baselines.deepq import DQN
from stable_baselines.deepq.policies import FeedForwardPolicy


class CustomPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
net_arch=[dict(pi=[64], vf=[64])],
layer_norm=True,
feature_extraction="mlp")


def main(args):
Expand All @@ -24,13 +15,14 @@ def main(args):

# using layer norm policy here is important for parameter space noise!
model = DQN(
policy=CustomPolicy,
policy="LnMlpPolicy",
env=env,
learning_rate=1e-3,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.1,
param_noise=True
param_noise=True,
policy_kwargs=dict(layers=[64])
)
model.learn(total_timesteps=args.max_timesteps)

Expand Down
11 changes: 7 additions & 4 deletions stable_baselines/deepq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ class FeedForwardPolicy(DQNPolicy):
and the processed observation placeholder respectivly
:param layer_norm: (bool) enable layer normalisation
:param dueling: (bool) if true double the output MLP to compute a baseline for action scores
:param act_fun: (tf.func) the activation function to use in the neural network.
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""

def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None,
cnn_extractor=nature_cnn, feature_extraction="cnn",
obs_phs=None, layer_norm=False, dueling=True, **kwargs):
obs_phs=None, layer_norm=False, dueling=True, act_fun=tf.nn.relu, **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps,
n_batch, dueling=dueling, reuse=reuse,
scale=(feature_extraction == "cnn"), obs_phs=obs_phs)

self._kwargs_check(feature_extraction, kwargs)

if layers is None:
layers = [64, 64]

Expand All @@ -102,14 +106,13 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
action_out = extracted_features
else:
activ = tf.nn.relu
extracted_features = tf.layers.flatten(self.processed_obs)
action_out = extracted_features
for layer_size in layers:
action_out = tf_layers.fully_connected(action_out, num_outputs=layer_size, activation_fn=None)
if layer_norm:
action_out = tf_layers.layer_norm(action_out, center=True, scale=True)
action_out = activ(action_out)
action_out = act_fun(action_out)

action_scores = tf_layers.fully_connected(action_out, num_outputs=self.n_actions, activation_fn=None)

Expand All @@ -120,7 +123,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals
state_out = tf_layers.fully_connected(state_out, num_outputs=layer_size, activation_fn=None)
if layer_norm:
state_out = tf_layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_out = act_fun(state_out)
state_score = tf_layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, axis=1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, axis=1)
Expand Down
8 changes: 6 additions & 2 deletions stable_baselines/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,18 @@ class FeedForwardPolicy(SACPolicy):
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param layer_norm: (bool) enable layer normalisation
:param reg_weight: (float) Regularization loss weight for the policy parameters
:param reg_weight: (float) Regularization loss weight for the policy parameters
:param act_fun: (tf.func) the activation function to use in the neural network.
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""

def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, layers=None,
cnn_extractor=nature_cnn, feature_extraction="cnn", reg_weight=0.0,
layer_norm=False, **kwargs):
layer_norm=False, act_fun=tf.nn.relu, **kwargs):
super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
reuse=reuse, scale=(feature_extraction == "cnn"))

self._kwargs_check(feature_extraction, kwargs)
self.layer_norm = layer_norm
self.feature_extraction = feature_extraction
self.cnn_kwargs = kwargs
Expand All @@ -197,7 +201,7 @@ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, r

assert len(layers) >= 1, "Error: must have at least one hidden layer for the policy."

self.activ_fn = tf.nn.relu
self.activ_fn = act_fun

def make_actor(self, obs=None, reuse=False, scope="pi"):
if obs is None:
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/trpo_mpi/trpo_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def setup_model(self):

# Construct network for new policy
self.policy_pi = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
None, reuse=False)
None, reuse=False, **self.policy_kwargs)

# Network for old policy
with tf.variable_scope("oldpi", reuse=False):
old_policy = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs, 1,
None, reuse=False)
None, reuse=False, **self.policy_kwargs)

with tf.variable_scope("loss", reuse=False):
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
Expand Down
Loading

0 comments on commit 396e0fa

Please sign in to comment.