diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 171f94cb2..26af33938 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional import torch @@ -24,6 +24,10 @@ FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment" FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode" +# Force lengths to offsets conversion before TBE lookup. Helps with performance +# with certain ways to split models. +FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup" + class TBEToRegisterMixIn: def get_tbes_to_register( @@ -68,6 +72,18 @@ def fused_param_bounds_check_mode( return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE] +def fused_param_lengths_to_offsets_lookup( + fused_params: Optional[Dict[str, Any]] +) -> bool: + if ( + fused_params is None + or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params + ): + return False + else: + return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP] + + def is_fused_param_quant_state_dict_split_scale_bias( fused_params: Optional[Dict[str, Any]] ) -> bool: @@ -93,5 +109,7 @@ def tbe_fused_params( fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT) if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe: fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE) + if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP) return fused_params_for_tbe diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 9b230103e..68f799652 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -33,6 +33,7 @@ ) from torchrec.distributed.fused_params import ( fused_param_bounds_check_mode, + fused_param_lengths_to_offsets_lookup, is_fused_param_quant_state_dict_split_scale_bias, is_fused_param_register_tbe, tbe_fused_params, @@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu( return indices, offsets, None +@torch.fx.wrap +def _unwrap_kjt_lengths( + features: KeyedJaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + indices = features.values() + lengths = features.lengths() + return ( + indices.int(), + lengths.int(), + features.weights_or_none(), + ) + + @torch.fx.wrap def _unwrap_optional_tensor( tensor: Optional[torch.Tensor], @@ -180,6 +194,26 @@ def _unwrap_optional_tensor( return tensor +class IntNBitTableBatchedEmbeddingBagsCodegenWithLength( + IntNBitTableBatchedEmbeddingBagsCodegen +): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # pyre-ignore Inconsistent override [14] + def forward( + self, + indices: torch.Tensor, + lengths: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self._forward_impl( + indices=indices, + offsets=(torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)), + per_sample_weights=per_sample_weights, + ) + + class QuantBatchedEmbeddingBag( BaseBatchedEmbeddingBag[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] @@ -237,22 +271,27 @@ def __init__( ) ) - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( - IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=embedding_specs, - device=device, - pooling_mode=self._pooling, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - bounds_check_mode=( - bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING - ), - feature_names_per_table=[ - table.feature_names for table in config.embedding_tables - ], - **(tbe_fused_params(fused_params) or {}), - ) + self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params) + + if self.lengths_to_tbe: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength + else: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz( + embedding_specs=embedding_specs, + device=device, + pooling_mode=self._pooling, + feature_table_map=self._feature_table_map, + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + bounds_check_mode=( + bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING + ), + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) if device is not None: self._emb_module.initialize_weights() @@ -271,44 +310,50 @@ def get_tbes_to_register( ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: return {self._emb_module: self._config} - def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: - # Important: _unwrap_kjt regex for FX tracing TAGing - if self._runtime_device.type == "cpu": - indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu( - features, self._config.is_weighted - ) + def _emb_module_forward( + self, + indices: torch.Tensor, + lengths_or_offsets: torch.Tensor, + weights: Optional[torch.Tensor], + ) -> torch.Tensor: + kwargs = {"indices": indices} + + if self.lengths_to_tbe: + kwargs["lengths"] = lengths_or_offsets else: - indices, offsets, per_sample_weights = _unwrap_kjt(features) + kwargs["offsets"] = lengths_or_offsets if self._is_weighted: - weights = _unwrap_optional_tensor(per_sample_weights) - if self._emb_module_registered: - # Conditional call of .forward function for FX: - # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) - # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) - # For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged. - return self.emb_module( - indices=indices, - offsets=offsets, - per_sample_weights=weights, - ) + kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights) + + if self._emb_module_registered: + # Conditional call of .forward function for FX: + # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) + # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) + # For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged. + return self._emb_module(**kwargs) + else: + return self._emb_module.forward(**kwargs) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + # Important: _unwrap_kjt regex for FX tracing TAGing + lengths, offsets = None, None + if self._runtime_device.type == "cpu": + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) else: - return self.emb_module.forward( - indices=indices, - offsets=offsets, - per_sample_weights=weights, + indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu( + features, self._config.is_weighted ) else: - if self._emb_module_registered: - return self.emb_module( - indices=indices, - offsets=offsets, - ) + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) else: - return self.emb_module.forward( - indices=indices, - offsets=offsets, - ) + indices, offsets, per_sample_weights = _unwrap_kjt(features) + + return self._emb_module_forward( + indices, lengths if lengths is not None else offsets, per_sample_weights + ) def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index f2d8903dc..4d136f488 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -26,6 +26,7 @@ from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fused_params import ( FUSED_PARAM_BOUNDS_CHECK_MODE, + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, FUSED_PARAM_REGISTER_TBE_BOOL, ) @@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str: FUSED_PARAM_REGISTER_TBE_BOOL: True, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True, FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False, } DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [ diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index d0ad0469a..8dbec9145 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -10,17 +10,22 @@ import unittest from argparse import Namespace +from typing import Any, cast, Dict, List import torch from fbgemm_gpu.split_embedding_configs import SparseType +from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.test_utils.test_model import ( ModelInput, TestOverArchRegroupModule, TestSparseNN, ) +from torchrec.distributed.types import ModuleSharder from torchrec.inference.dlrm_predict import ( create_training_batch, @@ -300,6 +305,65 @@ def test_sharded_quantized_tbe_count(self) -> None: expected_num_embeddings[spec[0]], ) + def test_sharded_quantized_lengths_to_tbe(self) -> None: + set_propogate_device(True) + + fused_params: Dict[str, Any] = {FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: True} + sharders: List[ModuleSharder[torch.nn.Module]] = [ + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingBagCollectionSharder(fused_params=fused_params), + ), + ] + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + # with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types + output = model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model( + model + ) + quantized_output = quantized_model(local_batch[0]) + table_to_weight = get_table_to_weights_from_tbe(quantized_model) + + # Shard the model, all weights are initialized back to 0, so have to reassign weights + sharded_quant_model, _ = shard_quant_model( + quantized_model, + world_size=1, + compute_device="cpu", + sharding_device="cpu", + sharders=sharders, + ) + assign_weights_to_tbe(quantized_model, table_to_weight) + sharded_quant_output = sharded_quant_model(local_batch[0]) + + # When world_size = 1, we should have 1 TBE per sharded, quantized ebc + self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1) + self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1) + + # Check the weights are close + self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3)) + self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3)) + + def test_quantized_tbe_count_different_pooling(self) -> None: set_propogate_device(True)