Skip to content

Commit

Permalink
Explicitly check TBE weights are initialized (#1868)
Browse files Browse the repository at this point in the history
Summary:

As per request to check if TBE are initialized in test
Also refactor

Differential Revision: D55995107
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 11, 2024
1 parent 7df845b commit 2cd2186
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 24 deletions.
55 changes: 40 additions & 15 deletions torchrec/distributed/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,51 @@

from typing import List, Tuple

import torch

from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.quant_embedding import ShardedQuantEmbeddingCollection

from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection


def get_tbe_specs_from_sqebc(
sqebc: ShardedQuantEmbeddingBagCollection,
def get_tbes_from_sharded_module(
module: torch.nn.Module,
) -> List[IntNBitTableBatchedEmbeddingBagsCodegen]:
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBEs"
tbes = []
for lookup in module._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
tbes.append(emb_module._emb_module)
return tbes


def get_tbe_specs_from_sharded_module(
module: torch.nn.Module,
) -> List[
Tuple[str, int, int, str, str]
]: # # tuple of (feature_names, rows, dims, str(SparseType), str(EmbeddingLocation/placement))
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBE specs"
tbe_specs = []
for lookup in sqebc._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
for spec in emb_module._emb_module.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
tbes = get_tbes_from_sharded_module(module)
for tbe in tbes:
for spec in tbe.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
return tbe_specs
5 changes: 5 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
KeyedJaggedTensor,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.infer_utils import get_tbes_from_sharded_module
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.shard_estimators import (
Expand Down Expand Up @@ -969,6 +970,10 @@ def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> Non
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

tbes = get_tbes_from_sharded_module(sharded_model._module_kjt_input[0])
for tbe in tbes:
self.assertTrue(tbe.weight_initialized)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down
96 changes: 87 additions & 9 deletions torchrec/distributed/tests/test_infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@
#!/usr/bin/env python3

import unittest
from typing import cast

import torch

from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.infer_utils import get_tbe_specs_from_sqebc
from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection
from torchrec.distributed.infer_utils import get_tbe_specs_from_sharded_module
from torchrec.distributed.shard import _shard_modules
from torchrec.distributed.sharding_plan import (
construct_module_sharding_plan,
table_wise,
)
from torchrec.distributed.test_utils.infer_utils import quantize, TestQuantEBCSharder
from torchrec.distributed.test_utils.infer_utils import (
quantize,
TestQuantEBCSharder,
TestQuantECSharder,
)
from torchrec.distributed.types import ShardingEnv, ShardingPlan, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)


class UtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -92,9 +97,7 @@ def test_get_tbe_specs_from_sqebc(self) -> None:
env=ShardingEnv.from_local(world_size=2, rank=0),
)

specs = get_tbe_specs_from_sqebc(
cast(ShardedQuantEmbeddingBagCollection, sharded_model)
)
specs = get_tbe_specs_from_sharded_module(sharded_model)

expected_specs = [
("table_1", 40, 20, "int8", "EmbeddingLocation.DEVICE"),
Expand All @@ -103,3 +106,78 @@ def test_get_tbe_specs_from_sqebc(self) -> None:
]

self.assertEqual(specs, expected_specs)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
def test_get_tbe_specs_from_sqec(self) -> None:
device = torch.device("cuda:0")

num_features = 3

tables = [
EmbeddingConfig(
num_embeddings=(i + 1) * 20,
embedding_dim=10,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_features)
]

model = torch.nn.Sequential(
EmbeddingCollection(
tables=tables,
device=device,
)
)
model.training = False

quant_model = quantize(
model,
inplace=True,
output_type=torch.float,
quant_state_dict_split_scale_bias=True,
)

sharder = TestQuantECSharder(
sharding_type=ShardingType.TABLE_WISE.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[f"table_{i}" for i in range(num_features)],
)

module_plan = construct_module_sharding_plan(
quant_model[0],
per_param_sharding={
"table_0": table_wise(rank=1),
"table_1": table_wise(rank=0),
"table_2": table_wise(rank=0),
},
# pyre-ignore
sharder=sharder,
local_size=2,
world_size=2,
)

plan = ShardingPlan(plan={"": module_plan})

sharded_model = _shard_modules(
module=quant_model[0],
# pyre-ignore
sharders=[sharder],
device=device,
plan=plan,
env=ShardingEnv.from_local(world_size=2, rank=0),
)

specs = get_tbe_specs_from_sharded_module(sharded_model)

expected_specs = [
("table_1", 40, 10, "int8", "EmbeddingLocation.DEVICE"),
("table_2", 60, 10, "int8", "EmbeddingLocation.DEVICE"),
("table_0", 20, 10, "int8", "EmbeddingLocation.DEVICE"),
]

self.assertEqual(specs, expected_specs)

0 comments on commit 2cd2186

Please sign in to comment.