Skip to content

Commit

Permalink
avoid permuting inverse indices when not necessary (#1807)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1807

if inverse indices and embedding names are in the same order + no duplicates we can avoid the index select to permute the inverse indices tensor

Reviewed By: zainhuda

Differential Revision: D55039535

fbshipit-source-id: c00a193e5a8f34914cd67c627cb88d79e02cfdc9
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Mar 19, 2024
1 parent ce7b919 commit 6a415a7
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __init__(
self,
awaitables: List[Awaitable[torch.Tensor]],
inverse_indices: Tuple[List[str], torch.Tensor],
inverse_indices_permute_indices: torch.Tensor,
inverse_indices_permute_indices: Optional[torch.Tensor],
batch_size_per_feature_pre_a2a: List[int],
uncombined_embedding_dims: List[int],
embedding_names: List[str],
Expand All @@ -331,9 +331,11 @@ def __init__(
def _wait_impl(self) -> KeyedTensor:
embeddings = [w.wait() for w in self._awaitables]
batch_size = self._inverse_indices[1].numel() // len(self._inverse_indices[0])
indices = torch.index_select(
self._inverse_indices[1], 0, self._inverse_indices_permute_indices
)
permute_indices = self._inverse_indices_permute_indices
if permute_indices is not None:
indices = torch.index_select(self._inverse_indices[1], 0, permute_indices)
else:
indices = self._inverse_indices[1]
reindex_output = torch.ops.fbgemm.batch_index_select_dim0(
inputs=embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings),
indices=indices.view(-1),
Expand Down Expand Up @@ -768,25 +770,25 @@ def _create_inverse_indices_permute_indices(
index_per_name[name.split("@")[0]]
for name in self._uncombined_embedding_names
]
self._inverse_indices_permute_indices = _pin_and_move(
torch.tensor(permute_indices),
inverse_indices[1].device,
)
if len(permute_indices) != len(index_per_name) or permute_indices != sorted(
permute_indices
):
self._inverse_indices_permute_indices = _pin_and_move(
torch.tensor(permute_indices),
inverse_indices[1].device,
)

# pyre-ignore [14]
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if self._has_uninitialized_input_dist:
self._create_input_dist(features.keys())
self._has_uninitialized_input_dist = False
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if (
ctx.variable_batch_per_feature
and self._inverse_indices_permute_indices is None
):
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
if ctx.variable_batch_per_feature:
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
with torch.no_grad():
if self._has_features_permute:
features = features.permute(
Expand Down

0 comments on commit 6a415a7

Please sign in to comment.