From 61344920302b5ea5dac13532a9d194a7a4fd3658 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 26 Nov 2024 15:38:46 -0800 Subject: [PATCH] add NJT/TD support for EC Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351 --- torchrec/distributed/embedding.py | 43 ++- .../distributed/test_utils/test_sharding.py | 32 +- .../tests/test_sequence_model_parallel.py | 312 ++++++++++++++++++ torchrec/modules/embedding_modules.py | 8 +- 4 files changed, 381 insertions(+), 14 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 5f16efc1b..b44c34c23 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -88,7 +89,12 @@ from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import ( + _to_offsets, + JaggedTensor, + KeyedJaggedTensor, + td_to_kjt, +) try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -1146,25 +1152,50 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + # torch.distributed.breakpoint() + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) + self._create_input_dist(input_feature_names=feature_keys) self._has_uninitialized_input_dist = False with torch.no_grad(): unpadded_features = None - if features.variable_stride_per_key(): + if ( + isinstance(features, KeyedJaggedTensor) + and features.variable_stride_per_key() + ): unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if isinstance(features, KeyedJaggedTensor) and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `TypeUnion[Module, Tensor]`. self._features_order_tensor, ) - features_by_shards = features.split(self._feature_splits) + + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split(self._feature_splits) + else: # TensorDict + feature_names = ( + [feature_keys[i] for i in self._features_order] + if self._features_order # empty features_order means no reordering + else feature_keys + ) + feature_names = [name.split("@")[0] for name in feature_names] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] if self._use_index_dedup: features_by_shards = self._dedup_indices(ctx, features_by_shards) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index dbd8f1007..f74c597f5 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -148,6 +148,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -178,9 +179,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -189,8 +190,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -201,7 +220,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -287,6 +305,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -310,6 +329,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..18c89c6d8 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,315 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.ROW_WISE.value), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_nccl_rw( + self, + sharding_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + qcomms_config=qcomms_config, + ) + ], + backend="nccl", + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.DATA_PARALLEL.value), + kernel_type=st.just(EmbeddingComputeKernel.DENSE.value), + apply_optimizer_in_backward_config=st.just(None), + # TODO - need to enable optimizer overlapped behavior for data_parallel tables + # apply_optimizer_in_backward_config=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_nccl_dp( + self, + sharding_type: str, + kernel_type: str, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + ) -> None: + self._test_sharding( + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + ) + ], + backend="nccl", + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.TABLE_WISE.value), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_nccl_tw( + self, + sharding_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + qcomms_config=qcomms_config, + ) + ], + backend="nccl", + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.just(ShardingType.COLUMN_WISE.value), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_nccl_cw( + self, + sharding_type: str, + kernel_type: str, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + assume( + apply_optimizer_in_backward_config is None + or kernel_type != EmbeddingComputeKernel.DENSE.value + ) + self._test_sharding( + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + ) + ], + backend="nccl", + constraints={ + table.name: ParameterConstraints(min_partition=8) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + # pyre-fixme[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharding_empty_rank(self) -> None: + table = self.tables[0] + embedding_groups = {"group_0": table.feature_names} + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=2, + model_class=TestSequenceSparseNN, + tables=[table], + embedding_groups=embedding_groups, + sharders=[ + TestEmbeddingCollectionSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + ) + ], + optim=EmbOptimType.EXACT_SGD, + backend="nccl", + variable_batch_size=True, + ) + + @seed_and_log + def setUp(self) -> None: + super().setUp() + + num_features = 4 + shared_features = 2 + + initial_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + + shared_features_tables = [ + EmbeddingConfig( + num_embeddings=(i + 1) * 11, + embedding_dim=16, + name="table_" + str(i + num_features), + feature_names=["feature_" + str(i)], + ) + for i in range(shared_features) + ] + + self.tables = initial_tables + shared_features_tables + self.shared_features = [f"feature_{i}" for i in range(shared_features)] + + self.embedding_groups = { + "group_0": [ + ( + f"{feature}@{table.name}" + if feature in self.shared_features + else feature + ) + for table in self.tables + for feature in table.feature_names + ] + } + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type='td', + ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 979360fa8..a35b7a9cf 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -464,7 +464,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -478,7 +478,10 @@ def forward( """ feature_embeddings: Dict[str, JaggedTensor] = {} - jt_dict: Dict[str, JaggedTensor] = features.to_dict() + if isinstance(features, KeyedJaggedTensor): + jt_dict: Dict[str, JaggedTensor] = features.to_dict() + else: + jt_dict = features for i, emb_module in enumerate(self.embeddings.values()): feature_names = self._feature_names[i] embedding_names = self._embedding_names_by_table[i] @@ -491,6 +494,7 @@ def forward( feature_embeddings[embedding_name] = JaggedTensor( values=lookup, lengths=f.lengths(), + offsets=f.offsets() if isinstance(features, TensorDict) else None, weights=f.values() if self._need_indices else None, ) return feature_embeddings