diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 9245beb8f..c66bb74a4 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -48,6 +48,7 @@ try: from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: def is_torchdynamo_compiling() -> bool: # type: ignore[misc] @@ -63,6 +64,7 @@ def _get_recat( stagger: int = 1, device: Optional[torch.device] = None, batch_size_per_rank: Optional[List[int]] = None, + use_tensor_compute: bool = False, ) -> Optional[torch.Tensor]: """ Calculates relevant recat indices required to reorder AlltoAll collective. @@ -88,6 +90,11 @@ def _get_recat( _recat(0, 4, 2) # None """ + if use_tensor_compute: + return _get_recat_tensor_compute( + local_split, num_splits, stagger, device, batch_size_per_rank + ) + with record_function("## all2all_data:recat_permute_gen ##"): if local_split == 0: return None @@ -151,6 +158,77 @@ def _get_recat( return torch.tensor(recat, device=device, dtype=torch.int32) +def _get_recat_tensor_compute( + local_split: int, + num_splits: int, + stagger: int = 1, + device: Optional[torch.device] = None, + batch_size_per_rank: Optional[List[int]] = None, +) -> Optional[torch.Tensor]: + """ + _get_recat list based will produce many instructions in the graph with scalar compute. + This is tensor compute with identical result with smaller ops count. + """ + with record_function("## all2all_data:recat_permute_gen ##"): + if local_split == 0: + return None + + X: int = num_splits // stagger + Y: int = stagger + feature_order: torch.Tensor = ( + torch.arange(X, dtype=torch.int32).view(X, 1).expand(X, Y) + + (X * torch.arange(Y, dtype=torch.int32)).expand(X, Y) + ).reshape(-1) + + LS: int = local_split + FO_S0: int = feature_order.size(0) + recat: torch.Tensor = ( + torch.arange(LS, dtype=torch.int32).view(LS, 1).expand(LS, FO_S0) + + (feature_order.expand(LS, FO_S0) * LS) + ).reshape(-1) + + vb_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) + + if vb_condition: + batch_size_per_rank_tensor = torch._refs.tensor( + batch_size_per_rank, dtype=torch.int32 + ) + N: int = batch_size_per_rank_tensor.size(0) + batch_size_per_feature_tensor: torch.Tensor = ( + batch_size_per_rank_tensor.view(N, 1).expand(N, LS).reshape(-1) + ) + + permuted_batch_size_per_feature_tensor: torch.Tensor = ( + batch_size_per_feature_tensor.index_select(0, recat) + ) + + input_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + batch_size_per_feature_tensor + ) + output_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + permuted_batch_size_per_feature_tensor + ) + + recat_tensor = torch.tensor( + recat, + device=device, + dtype=torch.int32, + ) + input_offset_device = input_offset.to(device=device) + output_offset_device = output_offset.to(device=device) + recat = torch.ops.fbgemm.expand_into_jagged_permute( + recat_tensor, + input_offset_device, + output_offset_device, + output_offset[-1].item(), + ) + return recat + else: + return torch.tensor(recat, device=device, dtype=torch.int32) + + class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]): """ Awaitable for splits AlltoAll. @@ -198,7 +276,14 @@ def _wait_impl(self) -> List[List[int]]: if not is_torchdynamo_compiling(): self._splits_awaitable.wait() - return self._output_tensor.view(self.num_workers, -1).T.tolist() + ret = self._output_tensor.view(self.num_workers, -1).T.tolist() + + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + for i in range(len(ret)): + for j in range(len(ret[i])): + torch._check_is_size(ret[i][j]) + + return ret class KJTAllToAllTensorsAwaitable(Awaitable[KeyedJaggedTensor]): @@ -258,6 +343,7 @@ def __init__( stagger=stagger, device=device, batch_size_per_rank=self._stride_per_rank, + use_tensor_compute=is_torchdynamo_compiling(), ) if self._workers == 1: return