Skip to content

Commit

Permalink
Support GPUs in RaySampler
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Apr 16, 2022
1 parent c56513f commit 9ceabe8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
16 changes: 12 additions & 4 deletions src/garage/sampler/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class RaySampler(Sampler):
the Worker interface.
worker_args (dict or None): Additional arguments that should be passed
to the worker.
n_gpus (int or float): Number of GPUs to to use in total for sampling.
If `n_workers` is not a power of two, this may need to be set
slightly below the true value, since `n_workers / n_gpus` gpus are
allocated to each worker.
"""

Expand All @@ -56,8 +60,8 @@ def __init__(
seed=get_seed(),
n_workers=psutil.cpu_count(logical=False),
worker_class=DefaultWorker,
worker_args=None):
# pylint: disable=super-init-not-called
worker_args=None,
n_gpus=0):
if not ray.is_initialized():
ray.init(log_to_driver=False, ignore_reinit_error=True)
if worker_factory is None and max_episode_length is None:
Expand All @@ -73,7 +77,8 @@ def __init__(
n_workers=n_workers,
worker_class=worker_class,
worker_args=worker_args)
self._sampler_worker = ray.remote(SamplerWorker)
remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers)
self._sampler_worker = remote_wrapper(SamplerWorker)
self._agents = agents
self._envs = self._worker_factory.prepare_worker_messages(envs)
self._all_workers = defaultdict(None)
Expand Down Expand Up @@ -103,7 +108,10 @@ def from_worker_factory(cls, worker_factory, agents, envs):
Sampler: An instance of `cls`.
"""
return cls(agents, envs, worker_factory=worker_factory)
return cls(agents,
envs,
worker_factory=worker_factory,
n_workers=worker_factory.n_workers)

def start_worker(self):
"""Initialize a new ray worker."""
Expand Down
18 changes: 1 addition & 17 deletions src/garage/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,9 @@ class Sampler(abc.ABC):
`Sampler` needs. Specifically, it specifies how to construct `Worker`s,
which know how to collect episodes and update both agents and environments.
Currently, `__init__` is also part of the interface, but calling it is
deprecated. `start_worker` is also deprecated, and does not need to be
implemented.
`start_worker` is deprecated, and does not need to be implemented.
"""

def __init__(self, algo, env):
"""Construct a Sampler from an Algorithm.
Args:
algo (RLAlgorithm): The RL Algorithm controlling this
sampler.
env (Environment): The environment being sampled from.
Calling this method is deprecated.
"""
self.algo = algo
self.env = env

@classmethod
def from_worker_factory(cls, worker_factory, agents, envs):
"""Construct this sampler.
Expand Down

0 comments on commit 9ceabe8

Please sign in to comment.