Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly check TBE weights are initialized #1868

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions torchrec/distributed/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,51 @@

from typing import List, Tuple

import torch

from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.quant_embedding import ShardedQuantEmbeddingCollection

from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection


def get_tbe_specs_from_sqebc(
sqebc: ShardedQuantEmbeddingBagCollection,
def get_tbes_from_sharded_module(
module: torch.nn.Module,
) -> List[IntNBitTableBatchedEmbeddingBagsCodegen]:
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBEs"
tbes = []
for lookup in module._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
tbes.append(emb_module._emb_module)
return tbes


def get_tbe_specs_from_sharded_module(
module: torch.nn.Module,
) -> List[
Tuple[str, int, int, str, str]
]: # # tuple of (feature_names, rows, dims, str(SparseType), str(EmbeddingLocation/placement))
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBE specs"
tbe_specs = []
for lookup in sqebc._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
for spec in emb_module._emb_module.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
tbes = get_tbes_from_sharded_module(module)
for tbe in tbes:
for spec in tbe.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
return tbe_specs
4 changes: 2 additions & 2 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
)
from torchrec.distributed.planner.types import ParameterConstraints, Perf, Topology
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.infer_utils import quantize
from torchrec.distributed.test_utils.test_model import TestEBCSharder, TestSparseNN
from torchrec.distributed.tests.test_quant_model_parallel import _quantize
from torchrec.distributed.tests.test_sequence_model import TestSequenceSparseNN
from torchrec.distributed.types import (
CacheParams,
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_inference_1_table_perf(self) -> None:
)
]
model = TestSparseNN(tables=tables, weighted_tables=[])
quant_model = _quantize(model, inplace=True)
quant_model = quantize(model, inplace=True)

inference_estimator = EmbeddingPerfEstimator(
topology=self.topology, is_inference=True
Expand Down
19 changes: 19 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
data_type_to_sparse_type,
dtype_to_data_type,
EmbeddingBagConfig,
QuantConfig,
)
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
Expand Down Expand Up @@ -248,6 +249,7 @@ def quantize(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
weight_dtype: torch.dtype = torch.qint8,
per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None,
) -> torch.nn.Module:
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.embedding_modules.EmbeddingBagCollection,
Expand All @@ -264,6 +266,14 @@ def quantize(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
)

if per_table_weight_dtypes:
qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8),
per_table_weight_dtype=per_table_weight_dtypes,
)

return quant.quantize_dynamic(
module,
qconfig_spec={
Expand All @@ -285,6 +295,7 @@ def quantize_fpebc(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
weight_dtype: torch.dtype = torch.qint8,
per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None,
) -> torch.nn.Module:
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection,
Expand All @@ -300,6 +311,14 @@ def quantize_fpebc(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
)

if per_table_weight_dtypes:
qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8),
per_table_weight_dtype=per_table_weight_dtypes,
)

return quant.quantize_dynamic(
module,
qconfig_spec={
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
KeyedJaggedTensor,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.infer_utils import get_tbes_from_sharded_module
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.shard_estimators import (
Expand Down Expand Up @@ -969,6 +970,10 @@ def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> Non
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

tbes = get_tbes_from_sharded_module(sharded_model._module_kjt_input[0])
for tbe in tbes:
self.assertTrue(tbe.weight_initialized)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down
97 changes: 86 additions & 11 deletions torchrec/distributed/tests/test_infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,27 @@
#!/usr/bin/env python3

import unittest
from typing import cast

import torch

from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.infer_utils import get_tbe_specs_from_sqebc
from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection
from torchrec.distributed.infer_utils import get_tbe_specs_from_sharded_module
from torchrec.distributed.shard import _shard_modules
from torchrec.distributed.sharding_plan import (
construct_module_sharding_plan,
table_wise,
)
from torchrec.distributed.tests.test_quant_model_parallel import (
_quantize,
from torchrec.distributed.test_utils.infer_utils import (
quantize,
TestQuantEBCSharder,
TestQuantECSharder,
)
from torchrec.distributed.types import ShardingEnv, ShardingPlan, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)


class UtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -58,7 +60,7 @@ def test_get_tbe_specs_from_sqebc(self) -> None:
)
model.training = False

quant_model = _quantize(
quant_model = quantize(
model,
inplace=True,
output_type=torch.float,
Expand Down Expand Up @@ -95,9 +97,7 @@ def test_get_tbe_specs_from_sqebc(self) -> None:
env=ShardingEnv.from_local(world_size=2, rank=0),
)

specs = get_tbe_specs_from_sqebc(
cast(ShardedQuantEmbeddingBagCollection, sharded_model)
)
specs = get_tbe_specs_from_sharded_module(sharded_model)

expected_specs = [
("table_1", 40, 20, "int8", "EmbeddingLocation.DEVICE"),
Expand All @@ -106,3 +106,78 @@ def test_get_tbe_specs_from_sqebc(self) -> None:
]

self.assertEqual(specs, expected_specs)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
def test_get_tbe_specs_from_sqec(self) -> None:
device = torch.device("cuda:0")

num_features = 3

tables = [
EmbeddingConfig(
num_embeddings=(i + 1) * 20,
embedding_dim=10,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_features)
]

model = torch.nn.Sequential(
EmbeddingCollection(
tables=tables,
device=device,
)
)
model.training = False

quant_model = quantize(
model,
inplace=True,
output_type=torch.float,
quant_state_dict_split_scale_bias=True,
)

sharder = TestQuantECSharder(
sharding_type=ShardingType.TABLE_WISE.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[f"table_{i}" for i in range(num_features)],
)

module_plan = construct_module_sharding_plan(
quant_model[0],
per_param_sharding={
"table_0": table_wise(rank=1),
"table_1": table_wise(rank=0),
"table_2": table_wise(rank=0),
},
# pyre-ignore
sharder=sharder,
local_size=2,
world_size=2,
)

plan = ShardingPlan(plan={"": module_plan})

sharded_model = _shard_modules(
module=quant_model[0],
# pyre-ignore
sharders=[sharder],
device=device,
plan=plan,
env=ShardingEnv.from_local(world_size=2, rank=0),
)

specs = get_tbe_specs_from_sharded_module(sharded_model)

expected_specs = [
("table_1", 40, 10, "int8", "EmbeddingLocation.DEVICE"),
("table_2", 60, 10, "int8", "EmbeddingLocation.DEVICE"),
("table_0", 20, 10, "int8", "EmbeddingLocation.DEVICE"),
]

self.assertEqual(specs, expected_specs)
Loading
Loading