Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GPUs in RaySampler #2323

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/garage/sampler/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ class RaySampler(Sampler):
The maximum length episodes which will be sampled.
is_tf_worker (bool): Whether it is workers for TFTrainer.
seed(int): The seed to use to initialize random number generators.
n_workers(int): The number of workers to use.
n_workers(int or None): The number of workers to use. Defaults to
number of physical cpus, if worker_factory is also None.
worker_class(type): Class of the workers. Instances should implement
the Worker interface.
worker_args (dict or None): Additional arguments that should be passed
to the worker.
n_gpus (int or float): Number of GPUs to to use in total for sampling.
If `n_workers` is not a power of two, this may need to be set
slightly below the true value, since `n_workers / n_gpus` gpus are
allocated to each worker. Defaults to zero, because otherwise
nothing would run if no gpus were available.

"""

Expand All @@ -54,26 +60,32 @@ def __init__(
max_episode_length=None,
is_tf_worker=False,
seed=get_seed(),
n_workers=psutil.cpu_count(logical=False),
n_workers=None,
worker_class=DefaultWorker,
worker_args=None):
# pylint: disable=super-init-not-called
worker_args=None,
n_gpus=0):
if not ray.is_initialized():
ray.init(log_to_driver=False, ignore_reinit_error=True)
if worker_factory is None and max_episode_length is None:
raise TypeError('Must construct a sampler from WorkerFactory or'
'parameters (at least max_episode_length)')
if isinstance(worker_factory, WorkerFactory):
if worker_factory is not None:
if n_workers is None:
n_workers = worker_factory.n_workers
self._worker_factory = worker_factory
else:
if n_workers is None:
n_workers = psutil.cpu_count(logical=False)
self._worker_factory = WorkerFactory(
max_episode_length=max_episode_length,
is_tf_worker=is_tf_worker,
seed=seed,
n_workers=n_workers,
worker_class=worker_class,
worker_args=worker_args)
self._sampler_worker = ray.remote(SamplerWorker)
remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers)
self._n_gpus = n_gpus
self._sampler_worker = remote_wrapper(SamplerWorker)
self._agents = agents
self._envs = self._worker_factory.prepare_worker_messages(envs)
self._all_workers = defaultdict(None)
Expand Down Expand Up @@ -103,7 +115,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
Sampler: An instance of `cls`.

"""
return cls(agents, envs, worker_factory=worker_factory)
return cls(agents,
envs,
worker_factory=worker_factory,
n_workers=worker_factory.n_workers)

def start_worker(self):
"""Initialize a new ray worker."""
Expand Down Expand Up @@ -308,7 +323,8 @@ def __getstate__(self):
"""
return dict(factory=self._worker_factory,
agents=self._agents,
envs=self._envs)
envs=self._envs,
n_gpus=self._n_gpus)

def __setstate__(self, state):
"""Unpickle the state.
Expand All @@ -319,7 +335,8 @@ def __setstate__(self, state):
"""
self.__init__(state['agents'],
state['envs'],
worker_factory=state['factory'])
worker_factory=state['factory'],
n_gpus=state['n_gpus'])


class SamplerWorker:
Expand Down
18 changes: 1 addition & 17 deletions src/garage/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,9 @@ class Sampler(abc.ABC):
`Sampler` needs. Specifically, it specifies how to construct `Worker`s,
which know how to collect episodes and update both agents and environments.

Currently, `__init__` is also part of the interface, but calling it is
deprecated. `start_worker` is also deprecated, and does not need to be
implemented.
`start_worker` is deprecated, and does not need to be implemented.
"""

def __init__(self, algo, env):
"""Construct a Sampler from an Algorithm.

Args:
algo (RLAlgorithm): The RL Algorithm controlling this
sampler.
env (Environment): The environment being sampled from.

Calling this method is deprecated.

"""
self.algo = algo
self.env = env

@classmethod
def from_worker_factory(cls, worker_factory, agents, envs):
"""Construct this sampler.
Expand Down
35 changes: 35 additions & 0 deletions tests/garage/sampler/test_ray_batched_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for ray_batched_sampler."""
import pickle
from unittest.mock import Mock

import numpy as np
Expand Down Expand Up @@ -138,6 +139,40 @@ def test_init_with_env_updates(ray_local_session_fixture):
assert sum(episodes.lengths) >= 160


def test_pickle(ray_local_session_fixture):
del ray_local_session_fixture
assert ray.is_initialized()
max_episode_length = 16
env = PointEnv()
policy = FixedPolicy(env.spec,
scripted_actions=[
env.action_space.sample()
for _ in range(max_episode_length)
])
tasks = SetTaskSampler(PointEnv)
n_workers = 4
workers = WorkerFactory(seed=100,
max_episode_length=max_episode_length,
n_workers=n_workers)
sampler = RaySampler.from_worker_factory(workers, policy, env)
sampler_pickled = pickle.dumps(sampler)
sampler.shutdown_worker()
sampler2 = pickle.loads(sampler_pickled)
episodes = sampler2.obtain_samples(0,
500,
np.asarray(policy.get_param_values()),
env_update=tasks.sample(n_workers))
mean_rewards = []
goals = []
for eps in episodes.split():
mean_rewards.append(eps.rewards.mean())
goals.append(eps.env_infos['task'][0]['goal'])
assert np.var(mean_rewards) > 0
assert np.var(goals) > 0
sampler2.shutdown_worker()
env.close()


def test_init_without_worker_factory(ray_local_session_fixture):
del ray_local_session_fixture
assert ray.is_initialized()
Expand Down