Skip to content

Commit

Permalink
[rllib] Fix A3C PyTorch implementation (ray-project#2036)
Browse files Browse the repository at this point in the history
* Use F.softmax instead of a pointless network layer

Stateless functions should not be network layers.

* Use correct pytorch functions

* Rename argument name to out_size

Matches in_size and makes more sense.

* Fix shapes of tensors

Advantages and rewards both should be scalars, and therefore a list of them
should be 1D.

* Fmt

* replace deprecated function

* rm unnecessary Variable wrapper

* rm all use of torch Variables

Torch does this for us now.

* Ensure that values are flat list

* Fix shape error in conv nets

* fmt

* Fix shape errors

Reshaping the action before stepping in the env fixes a few errors.

* Add TODO

* Use correct filter size

Works when `self.config['model']['channel_major'] = True`.

* Add missing channel major

* Revert reshape of action

This should be handled by the agent or at least in a cleaner way that doesn't
break existing envs.

* Squeeze action

* Squeeze actions along first dimension

This should deal with some cases such as cartpole where actions are scalars
while leaving alone cases where actions are arrays (some robotics tasks).

* try adding pytorch tests

* typo

* fixup docker messages

* Fix A3C for some envs

Pendulum doesn't work since it's an edge case (expects singleton arrays, which
`.squeeze()` collapses to scalars).

* fmt

* nit flake

* small lint
  • Loading branch information
alok authored and richardliaw committed May 30, 2018
1 parent ac1e5a7 commit fd234e3
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 100 deletions.
8 changes: 5 additions & 3 deletions docker/examples/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# The examples Docker image adds dependencies needed to run the examples

FROM ray-project/deploy
RUN conda install -y -c conda-forge tensorflow

# This updates numpy to 1.14 and mutes errors from other libraries
RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
# RUN conda install -y -q pytorch torchvision -c soumith
RUN conda install pytorch-cpu torchvision-cpu -c pytorch
56 changes: 32 additions & 24 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources


DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
Expand Down Expand Up @@ -52,7 +51,7 @@
# (Image statespace) - Converts image to (dim, dim, C)
"dim": 80,
# (Image statespace) - Converts image shape to (C, dim, dim)
"channel_major": False
"channel_major": False,
},
# Arguments to pass to the rllib optimizer
"optimizer": {
Expand All @@ -73,46 +72,53 @@ class A3CAgent(Agent):
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1, gpu=0,
cpu=1,
gpu=0,
extra_cpu=cf["num_workers"],
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)

def _init(self):
self.local_evaluator = A3CEvaluator(
self.registry, self.env_creator, self.config, self.logdir,
self.registry,
self.env_creator,
self.config,
self.logdir,
start_sampler=False)
if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteA3CEvaluator
else:
remote_cls = RemoteA3CEvaluator
self.remote_evaluators = [
remote_cls.remote(
self.registry, self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
self.remote_evaluators)
remote_cls.remote(self.registry, self.env_creator, self.config,
self.logdir)
for i in range(self.config["num_workers"])
]
self.optimizer = AsyncOptimizer(self.config["optimizer"],
self.local_evaluator,
self.remote_evaluators)

def _train(self):
self.optimizer.step()
FilterManager.synchronize(
self.local_evaluator.filters, self.remote_evaluators)
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
res = self._fetch_metrics_from_remote_evaluators()
return res

def _fetch_metrics_from_remote_evaluators(self):
episode_rewards = []
episode_lengths = []
metric_lists = [a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators]
metric_lists = [
a.get_completed_rollout_metrics.remote()
for a in self.remote_evaluators
]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
avg_reward = (
np.mean(episode_rewards) if episode_rewards else float('nan'))
avg_length = (
np.mean(episode_lengths) if episode_lengths else float('nan'))
avg_reward = (np.mean(episode_rewards)
if episode_rewards else float('nan'))
avg_length = (np.mean(episode_lengths)
if episode_lengths else float('nan'))
timesteps = np.sum(episode_lengths) if episode_lengths else 0

result = TrainingResult(
Expand All @@ -129,21 +135,23 @@ def _stop(self):
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()}
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path

def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get(
[a.restore.remote(o) for a, o in zip(
self.remote_evaluators, extra_data["remote_state"])])
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])

def compute_action(self, observation):
Expand Down
49 changes: 28 additions & 21 deletions python/ray/rllib/a3c/shared_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

import torch
from torch.autograd import Variable
import torch.nn.functional as F

from ray.rllib.a3c.torchpolicy import TorchPolicy
Expand All @@ -18,8 +17,8 @@ class SharedTorchPolicy(TorchPolicy):
is_recurrent = False

def __init__(self, registry, ob_space, ac_space, config, **kwargs):
super(SharedTorchPolicy, self).__init__(
registry, ob_space, ac_space, config, **kwargs)
super(SharedTorchPolicy, self).__init__(registry, ob_space, ac_space,
config, **kwargs)

def _setup_graph(self, ob_space, ac_space):
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
Expand All @@ -31,48 +30,56 @@ def _setup_graph(self, ob_space, ac_space):
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), {"vf_preds": var_to_np(values)}
# TODO(alok): Support non-categorical distributions. Multinomial
# is only for categorical.
sampled_actions = F.softmax(logits, dim=1).multinomial(1).squeeze()
values = values.squeeze()
return var_to_np(sampled_actions), {"vf_preds": var_to_np(values)}

def compute_logits(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
return var_to_np(self._model.logits(res))

def value(self, ob, *args):
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
ob = torch.from_numpy(ob).float().unsqueeze(0)
res = self._model.hidden_layers(ob)
res = self._model.value_branch(res)
res = res.squeeze(0)
res = res.squeeze()
return var_to_np(res)

def _evaluate(self, obs, actions):
"""Passes in multiple obs."""
logits, values = self._model(obs)
log_probs = F.log_softmax(logits)
probs = self._model.probs(logits)
log_probs = F.log_softmax(logits, dim=1)
probs = F.softmax(logits, dim=1)
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
# TODO(alok): set distribution based on action space and use its
# `.entropy()` method to calculate automatically
entropy = -(log_probs * probs).sum(-1).sum()
return values, action_log_probs, entropy

def _backward(self, batch):
"""Loss is encoded in here. Defining a new loss function
would start by rewriting this function"""

states, acs, advs, rs, _ = convert_batch(batch)
values, ac_logprobs, entropy = self._evaluate(states, acs)
pi_err = -(advs * ac_logprobs).sum()
value_err = 0.5 * (values - rs).pow(2).sum()
states, actions, advs, rs, _ = convert_batch(batch)
values, action_log_probs, entropy = self._evaluate(states, actions)
pi_err = -advs.dot(action_log_probs.reshape(-1))
value_err = F.mse_loss(values.reshape(-1), rs)

self.optimizer.zero_grad()
overall_err = (pi_err +
value_err * self.config["vf_loss_coeff"] +
entropy * self.config["entropy_coeff"])

overall_err = sum([
pi_err,
self.config["vf_loss_coeff"] * value_err,
self.config["entropy_coeff"] * entropy,
])

overall_err.backward()
torch.nn.utils.clip_grad_norm(
self._model.parameters(), self.config["grad_clip"])
torch.nn.utils.clip_grad_norm_(self._model.parameters(),
self.config["grad_clip"])
14 changes: 9 additions & 5 deletions python/ray/rllib/a3c/torchpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

import torch
from torch.autograd import Variable

from ray.rllib.a3c.policy import Policy
from threading import Lock
Expand All @@ -15,8 +14,13 @@ class TorchPolicy(Policy):
The model is a separate object than the policy. This could be changed
in the future."""

def __init__(self, registry, ob_space, action_space, config,
name="local", summarize=True):
def __init__(self,
registry,
ob_space,
action_space,
config,
name="local",
summarize=True):
self.registry = registry
self.local_steps = 0
self.config = config
Expand All @@ -28,7 +32,7 @@ def __init__(self, registry, ob_space, action_space, config,
def apply_gradients(self, grads):
self.optimizer.zero_grad()
for g, p in zip(grads, self._model.parameters()):
p.grad = Variable(torch.from_numpy(g))
p.grad = torch.from_numpy(g)
self.optimizer.step()

def get_weights(self):
Expand Down Expand Up @@ -69,7 +73,7 @@ def _setup_graph(ob_space, action_space):

def _backward(self, batch):
"""Implements the loss function and calculates the gradient.
Pytorch automatically generates a backward trace for each variable.
Pytorch automatically generates a backward trace for each tensor.
Assumption right now is that variables are moved, so the backward
trace is lost.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,14 @@ def get_torch_model(registry, input_shape, num_outputs, options={}):
return registry.get(RLLIB_MODEL, model)(
input_shape, num_outputs, options)

# TODO(alok): fix to handle Discrete(n) state spaces
obs_rank = len(input_shape) - 1

if obs_rank > 1:
return PyTorchVisionNet(input_shape, num_outputs, options)

# TODO(alok): overhaul PyTorchFCNet so it can just
# take input shape directly
return PyTorchFCNet(input_shape[0], num_outputs, options)

@staticmethod
Expand Down
22 changes: 13 additions & 9 deletions python/ray/rllib/models/pytorch/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class FullyConnectedNetwork(Model):
"""TODO(rliaw): Logits, Value should both be contained here"""

def _init(self, inputs, num_outputs, options):
assert type(inputs) is int
hiddens = options.get("fcnet_hiddens", [256, 256])
Expand All @@ -23,26 +24,29 @@ def _init(self, inputs, num_outputs, options):
layers = []
last_layer_size = inputs
for size in hiddens:
layers.append(SlimFC(
last_layer_size, size,
initializer=normc_initializer(1.0),
activation_fn=activation))
layers.append(
SlimFC(
in_size=last_layer_size,
out_size=size,
initializer=normc_initializer(1.0),
activation_fn=activation))
last_layer_size = size

self.hidden_layers = nn.Sequential(*layers)

self.logits = SlimFC(
last_layer_size, num_outputs,
in_size=last_layer_size,
out_size=num_outputs,
initializer=normc_initializer(0.01),
activation_fn=None)
self.probs = nn.Softmax()
self.value_branch = SlimFC(
last_layer_size, 1,
in_size=last_layer_size,
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)

def forward(self, obs):
""" Internal method - pass in Variables, not numpy arrays
""" Internal method - pass in torch tensors, not numpy arrays
Args:
obs: observations and features
Expand All @@ -52,5 +56,5 @@ def forward(self, obs):
value: value function for each state"""
res = self.hidden_layers(obs)
logits = self.logits(res)
value = self.value_branch(res)
value = self.value_branch(res).reshape(-1)
return logits, value
22 changes: 8 additions & 14 deletions python/ray/rllib/models/pytorch/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,32 @@

import numpy as np
import torch
from torch.autograd import Variable


def convert_batch(trajectory, has_features=False):
"""Convert trajectory from numpy to PT variable"""
states = Variable(torch.from_numpy(
trajectory["observations"]).float())
acs = Variable(torch.from_numpy(
trajectory["actions"]))
advs = Variable(torch.from_numpy(
trajectory["advantages"].copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(
trajectory["value_targets"]).float())
rs = rs.view(-1, 1)
states = torch.from_numpy(trajectory["obs"]).float()
acs = torch.from_numpy(trajectory["actions"])
advs = torch.from_numpy(
trajectory["advantages"].copy()).float().reshape(-1)
rs = torch.from_numpy(trajectory["rewards"]).float().reshape(-1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in trajectory["features"]]
features = [torch.from_numpy(f) for f in trajectory["features"]]
else:
features = trajectory["features"]
return states, acs, advs, rs, features


def var_to_np(var):
return var.data.numpy()[0]
return var.detach().numpy()


def normc_initializer(std=1.0):
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))

return initializer


Expand Down
Loading

0 comments on commit fd234e3

Please sign in to comment.