diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 01d3a7708..8b7decdba 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -21,6 +21,7 @@ "__register_quant_state_dict_split_scale_bias" ) FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment" +FUSED_PARAM_IS_WEIGHTED: str = "__register_tbe_is_weighted" class TBEToRegisterMixIn: @@ -57,6 +58,13 @@ def get_fused_param_tbe_row_alignment( return fused_params[FUSED_PARAM_TBE_ROW_ALIGNMENT] +def is_fused_param_weighted(fused_params: Optional[Dict[str, Any]]) -> Optional[int]: + if fused_params is None or FUSED_PARAM_IS_WEIGHTED not in fused_params: + return None + else: + return fused_params[FUSED_PARAM_IS_WEIGHTED] + + def is_fused_param_quant_state_dict_split_scale_bias( fused_params: Optional[Dict[str, Any]] ) -> bool: @@ -80,5 +88,7 @@ def tbe_fused_params( fused_params_for_tbe.pop(FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS) if FUSED_PARAM_TBE_ROW_ALIGNMENT in fused_params_for_tbe: fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT) + if FUSED_PARAM_IS_WEIGHTED in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_IS_WEIGHTED) return fused_params_for_tbe diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index a9ba784c0..f54ac8b93 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -35,6 +35,7 @@ get_fused_param_tbe_row_alignment, is_fused_param_quant_state_dict_split_scale_bias, is_fused_param_register_tbe, + is_fused_param_weighted, tbe_fused_params, TBEToRegisterMixIn, ) @@ -192,6 +193,7 @@ def __init__( managed.append(EmbeddingLocation.HOST) self._config: GroupedEmbeddingConfig = config self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params) + self._is_weighted: bool = is_fused_param_weighted(fused_params) self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) @@ -254,8 +256,16 @@ def get_tbes_to_register( def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: if self._runtime_device.type == "cpu": indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(features) + if self._is_weighted: + assert per_sample_weights is not None + else: + per_sample_weights = None else: indices, offsets, per_sample_weights = _unwrap_kjt(features) + if self._is_weighted: + assert per_sample_weights is not None + else: + per_sample_weights = None # Conditional call of .forward function for FX: # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 8fcfa5928..ed836bbc1 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -32,6 +32,7 @@ create_sharding_infos_by_sharding, ) from torchrec.distributed.fused_params import ( + FUSED_PARAM_IS_WEIGHTED, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, FUSED_PARAM_REGISTER_TBE_BOOL, get_tbes_to_register_from_iterable, @@ -351,6 +352,8 @@ def shard( fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( module, FUSED_PARAM_REGISTER_TBE_BOOL, False ) + if FUSED_PARAM_IS_WEIGHTED not in fused_params: + fused_params[FUSED_PARAM_IS_WEIGHTED] = module.is_weighted() return ShardedQuantEmbeddingBagCollection( module, params, env, fused_params, device=device