From 16ede46ebfcd78fb1daf2e0e33ee3012b2fbc122 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Wed, 17 Apr 2024 07:31:19 -0700 Subject: [PATCH] Add tablewise sharding test (#1878) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1878 As titled Reviewed By: IvanKobzarev Differential Revision: D56123810 fbshipit-source-id: aea5e027476238b4aeaa00bb95f3e45787acbf00 --- .../distributed/tests/test_infer_shardings.py | 170 ++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index ab5a5af24..239cdcd9e 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -131,6 +131,70 @@ def placement_helper(device_type: str, index: int = 0) -> str: class InferShardingsTest(unittest.TestCase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + ) + @settings(max_examples=4, deadline=None) + def test_tw(self, weight_dtype: torch.dtype) -> None: + num_embeddings = 256 + emb_dim = 16 + world_size = 2 + batch_size = 4 + local_device = torch.device("cuda:0") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ((0, 0, num_embeddings, emb_dim), "rank:0/cuda:0"), + ] + ] + sharded_model = shard_qebc( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=expected_shards, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(sharded_output, non_sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module.sparse.ebc", + "embedding_bags", + ["table_0"], + ShardingType.TABLE_WISE.value, + ) + @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", @@ -678,6 +742,112 @@ def test_cw_sequence(self, weight_dtype: torch.dtype) -> None: ShardingType.COLUMN_WISE.value, ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + ) + @settings(max_examples=4, deadline=None) + def test_tw_sequence(self, weight_dtype: torch.dtype) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device("cuda:0") + + topology: Topology = Topology(world_size=world_size, compute_device="cuda") + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=2, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ) + ) + ) + + mi.model.training = False + mi.quant_model = quantize( + mi.model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ((0, 0, num_embeddings, emb_dim), "rank:0/cuda:0"), + ], + [ + ((0, 0, num_embeddings, emb_dim), "rank:1/cuda:1"), + ], + ] + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.TABLE_WISE, + device=local_device, + expected_shards=expected_shards, + ) + + inputs = [ + model_input_to_forward_args_kjt(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + + sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + weights_spec: Dict[str, WeightSpec] = sharded_tbes_weights_spec(sharded_model) + assert_weight_spec( + weights_spec, + expected_shards, + "_module_kjt_input.0", + "embeddings", + ["table_0", "table_1"], + ShardingType.TABLE_WISE.value, + ) + @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available",