Skip to content

Commit

Permalink
use the appropriate wait op for native funcol (#1811)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1811

Native funcol needs to be waited with `_c10d_functional.wait_tensor`. `c10d_functional.wait_tensor` is a noop for native funcol.

Reviewed By: IvanKobzarev, wanchaol

Differential Revision: D55116102

fbshipit-source-id: 8a3e7c8181d3bd7fb70811bd5a173df8f9ad2b60
  • Loading branch information
yifuwang authored and facebook-github-bot committed Mar 21, 2024
1 parent d9fbac9 commit 04ca364
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,6 +2663,26 @@ def _wait_autograd(input: torch.Tensor) -> torch.Tensor:
return _Wait.apply(input)


class _Wait_native_funcol(torch.autograd.Function):
@staticmethod
# pyre-ignore
def forward(
ctx, # pyre-ignore
input: torch.Tensor,
) -> torch.Tensor:
with torch._C._AutoDispatchBelowAutograd():
ret = torch.ops._c10d_functional.wait_tensor(input)
return ret

@staticmethod
def backward(ctx, grad_output): # pyre-ignore
return (grad_output,)


def _wait_autograd_native_funcol(input: torch.Tensor) -> torch.Tensor:
return _Wait_native_funcol.apply(input)


# pyre-ignore
c10d_functional_autograd_ops = [
("all_to_all_single", _all_to_all_single_autograd),
Expand All @@ -2676,7 +2696,7 @@ def _wait_autograd(input: torch.Tensor) -> torch.Tensor:
("all_to_all_single", _all_to_all_single_autograd_native_funcol),
("reduce_scatter_tensor", _reduce_scatter_tensor_autograd_native_funcol),
("all_gather_into_tensor", _all_gather_into_tensor_autograd_native_funcol),
("wait_tensor", _wait_autograd),
("wait_tensor", _wait_autograd_native_funcol),
]


Expand Down

0 comments on commit 04ca364

Please sign in to comment.