Skip to content

Commit

Permalink
Decouple DataParallel/DistributedDataParallel from CUDA (pytorch#38454)
Browse files Browse the repository at this point in the history
Summary:
Decouple DataParallel/DistributedDataParallel from CUDA to support more device types.
- Move torch/cuda/comm.py to torch/nn/parallel/comm.py with minor changes for common devices support. Torch.cuda.comm is kept as is for backward compatibility
- Provide common APIs to arbitrary device types without changing existing CUDA APIs in torch.cuda space.
- Replace the torch.cuda calls in DataParellel/DistributedDataParallel with the new APIs.

Related RFC: [https://github.com/pytorch/pytorch/issues/36160](https://github.com/pytorch/pytorch/issues/36160)

Pull Request resolved: pytorch#38454

Differential Revision: D22051557

Pulled By: mrshenli

fbshipit-source-id: 7842dad0e5d3ca0f6fb760bda49182dcf6653af8
  • Loading branch information
chengjunlu authored and facebook-github-bot committed Jul 7, 2020
1 parent 75155df commit 8d570bc
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 287 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ ignore_errors = True
[mypy-torch.nn.parallel._functions]
ignore_errors = True

[mypy-torch.nn.parallel.comm]
ignore_errors = True

[mypy-torch.nn.quantized.functional]
ignore_errors = True

Expand Down
8 changes: 4 additions & 4 deletions test/distributed/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,20 +2113,20 @@ def test_ddp_multi_device_module_config(self):
gpus = gpus[:2]
model = DoubleGpuNet(gpus)

with self.assertRaisesRegex(AssertionError, "output_device .* single-device CUDA"):
with self.assertRaisesRegex(AssertionError, "output_device .* single-device GPU"):
ddp_model = DistributedDataParallel(
model, output_device=gpus[1], process_group=process_group)

with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group)

with self.assertRaisesRegex(AssertionError, "only works with CUDA devices"):
with self.assertRaisesRegex(AssertionError, "input module must be on the same type of devices"):
model.fc1 = model.fc1.cpu()
ddp_model = DistributedDataParallel(model, process_group=process_group)

model = model.cpu()
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group)

Expand Down
67 changes: 67 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import torch._six
from typing import Optional
import warnings
from collections import defaultdict
import sys
Expand Down Expand Up @@ -416,3 +418,68 @@ def reraise(self):
# (https://bugs.python.org/issue2651), so we work around it.
msg = KeyErrorMessage(msg)
raise self.exc_type(msg)


def _get_available_device_type():
if torch.cuda.is_available():
return "cuda"
# add more available device types here
return None


def _get_device_attr(get_member):
device_type = _get_available_device_type()
if device_type.lower() == "cuda":
return get_member(torch.cuda)
# add more available device types here
return None


def _get_current_device_index():
# current device index
return _get_device_attr(lambda m: m.current_device())


def _get_all_device_indices():
# all device index
return _get_device_attr(lambda m: list(range(m.device_count())))


def _get_devices_properties(device_ids):
# all device properties
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]


def _get_device_index(device, optional=False, allow_cpu=False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
has index. Note that for a device without a specified index,
i.e., ``torch.device('xxx')``, this will return the current default
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default
device of the supported runtime platform if :attr:`optional` is ``True``.
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
"""
if isinstance(device, str):
device = torch.device(device)
device_idx: Optional[int]
device_idx = None
if isinstance(device, torch.device):
if not allow_cpu and device.type == 'cpu':
raise ValueError('Expected a non cpu device, but got: {}'.format(device))
device_idx = -1 if device.type == 'cpu' else device.index
if isinstance(device, int):
device_idx = device
if device_idx is None:
if optional:
device_idx = _get_current_device_index()
else:
raise ValueError('Expected a torch.device with a specified index '
'or an integer, but got:{}'.format(device))
return device_idx
20 changes: 4 additions & 16 deletions torch/cuda/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from typing import Optional, Union
from typing import Union
from torch.types import Device
# The _get_device_index has been moved to torch.utils._get_device_index
from torch._utils import _get_device_index as _torch_get_device_index


def _get_device_index(device: Union[Device, str, int], optional: bool = False,
Expand All @@ -21,27 +23,13 @@ def _get_device_index(device: Union[Device, str, int], optional: bool = False,
"""
if isinstance(device, str):
device = torch.device(device)
device_idx: Optional[int]
if isinstance(device, torch.device):
dev_type = device.type
if allow_cpu:
if device.type not in {'cuda', 'cpu'}:
raise ValueError('Expected a cuda or cpu device, but got: {}'.format(device))
elif device.type != 'cuda':
raise ValueError('Expected a cuda device, but got: {}'.format(device))
device_idx = -1 if device.type == 'cpu' else device.index
else:
if device is not None and not isinstance(device, torch._six.int_classes):
raise ValueError('Cannot recognize device {}'.format(device))
device_idx = device
if device_idx is None:
if optional:
# default cuda device index
return torch.cuda.current_device()
else:
raise ValueError('Expected a cuda device with a specified index '
'or an integer, but got: {}'.format(device))
return device_idx
return _torch_get_device_index(device, optional, allow_cpu)


def _dummy_type(name: str) -> type:
Expand Down
236 changes: 4 additions & 232 deletions torch/cuda/comm.py
Original file line number Diff line number Diff line change
@@ -1,233 +1,5 @@
import warnings
# The functions here have been moved to torch.nn.parallel.comm
from torch.nn.parallel.comm import broadcast, broadcast_coalesced, reduce_add, \
reduce_add_coalesced, scatter, gather

import torch

from . import nccl
from torch._utils import _take_tensors, _flatten_dense_tensors, \
_unflatten_dense_tensors, _reorder_tensors_as


def broadcast(tensor, devices=None, *, out=None):
r"""Broadcasts a tensor to specified CUDA devices.
Arguments:
tensor (Tensor): tensor to broadcast. Can be on CPU or CUDA.
devices (Iterable[torch.device, str or int], optional): an iterable of
CUDA devices, among which to broadcast.
out (Sequence[Tensor], optional, keyword-only): the CUDA tensors to
store output results.
.. note::
Exactly one of :attr:`devices` and :attr:`out` must be specified.
Returns:
- If :attr:`devices` is specified,
a tuple containing copies of :attr:`tensor`, placed on
:attr:`devices`.
- If :attr:`out` is specified,
a tuple containing :attr:`out` tensors, each containing a copy of
:attr:`tensor`.
"""
if not ((devices is None) ^ (out is None)):
raise RuntimeError(
"Exactly one of 'devices' and 'out' must be specified, but got "
"devices={} and out={}".format(devices, out))
if devices is not None:
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
return torch._C._broadcast(tensor, devices)
else:
return torch._C._broadcast_out(tensor, out)


def broadcast_coalesced(tensors, devices, buffer_size=10485760):
r"""Broadcasts a sequence tensors to the specified CUDA devices.
Small tensors are first coalesced into a buffer to reduce the number
of synchronizations.
Arguments:
tensors (sequence): tensors to broadcast. Must be on the same device,
either CPU or CUDA.
devices (Iterable[torch.device, str or int]): an iterable of CUDA
devices, among which to broadcast.
buffer_size (int): maximum size of the buffer used for coalescing
Returns:
A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
"""
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)


def reduce_add(inputs, destination=None):
"""Sums tensors from multiple GPUs.
All inputs should have matching shapes, dtype, and layout. The output tensor
will be of the same shape, dtype, and layout.
Arguments:
inputs (Iterable[Tensor]): an iterable of tensors to add.
destination (int, optional): a device on which the output will be
placed (default: current device).
Returns:
A tensor containing an elementwise sum of all inputs, placed on the
:attr:`destination` device.
"""
destination = torch.cuda._utils._get_device_index(destination, optional=True)
input_size = inputs[0].size()
root_index = None # index of input tensor that already is on the correct device
for i, inp in enumerate(inputs):
assert inp.is_cuda, "reduce_add expects all inputs to be on GPUs"
if inp.get_device() == destination:
root_index = i
if inp.size() != input_size:
got = 'x'.join(str(x) for x in inp.size())
expected = 'x'.join(str(x) for x in input_size)
raise ValueError("input {} has invalid size: got {}, but expected "
"{}".format(i, got, expected))
if root_index is None:
raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors")

if len(inputs) == 1:
return inputs[0]

if nccl.is_available(inputs):
result = torch.empty_like(inputs[root_index])
nccl.reduce(inputs, output=result, root=root_index)
else:
nonroot = [t for i, t in enumerate(inputs) if i != root_index]
result = inputs[root_index] + nonroot[0].cuda(destination, non_blocking=True) # make a new tensor w/o clone
for other in nonroot[1:]:
result.add_(other.cuda(destination, non_blocking=True))
return result


def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
"""Sums tensors from multiple GPUs.
Small tensors are first coalesced into a buffer to reduce the number
of synchronizations.
Arguments:
inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
contain tensors from a single device.
destination (int, optional): a device on which the output will be
placed (default: current device).
buffer_size (int): maximum size of the buffer used for coalescing
Returns:
A tuple of tensors containing an elementwise sum of each group of
inputs, placed on the ``destination`` device.
"""
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
# return `inputs`.
dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors)
output = []
ref_order = []
# process sparse ones first since they may have different sizes on different gpus
for tensor_at_gpus in zip(*inputs):
if all(t.is_sparse for t in tensor_at_gpus):
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
output.append(result)
ref_order.append(tensor_at_gpus[0])
else:
for coll, t in zip(dense_tensors, tensor_at_gpus):
coll.append(t.to_dense() if t.is_sparse else t)
ref_order.append(dense_tensors[0][-1])
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
# now the dense ones, which have consistent sizes
for chunks in zip(*itrs):
flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks] # (num_gpus,)
flat_result = reduce_add(flat_tensors, destination)
for t in _unflatten_dense_tensors(flat_result, chunks[0]):
# The unflattened tensors do not share storage, and we don't expose
# base flat tensor anyways, so give them different version counters.
# See NOTE [ Version Counter in comm.*_coalesced ]
output.append(t.data)
return tuple(_reorder_tensors_as(output, ref_order))


def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
r"""Scatters a tensor across multiple CUDA devices.
Arguments:
tensor (Tensor): tensor to scatter. Can be on CPU or CUDA.
devices (Iterable[torch.device, str or int], optional): an iterable of
CUDA devices, among which to scatter.
chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
each device. It should match :attr:`devices` in length and sums to
``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
into equal chunks.
dim (int, optional): A dimension along which to chunk :attr:`tensor`.
Default: ``0``.
out (Sequence[Tensor], optional, keyword-only): the CUDA tensors to
store output results. Sizes of these tensors must match that of
:attr:`tensor`, except for :attr:`dim`, where the total size must
sum to ``tensor.size(dim)``.
.. note::
Exactly one of :attr:`devices` and :attr:`out` must be specified. When
:attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
will be inferred from sizes of :attr:`out`.
Returns:
- If :attr:`devices` is specified,
a tuple containing chunks of :attr:`tensor`, placed on
:attr:`devices`.
- If :attr:`out` is specified,
a tuple containing :attr:`out` tensors, each containing a chunk of
:attr:`tensor`.
"""
if out is None:
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
else:
if devices is not None:
raise RuntimeError(
"'devices' must not be specified when 'out' is specified, but "
"got devices={}".format(devices))
if chunk_sizes is not None:
raise RuntimeError(
"'chunk_sizes' must not be specified when 'out' is specified, "
"but got chunk_sizes={}".format(chunk_sizes))
return tuple(torch._C._scatter_out(tensor, out, dim, streams))

def gather(tensors, dim=0, destination=None, *, out=None):
r"""Gathers tensors from multiple CUDA devices.
Arguments:
tensors (Iterable[Tensor]): an iterable of tensors to gather.
Tensor sizes in all dimensions other than :attr:`dim` have to match.
dim (int, optional): a dimension along which the tensors will be
concatenated. Default: ``0``.
destination (torch.device, str, or int, optional): the output device.
Can be CPU or CUDA. Default: the current CUDA device.
out (Tensor, optional, keyword-only): the tensor to store gather result.
Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
Can be on CPU or CUDA.
.. note::
:attr:`destination` must not be specified when :attr:`out` is specified.
Returns:
- If :attr:`destination` is specified,
a tensor located on :attr:`destination` device, that is a result of
concatenating :attr:`tensors` along :attr:`dim`.
- If :attr:`out` is specified,
the :attr:`out` tensor, now containing results of concatenating
:attr:`tensors` along :attr:`dim`.
"""
if out is None:
if destination == -1:
warnings.warn(
'Using -1 to represent CPU tensor is deprecated. Please use a '
'device object or string instead, e.g., "cpu".')
destination = torch.cuda._utils._get_device_index(destination, allow_cpu=True, optional=True)
return torch._C._gather(tensors, dim, destination)
else:
if destination is not None:
raise RuntimeError(
"'destination' must not be specified when 'out' is specified, but "
"got destination={}".format(destination))
return torch._C._gather_out(tensors, out, dim)
__all__ = [broadcast, broadcast_coalesced, reduce_add, reduce_add_coalesced, scatter, gather]
Loading

0 comments on commit 8d570bc

Please sign in to comment.