Skip to content

Commit

Permalink
migrate to FBGEMM version of Permute Pooled Embs module (#1819)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1819

migrate to use fbgemm directory version, as it is in pkg already
removes logging for traceability purposes later

Reviewed By: zainhuda

Differential Revision: D55154971

fbshipit-source-id: 5623c22ec60d30ac58f13994d5cb3f89424d6a1b
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Mar 21, 2024
1 parent 8a1b30f commit 3db7477
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 60 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import nn, Tensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -66,7 +67,6 @@
convert_to_fbgemm_types,
merge_fused_params,
optimizer_type_to_emb_opt_type,
PermutePooledEmbeddings,
)
from torchrec.modules.embedding_configs import (
EmbeddingBagConfig,
Expand Down
59 changes: 0 additions & 59 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging

from collections import OrderedDict
from itertools import accumulate
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union

import torch
Expand All @@ -31,22 +30,6 @@
logger: logging.Logger = logging.getLogger(__name__)
_T = TypeVar("_T")

try:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
except OSError:
pass

# OSS
try:
pass
except ImportError:
pass

"""
torch.package safe functions from pyre_extensions. However, pyre_extensions is
not safe to use in code that will be torch.packaged, as it requires sys for
Expand Down Expand Up @@ -456,45 +439,3 @@ def maybe_reset_parameters(m: nn.Module) -> None:
m.reset_parameters()

module.apply(maybe_reset_parameters)


# TODO remove and use FBGEMM version once changes are in the package
class PermutePooledEmbeddings:
def __init__(
self,
embs_dims: List[int],
permute: List[int],
device: Optional[torch.device] = None,
) -> None:
logging.info("Using Permute Pooled Embeddings")
self._offset_dim_list: torch.Tensor = torch.tensor(
[0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64
)

self._permute: torch.Tensor = torch.tensor(
permute, device=device, dtype=torch.int64
)

inv_permute: List[int] = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i

self._inv_permute: torch.Tensor = torch.tensor(
inv_permute, device=device, dtype=torch.int64
)

inv_embs_dims = [embs_dims[i] for i in permute]

self._inv_offset_dim_list: torch.Tensor = torch.tensor(
[0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64
)

def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor:
result = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
pooled_embs,
self._offset_dim_list.to(device=pooled_embs.device),
self._permute.to(device=pooled_embs.device),
self._inv_offset_dim_list.to(device=pooled_embs.device),
self._inv_permute.to(device=pooled_embs.device),
)
return result

0 comments on commit 3db7477

Please sign in to comment.