Skip to content

Commit

Permalink
recat tensor compute for dynamo (#1892)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1892

1/ Alternative for recat computation using tensor operations to reduce number of generated from list compute graph operations.

2/ torch._check_is_size compiler hint for dynamo for the results of splits (all size like as will be used for splits)

Reviewed By: joshuadeng

Differential Revision: D56269480

fbshipit-source-id: f75ed95a5743bc8ecda513be350808b6a99481db
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Apr 19, 2024
1 parent d0f7729 commit f034281
Showing 1 changed file with 87 additions and 1 deletion.
88 changes: 87 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling

except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
Expand All @@ -63,6 +64,7 @@ def _get_recat(
stagger: int = 1,
device: Optional[torch.device] = None,
batch_size_per_rank: Optional[List[int]] = None,
use_tensor_compute: bool = False,
) -> Optional[torch.Tensor]:
"""
Calculates relevant recat indices required to reorder AlltoAll collective.
Expand All @@ -88,6 +90,11 @@ def _get_recat(
_recat(0, 4, 2)
# None
"""
if use_tensor_compute:
return _get_recat_tensor_compute(
local_split, num_splits, stagger, device, batch_size_per_rank
)

with record_function("## all2all_data:recat_permute_gen ##"):
if local_split == 0:
return None
Expand Down Expand Up @@ -151,6 +158,77 @@ def _get_recat(
return torch.tensor(recat, device=device, dtype=torch.int32)


def _get_recat_tensor_compute(
local_split: int,
num_splits: int,
stagger: int = 1,
device: Optional[torch.device] = None,
batch_size_per_rank: Optional[List[int]] = None,
) -> Optional[torch.Tensor]:
"""
_get_recat list based will produce many instructions in the graph with scalar compute.
This is tensor compute with identical result with smaller ops count.
"""
with record_function("## all2all_data:recat_permute_gen ##"):
if local_split == 0:
return None

X: int = num_splits // stagger
Y: int = stagger
feature_order: torch.Tensor = (
torch.arange(X, dtype=torch.int32).view(X, 1).expand(X, Y)
+ (X * torch.arange(Y, dtype=torch.int32)).expand(X, Y)
).reshape(-1)

LS: int = local_split
FO_S0: int = feature_order.size(0)
recat: torch.Tensor = (
torch.arange(LS, dtype=torch.int32).view(LS, 1).expand(LS, FO_S0)
+ (feature_order.expand(LS, FO_S0) * LS)
).reshape(-1)

vb_condition = batch_size_per_rank is not None and any(
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
)

if vb_condition:
batch_size_per_rank_tensor = torch._refs.tensor(
batch_size_per_rank, dtype=torch.int32
)
N: int = batch_size_per_rank_tensor.size(0)
batch_size_per_feature_tensor: torch.Tensor = (
batch_size_per_rank_tensor.view(N, 1).expand(N, LS).reshape(-1)
)

permuted_batch_size_per_feature_tensor: torch.Tensor = (
batch_size_per_feature_tensor.index_select(0, recat)
)

input_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
batch_size_per_feature_tensor
)
output_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
permuted_batch_size_per_feature_tensor
)

recat_tensor = torch.tensor(
recat,
device=device,
dtype=torch.int32,
)
input_offset_device = input_offset.to(device=device)
output_offset_device = output_offset.to(device=device)
recat = torch.ops.fbgemm.expand_into_jagged_permute(
recat_tensor,
input_offset_device,
output_offset_device,
output_offset[-1].item(),
)
return recat
else:
return torch.tensor(recat, device=device, dtype=torch.int32)


class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]):
"""
Awaitable for splits AlltoAll.
Expand Down Expand Up @@ -198,7 +276,14 @@ def _wait_impl(self) -> List[List[int]]:
if not is_torchdynamo_compiling():
self._splits_awaitable.wait()

return self._output_tensor.view(self.num_workers, -1).T.tolist()
ret = self._output_tensor.view(self.num_workers, -1).T.tolist()

if not torch.jit.is_scripting() and is_torchdynamo_compiling():
for i in range(len(ret)):
for j in range(len(ret[i])):
torch._check_is_size(ret[i][j])

return ret


class KJTAllToAllTensorsAwaitable(Awaitable[KeyedJaggedTensor]):
Expand Down Expand Up @@ -258,6 +343,7 @@ def __init__(
stagger=stagger,
device=device,
batch_size_per_rank=self._stride_per_rank,
use_tensor_compute=is_torchdynamo_compiling(),
)
if self._workers == 1:
return
Expand Down

0 comments on commit f034281

Please sign in to comment.