Skip to content

Commit

Permalink
[DTensor][3/N] add DTensor constructor function: full (pytorch#101436)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#101436
Approved by: https://github.com/wanchaol
  • Loading branch information
XilunWu authored and pytorchmergebot committed May 23, 2023
1 parent 5c3cf76 commit 2ca75d4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
10 changes: 10 additions & 0 deletions test/distributed/_tensor/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def test_empty(self):
requires_grad=True,
)

@with_comms
def test_full(self):
self._run_init_op(
torch.full,
torch.distributed._tensor.full,
self.assertEqual,
123.4,
requires_grad=True,
)

@with_comms
def test_zeros(self):
self._run_init_op(
Expand Down
58 changes: 54 additions & 4 deletions torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def _dtensor_init_helper(
# initialize the local tensor
if len(local_shape) == 0:
local_tensor = torch.empty(0, **kwargs)
elif init_op == torch.full:
fill_value = kwargs.pop("fill_value", 0)
local_tensor = init_op(local_shape, fill_value, **kwargs)
else:
local_tensor = init_op(local_shape, **kwargs)

Expand Down Expand Up @@ -85,8 +88,8 @@ def ones(
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Expand Down Expand Up @@ -128,8 +131,8 @@ def empty(
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Expand Down Expand Up @@ -157,6 +160,53 @@ def empty(
)


def full(
size,
fill_value,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match
``device_mesh.device_type``.
Args:
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
Can be a variable number of arguments or a collection like a list or tuple.
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
fill_value(Scalar): the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned :class:`DTensor`. Default: ``False``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``, ``_Partial``.
Returns:
A :class:`DTensor` object on each rank
"""
torch_size = _normalize_to_torch_size(size)

return _dtensor_init_helper(
torch.full,
torch_size,
fill_value=fill_value,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device_mesh=device_mesh,
placements=placements,
)


def zeros(
*size,
requires_grad: bool = False,
Expand Down

0 comments on commit 2ca75d4

Please sign in to comment.