From 860d5740f4a0f3c39b09457e6e3f83c71d3589d6 Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Wed, 30 Mar 2022 10:30:41 -0700 Subject: [PATCH] gpu kernel for 1d sparse recat gen (#179) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/179 * add the `expand_into_jagged_permute` GPU kernel callsite for generating 1D sparse data permute Reviewed By: youyou6093 Differential Revision: D34778094 fbshipit-source-id: d14174cea809f3e33b1d860d297c7d318a930e34 --- torchrec/distributed/dist_data.py | 123 +++++++++++++++++------------- 1 file changed, 68 insertions(+), 55 deletions(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 4980a7d77..907f504db 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -41,8 +41,9 @@ def _get_recat( local_split: int, num_splits: int, stagger: int = 1, + device: Optional[torch.device] = None, batch_size_per_rank: Optional[List[int]] = None, -) -> List[int]: +) -> torch.Tensor: """ Calculates relevant recat indices required to reorder AlltoAll collective. @@ -63,42 +64,58 @@ def _get_recat( _recat(2, 4, 2) # [0, 4, 2, 6, 1, 5, 3, 7] """ + with record_function("## all2all_data:recat_permute_gen ##"): + recat: List[int] = [] - recat: List[int] = [] + if local_split == 0: + return torch.tensor(recat, device=device, dtype=torch.int32) - feature_order: List[int] = [ - x + num_splits // stagger * y - for x in range(num_splits // stagger) - for y in range(stagger) - ] + feature_order: List[int] = [ + x + num_splits // stagger * y + for x in range(num_splits // stagger) + for y in range(stagger) + ] - for i in range(local_split): - for j in feature_order: # range(num_splits): - recat.append(i + j * local_split) + for i in range(local_split): + for j in feature_order: # range(num_splits): + recat.append(i + j * local_split) - # variable batch size - if batch_size_per_rank is not None: - batch_size_per_feature = list( - itertools.chain.from_iterable( - itertools.repeat(x, local_split) for x in batch_size_per_rank - ) - ) - batch_size_per_feature_cumsum = [0] + list( - itertools.accumulate(batch_size_per_feature) - ) - recat_per_feature = recat - recat = [] - for r in recat_per_feature: - recat.extend( - list( - range( - batch_size_per_feature_cumsum[r], - batch_size_per_feature_cumsum[r + 1], - ) + # variable batch size + if batch_size_per_rank is not None: + batch_size_per_feature = list( + itertools.chain.from_iterable( + itertools.repeat(x, local_split) for x in batch_size_per_rank ) ) - - return recat + permuted_batch_size_per_feature = [batch_size_per_feature[r] for r in recat] + input_offset = [0] + list(itertools.accumulate(batch_size_per_feature)) + output_offset = [0] + list( + itertools.accumulate(permuted_batch_size_per_feature) + ) + recat_tensor = torch.tensor( + recat, + device=device, + dtype=torch.int32, + ) + input_offset_tensor = torch.tensor( + input_offset, + device=device, + dtype=torch.int32, + ) + output_offset_tensor = torch.tensor( + output_offset, + device=device, + dtype=torch.int32, + ) + recat = torch.ops.fbgemm.expand_into_jagged_permute( + recat_tensor, + input_offset_tensor, + output_offset_tensor, + output_offset[-1], + ) + return recat + else: + return torch.tensor(recat, device=device, dtype=torch.int32) def _split_lengths( @@ -321,15 +338,12 @@ def __init__( ) self._batch_size_per_rank_tensor = batch_size_per_rank_tensor self._batch_size_per_rank = batch_size_per_rank_tensor.cpu().tolist() - self._recat = torch.tensor( - _get_recat( - local_split=dim_0, - num_splits=len(splits), - stagger=stagger, - batch_size_per_rank=self._batch_size_per_rank, - ), + self._recat = _get_recat( + local_split=dim_0, + num_splits=len(splits), + stagger=stagger, device=self._device, - dtype=torch.int32, + batch_size_per_rank=self._batch_size_per_rank, ) else: assert self._recat is not None @@ -341,9 +355,10 @@ def __init__( dtype=in_lengths.dtype, ) self._lengths = out_lengths - self._in_lengths_per_worker = _split_lengths( - splits, input.keys(), input.offset_per_key() - ) + with record_function("## all2all_data:split length ##"): + self._in_lengths_per_worker = _split_lengths( + splits, input.keys(), input.offset_per_key() + ) self._output_split_sizes: List[int] = [ dim_0 * B_rank for B_rank in self._batch_size_per_rank @@ -370,12 +385,13 @@ def _wait_impl(self) -> KJTAllToAllIndicesAwaitable: if self._workers > 1: self._lengths_awaitable.wait() if self._variable_batch_size: - lengths_per_rank: List[torch.Tensor] = list( - self._lengths.split(self._output_split_sizes) - ) - out_lengths_per_worker = [ - int(length.sum().item()) for length in lengths_per_rank - ] + with record_function("## all2all_data:split length for a2a ##"): + lengths_per_rank: List[torch.Tensor] = list( + self._lengths.split(self._output_split_sizes) + ) + out_lengths_per_worker = [ + int(length.sum().item()) for length in lengths_per_rank + ] else: out_lengths_per_worker = ( self._lengths.view(self._workers, -1).sum(dim=1).cpu().tolist() @@ -467,14 +483,11 @@ def __init__( self._variable_batch_size = variable_batch_size self.register_buffer( "_recat", - torch.tensor( - _get_recat( - local_split=splits[pg.rank()], - num_splits=len(splits), - stagger=stagger, - ), + _get_recat( + local_split=splits[pg.rank()], + num_splits=len(splits), + stagger=stagger, device=device, - dtype=torch.int, ), )