Skip to content

Commit

Permalink
Tests for torch.export sharded inference module (#1885)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1885

The previous stack of diffs finally enables non-strict torch.export of TorchRec sharded inference modules (QEBC, QFPEBC). This provides the torch.export tests for ensuring compatability.

Unflatten currently does not work and will be WIP

Reviewed By: IvanKobzarev

Differential Revision: D56207585

fbshipit-source-id: 1024b3300a38d87f94c3cd6c11e304e33ac56c5d
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 17, 2024
1 parent 9f88932 commit b55fdb6
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 22 deletions.
100 changes: 90 additions & 10 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
from torchrec.distributed.quant_embeddingbag import (
QuantEmbeddingBagCollection,
QuantEmbeddingBagCollectionSharder,
QuantFeatureProcessedEmbeddingBagCollectionSharder,
ShardedQuantEmbeddingBagCollection,
ShardedQuantFeatureProcessedEmbeddingBagCollection,
)
from torchrec.distributed.quant_state import WeightSpec
from torchrec.distributed.shard import _shard_modules
Expand All @@ -79,6 +81,7 @@
QuantConfig,
)
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
Expand Down Expand Up @@ -331,6 +334,53 @@ def quantize_fpebc(
)


class TestQuantFPEBCSharder(QuantFeatureProcessedEmbeddingBagCollectionSharder):
def __init__(
self,
sharding_type: str,
kernel_type: str,
fused_params: Optional[Dict[str, Any]] = None,
shardable_params: Optional[List[str]] = None,
) -> None:
super().__init__(fused_params=fused_params, shardable_params=shardable_params)
self._sharding_type = sharding_type
self._kernel_type = kernel_type

def sharding_types(self, compute_device_type: str) -> List[str]:
return [self._sharding_type]

def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [self._kernel_type]

def shard(
self,
module: QuantFeatureProcessedEmbeddingBagCollection,
params: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device] = None,
) -> ShardedQuantFeatureProcessedEmbeddingBagCollection:
fused_params = self.fused_params if self.fused_params else {}
fused_params["output_dtype"] = data_type_to_sparse_type(
dtype_to_data_type(module.output_dtype())
)
fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr(
module, MODULE_ATTR_REGISTER_TBES_BOOL, False
)
fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr(
module, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False
)
return ShardedQuantFeatureProcessedEmbeddingBagCollection(
module=module,
table_name_to_parameter_sharding=params,
env=env,
fused_params=fused_params,
device=device,
feature_processor=module.feature_processor,
)


class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder):
def __init__(
self,
Expand Down Expand Up @@ -576,8 +626,10 @@ def create_test_model_ebc_only_no_quantize(
num_features: int = 1,
num_float_features: int = 8,
num_weighted_features: int = 1,
compute_device: str = "cuda",
feature_processor: bool = False,
) -> TestModelInfo:
topology: Topology = Topology(world_size=world_size, compute_device="cuda")
topology: Topology = Topology(world_size=world_size, compute_device=compute_device)
mi = TestModelInfo(
dense_device=dense_device,
sparse_device=sparse_device,
Expand Down Expand Up @@ -620,12 +672,27 @@ def create_test_model_ebc_only_no_quantize(
for i in range(mi.num_weighted_features)
]

mi.model = torch.nn.Sequential(
EmbeddingBagCollection(
if feature_processor:
max_feature_lengths = {config.feature_names[0]: 100 for config in mi.tables}
fp = PositionWeightedModuleCollection(
max_feature_lengths=max_feature_lengths, device=mi.sparse_device
)
ebc = FeatureProcessedEmbeddingBagCollection(
embedding_bag_collection=EmbeddingBagCollection(
# pyre-ignore [6]
tables=mi.tables,
device=mi.sparse_device,
is_weighted=True,
),
feature_processors=fp,
)
else:
ebc = EmbeddingBagCollection(
tables=mi.tables,
device=mi.sparse_device,
)
)

mi.model = torch.nn.Sequential(ebc)
mi.model.training = False
return mi

Expand All @@ -641,6 +708,8 @@ def create_test_model_ebc_only(
num_float_features: int = 8,
num_weighted_features: int = 1,
quant_state_dict_split_scale_bias: bool = False,
compute_device: str = "cuda",
feature_processor: bool = False,
) -> TestModelInfo:
mi = create_test_model_ebc_only_no_quantize(
num_embeddings=num_embeddings,
Expand All @@ -652,13 +721,24 @@ def create_test_model_ebc_only(
num_features=num_features,
num_float_features=num_float_features,
num_weighted_features=num_weighted_features,
compute_device=compute_device,
feature_processor=feature_processor,
)
mi.quant_model = quantize(
module=mi.model,
inplace=False,
register_tbes=True,
quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias,
)

if feature_processor:
mi.quant_model = quantize_fpebc(
module=mi.model,
inplace=True,
register_tbes=True,
quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias,
)
else:
mi.quant_model = quantize(
module=mi.model,
inplace=False,
register_tbes=True,
quant_state_dict_split_scale_bias=quant_state_dict_split_scale_bias,
)
return mi


Expand Down
88 changes: 76 additions & 12 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import List, Tuple

import torch
from torchrec.distributed.test_utils.infer_utils import TestQuantFPEBCSharder

try:
# pyre-ignore
Expand Down Expand Up @@ -50,12 +51,17 @@ def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor:
return kjt


def _sharded_quant_ebc_model() -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]]:
def _sharded_quant_ebc_model(
local_device: str = "cuda",
compute_device: str = "cuda",
feature_processor: bool = False,
) -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]]:
num_embeddings = 256
emb_dim = 12
world_size = 2
batch_size = 4
local_device = torch.device("cuda:0")

local_device = torch.device(local_device)
mi = create_test_model_ebc_only(
num_embeddings,
emb_dim,
Expand All @@ -66,34 +72,47 @@ def _sharded_quant_ebc_model() -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]
dense_device=local_device,
sparse_device=local_device,
quant_state_dict_split_scale_bias=True,
compute_device=compute_device,
feature_processor=feature_processor,
)
input_kjts = [
inp.to(local_device).idlist_features
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
model: torch.nn.Module = KJTInputExportWrapper(mi.quant_model, input_kjts[0].keys())

sharding_type: ShardingType = ShardingType.TABLE_WISE
sharder = TestQuantEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
)

if feature_processor:
sharder = TestQuantFPEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
)
else:
sharder = TestQuantEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
)
# pyre-ignore
plan = mi.planner.plan(
model,
mi.quant_model,
[sharder],
)

sharded_model = _shard_modules(
module=model,
module=mi.quant_model,
# pyre-ignore
sharders=[sharder],
device=local_device,
# Always shard on meta
device=torch.device("meta"),
plan=plan,
# pyre-ignore
env=ShardingEnv.from_local(world_size=mi.topology.world_size, rank=0),
)
return sharded_model, input_kjts

model: torch.nn.Module = KJTInputExportWrapper(sharded_model, input_kjts[0].keys())
return model, input_kjts


class TestPt2(unittest.TestCase):
Expand Down Expand Up @@ -269,6 +288,51 @@ def kjt_to_inputs(kjt):
]
assert_close(expected_outputs, aot_actual_outputs)

def test_sharded_quant_ebc_non_strict_export(self) -> None:
sharded_model, input_kjts = _sharded_quant_ebc_model(
local_device="cpu", compute_device="cpu"
)
kjt = input_kjts[0]
kjt = kjt.to("meta")
sharded_model(kjt.values(), kjt.lengths())

ep = torch.export.export(
sharded_model,
(
kjt.values(),
kjt.lengths(),
),
{},
strict=False,
)

ep.module()(kjt.values(), kjt.lengths())

# TODO: Fix Unflatten
# torch.export.unflatten(ep)

def test_sharded_quant_fpebc_non_strict_export(self) -> None:
sharded_model, input_kjts = _sharded_quant_ebc_model(
local_device="cpu", compute_device="cpu", feature_processor=True
)
kjt = input_kjts[0]
kjt = kjt.to("meta")
sharded_model(kjt.values(), kjt.lengths())

ep = torch.export.export(
sharded_model,
(
kjt.values(),
kjt.lengths(),
),
{},
strict=False,
)
ep.module()(kjt.values(), kjt.lengths())

# TODO: Fix Unflatten
# torch.export.unflatten(ep)

def test_maybe_compute_kjt_to_jt_dict(self) -> None:
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
self._test_kjt_input_module(
Expand Down

0 comments on commit b55fdb6

Please sign in to comment.