Skip to content

Commit

Permalink
enable is_weighted for quant emb kernel (#1893)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1893

pass is_weighted to quant emb kernel, so we can statically determine tbe module invoking

Reviewed By: seanx92, gnahzg

Differential Revision: D55567522

fbshipit-source-id: cedf81409c656200be1fcaf8f8a5967b9d7eabb1
  • Loading branch information
jiayisuse authored and facebook-github-bot committed Apr 18, 2024
1 parent b55fdb6 commit 59878e8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"__register_quant_state_dict_split_scale_bias"
)
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_IS_WEIGHTED: str = "__register_tbe_is_weighted"


class TBEToRegisterMixIn:
Expand Down Expand Up @@ -57,6 +58,13 @@ def get_fused_param_tbe_row_alignment(
return fused_params[FUSED_PARAM_TBE_ROW_ALIGNMENT]


def is_fused_param_weighted(fused_params: Optional[Dict[str, Any]]) -> Optional[bool]:
if fused_params is None or FUSED_PARAM_IS_WEIGHTED not in fused_params:
return None
else:
return fused_params[FUSED_PARAM_IS_WEIGHTED]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -80,5 +88,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS)
if FUSED_PARAM_TBE_ROW_ALIGNMENT in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_IS_WEIGHTED in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_IS_WEIGHTED)

return fused_params_for_tbe
6 changes: 6 additions & 0 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_fused_param_tbe_row_alignment,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
is_fused_param_weighted,
tbe_fused_params,
TBEToRegisterMixIn,
)
Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
managed.append(EmbeddingLocation.HOST)
self._config: GroupedEmbeddingConfig = config
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
self._is_weighted: Optional[bool] = is_fused_param_weighted(fused_params)
self._quant_state_dict_split_scale_bias: bool = (
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
)
Expand Down Expand Up @@ -256,6 +258,10 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(features)
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)
if self._is_weighted:
assert per_sample_weights is not None
elif self._is_weighted is not None:
per_sample_weights = None
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
create_sharding_infos_by_sharding,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_WEIGHTED,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
get_tbes_to_register_from_iterable,
Expand Down Expand Up @@ -351,6 +352,8 @@ def shard(
fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr(
module, FUSED_PARAM_REGISTER_TBE_BOOL, False
)
if FUSED_PARAM_IS_WEIGHTED not in fused_params:
fused_params[FUSED_PARAM_IS_WEIGHTED] = module.is_weighted()

return ShardedQuantEmbeddingBagCollection(
module, params, env, fused_params, device=device
Expand Down

0 comments on commit 59878e8

Please sign in to comment.