Skip to content

Commit

Permalink
Support sending using lengths to TBE instead of just offsets (pytorch…
Browse files Browse the repository at this point in the history
…#2595)

Summary:

The TorchRec part of supporting lengths to TBE for request coalescing in Ads inference

Differential Revision: D66515313
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Dec 2, 2024
1 parent 99162b5 commit 24927e3
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 48 deletions.
20 changes: 19 additions & 1 deletion torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional

import torch

Expand All @@ -24,6 +24,10 @@
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode"

# Force lengths to offsets conversion before TBE lookup. Helps with performance
# with certain ways to split models.
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"


class TBEToRegisterMixIn:
def get_tbes_to_register(
Expand Down Expand Up @@ -68,6 +72,18 @@ def fused_param_bounds_check_mode(
return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE]


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
) -> bool:
if (
fused_params is None
or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params
):
return False
else:
return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -93,5 +109,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)

return fused_params_for_tbe
139 changes: 92 additions & 47 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchrec.distributed.fused_params import (
fused_param_bounds_check_mode,
fused_param_lengths_to_offsets_lookup,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
Expand Down Expand Up @@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu(
return indices, offsets, None


@torch.fx.wrap
def _unwrap_kjt_lengths(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
indices = features.values()
lengths = features.lengths()
return (
indices.int(),
lengths.int(),
features.weights_or_none(),
)


@torch.fx.wrap
def _unwrap_optional_tensor(
tensor: Optional[torch.Tensor],
Expand All @@ -180,6 +194,26 @@ def _unwrap_optional_tensor(
return tensor


class IntNBitTableBatchedEmbeddingBagsCodegenWithLength(
IntNBitTableBatchedEmbeddingBagsCodegen
):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# pyre-ignore Inconsistent override [14]
def forward(
self,
indices: torch.Tensor,
lengths: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self._forward_impl(
indices=indices,
offsets=(torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)),
per_sample_weights=per_sample_weights,
)


class QuantBatchedEmbeddingBag(
BaseBatchedEmbeddingBag[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
Expand Down Expand Up @@ -237,22 +271,27 @@ def __init__(
)
)

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params)

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
Expand All @@ -271,44 +310,50 @@ def get_tbes_to_register(
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return {self._emb_module: self._config}

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
if self._runtime_device.type == "cpu":
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
def _emb_module_forward(
self,
indices: torch.Tensor,
lengths_or_offsets: torch.Tensor,
weights: Optional[torch.Tensor],
) -> torch.Tensor:
kwargs = {"indices": indices}

if self.lengths_to_tbe:
kwargs["lengths"] = lengths_or_offsets
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)
kwargs["offsets"] = lengths_or_offsets

if self._is_weighted:
weights = _unwrap_optional_tensor(per_sample_weights)
if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self.emb_module(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
)
kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights)

if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self._emb_module(**kwargs)
else:
return self._emb_module.forward(**kwargs)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
lengths, offsets = None, None
if self._runtime_device.type == "cpu":
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
if self._emb_module_registered:
return self.emb_module(
indices=indices,
offsets=offsets,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
)
indices, offsets, per_sample_weights = _unwrap_kjt(features)

return self._emb_module_forward(
indices, lengths if lengths is not None else offsets, per_sample_weights
)

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
FUSED_PARAM_BOUNDS_CHECK_MODE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
Expand Down Expand Up @@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str:
FUSED_PARAM_REGISTER_TBE_BOOL: True,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True,
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False,
}

DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [
Expand Down
61 changes: 61 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@

import unittest
from argparse import Namespace
from typing import Any, cast, Dict, List

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from torch.fx import symbolic_trace
from torchrec import PoolingType
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.test_model import (
ModelInput,
TestOverArchRegroupModule,
TestSparseNN,
)
from torchrec.distributed.types import ModuleSharder

from torchrec.inference.dlrm_predict import (
create_training_batch,
Expand Down Expand Up @@ -300,6 +305,62 @@ def test_sharded_quantized_tbe_count(self) -> None:
expected_num_embeddings[spec[0]],
)

def test_sharded_quantized_lengths_to_tbe(self) -> None:
set_propogate_device(True)

fused_params: Dict[str, Any] = {FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: True}
sharders: List[ModuleSharder[torch.nn.Module]] = [
cast(
ModuleSharder[torch.nn.Module],
QuantEmbeddingBagCollectionSharder(fused_params=fused_params),
),
]

model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=torch.device("cpu"),
sparse_device=torch.device("cpu"),
over_arch_clazz=TestOverArchRegroupModule,
)

model.eval()
_, local_batch = ModelInput.generate(
batch_size=16,
world_size=1,
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
)

# with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types
output = model(local_batch[0])

# Quantize the model and collect quantized weights
quantized_model = quantize_inference_model(model)
quantized_output = quantized_model(local_batch[0])
table_to_weight = get_table_to_weights_from_tbe(quantized_model)

# Shard the model, all weights are initialized back to 0, so have to reassign weights
sharded_quant_model, _ = shard_quant_model(
quantized_model,
world_size=1,
compute_device="cpu",
sharding_device="cpu",
sharders=sharders,
)
assign_weights_to_tbe(quantized_model, table_to_weight)
sharded_quant_output = sharded_quant_model(local_batch[0])

# When world_size = 1, we should have 1 TBE per sharded, quantized ebc
self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1)
self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1)

# Check the weights are close
self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3))
self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3))

def test_quantized_tbe_count_different_pooling(self) -> None:
set_propogate_device(True)

Expand Down

0 comments on commit 24927e3

Please sign in to comment.