Skip to content

Commit

Permalink
register a custom op to keep the PEA module unflattened when torch.ex…
Browse files Browse the repository at this point in the history
…port (#1900)

Summary:


reference:
* D54009459

Differential Revision: D56282744
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Apr 19, 2024
1 parent f120e42 commit ea9b83c
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# pyre-strict

import abc
import threading
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.library import Library
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingBagConfig,
Expand All @@ -20,6 +22,24 @@
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor

lib = Library("custom", "FRAGMENT")


class OpRegistryState:
"""
State of operator registry.
We can only register the op schema once. So if we're registering multiple
times we need a lock and check if they're the same schema
"""

op_registry_lock = threading.Lock()
# operator schema: op_name: schema
op_registry_schema: Dict[str, str] = {}


operator_registry_state = OpRegistryState()


@torch.fx.wrap
def reorder_inverse_indices(
Expand Down

0 comments on commit ea9b83c

Please sign in to comment.