Skip to content

Commit

Permalink
Use functional collectives explicitly (#1832)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1832

Preivously for dist_data dynamo tracing we relied on dynamo functional collectives remapping.

This remapping adds mutation of the tensor, `.copy_`.
Inductor `.copy_` lowering adds non trivial DataDependent logic on strides and sizes, especially check for non-Overlap memory layout.

Potentially this could be solved changing dynamo remapping:
pytorch/pytorch#122788

**Current Solution:**

Use functional collectives for dynamo explicitly, that avoids copy_ in the graph.

Reviewed By: ezyang

Differential Revision: D55423735

fbshipit-source-id: c7d9a6b7f93fe97432836fd252742b972ba40ce1
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Mar 27, 2024
1 parent 26b6899 commit 8bde6f8
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,34 @@ def __init__(
super().__init__()
self.num_workers: int = pg.size()

with record_function("## all2all_data:kjt splits ##"):
self._output_tensor: torch.Tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._splits_awaitable: dist.Work = dist.all_to_all_single(
output=self._output_tensor,
input=input_tensor,
group=pg,
async_op=not is_torchdynamo_compiling(),
)
if is_torchdynamo_compiling():
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function("## all2all_data:kjt splits ##"):
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
else:
with record_function("## all2all_data:kjt splits ##"):
self._output_tensor: torch.Tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._splits_awaitable: dist.Work = dist.all_to_all_single(
output=self._output_tensor,
input=input_tensor,
group=pg,
async_op=not is_torchdynamo_compiling(),
)

def _wait_impl(self) -> List[List[int]]:
# handling sync torch dynamo trace case, where awaitable will be a Tensor
if isinstance(self._splits_awaitable, dist.Work):
if not is_torchdynamo_compiling():
self._splits_awaitable.wait()

return self._output_tensor.view(self.num_workers, -1).T.tolist()
Expand Down Expand Up @@ -261,21 +272,33 @@ def __init__(
input_tensors,
labels,
):
output_tensor = torch.empty(
sum(output_split), device=self._device, dtype=input_tensor.dtype
)
with record_function(f"## all2all_data:kjt {label} ##"):
awaitable = dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
output_split_sizes=output_split,
input_split_sizes=input_split,
group=self._pg,
async_op=not is_torchdynamo_compiling(),
if is_torchdynamo_compiling():
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
self._output_tensors.append(output_tensor)
else:
output_tensor = torch.empty(
sum(output_split), device=self._device, dtype=input_tensor.dtype
)
with record_function(f"## all2all_data:kjt {label} ##"):
awaitable = dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
output_split_sizes=output_split,
input_split_sizes=input_split,
group=self._pg,
async_op=not is_torchdynamo_compiling(),
)

self._output_tensors.append(output_tensor)
self._awaitables.append(awaitable)
self._output_tensors.append(output_tensor)
self._awaitables.append(awaitable)

def _wait_impl(self) -> KeyedJaggedTensor:
"""
Expand All @@ -289,9 +312,8 @@ def _wait_impl(self) -> KeyedJaggedTensor:
self._input.sync()
return self._input

for awaitable in self._awaitables:
# handling sync torch dynamo trace case where awaitable will be a Tensor
if isinstance(awaitable, dist.Work):
if not is_torchdynamo_compiling():
for awaitable in self._awaitables:
awaitable.wait()

return type(self._input).dist_init(
Expand Down

0 comments on commit 8bde6f8

Please sign in to comment.