Skip to content

Commit

Permalink
Add tablewise sharding test (#1878)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1878

As titled

Reviewed By: IvanKobzarev

Differential Revision: D56123810

fbshipit-source-id: aea5e027476238b4aeaa00bb95f3e45787acbf00
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 17, 2024
1 parent b357837 commit 16ede46
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,70 @@ def placement_helper(device_type: str, index: int = 0) -> str:


class InferShardingsTest(unittest.TestCase):
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
)
@settings(max_examples=4, deadline=None)
def test_tw(self, weight_dtype: torch.dtype) -> None:
num_embeddings = 256
emb_dim = 16
world_size = 2
batch_size = 4
local_device = torch.device("cuda:0")
mi = create_test_model(
num_embeddings,
emb_dim,
world_size,
batch_size,
dense_device=local_device,
sparse_device=local_device,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
)

non_sharded_model = mi.quant_model
expected_shards = [
[
((0, 0, num_embeddings, emb_dim), "rank:0/cuda:0"),
]
]
sharded_model = shard_qebc(
mi,
sharding_type=ShardingType.TABLE_WISE,
device=local_device,
expected_shards=expected_shards,
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())

sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(sharded_output, non_sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model)
assert_weight_spec(
weights_spec,
expected_shards,
"_module.sparse.ebc",
"embedding_bags",
["table_0"],
ShardingType.TABLE_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down Expand Up @@ -678,6 +742,112 @@ def test_cw_sequence(self, weight_dtype: torch.dtype) -> None:
ShardingType.COLUMN_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
)
@settings(max_examples=4, deadline=None)
def test_tw_sequence(self, weight_dtype: torch.dtype) -> None:
num_embeddings = 10
emb_dim = 16
world_size = 2
batch_size = 2
local_device = torch.device("cuda:0")

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
mi = TestModelInfo(
dense_device=local_device,
sparse_device=local_device,
num_features=2,
num_float_features=10,
num_weighted_features=0,
topology=topology,
)

mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology, is_inference=True),
EmbeddingStorageEstimator(topology=topology),
],
),
)

mi.tables = [
EmbeddingConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(mi.num_features)
]

mi.model = KJTInputWrapper(
module_kjt_input=torch.nn.Sequential(
EmbeddingCollection(
tables=mi.tables,
device=mi.sparse_device,
)
)
)

mi.model.training = False
mi.quant_model = quantize(
mi.model,
inplace=False,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
)
non_sharded_model = mi.quant_model
expected_shards = [
[
((0, 0, num_embeddings, emb_dim), "rank:0/cuda:0"),
],
[
((0, 0, num_embeddings, emb_dim), "rank:1/cuda:1"),
],
]
sharded_model = shard_qec(
mi,
sharding_type=ShardingType.TABLE_WISE,
device=local_device,
expected_shards=expected_shards,
)

inputs = [
model_input_to_forward_args_kjt(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model)
assert_weight_spec(
weights_spec,
expected_shards,
"_module_kjt_input.0",
"embeddings",
["table_0", "table_1"],
ShardingType.TABLE_WISE.value,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down

0 comments on commit 16ede46

Please sign in to comment.