From c5a70770228452848d0e67baffad0199000ade41 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 8 Jan 2025 12:42:53 -0800 Subject: [PATCH] Optimzed backward pass for ROCm devices (pt 2) (#3511) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3511 X-link: https://github.com/facebookresearch/FBGEMM/pull/594 - Break up D66310520 (https://github.com/pytorch/FBGEMM/pull/3367) into backend and frontend diffs. This is the frontend diff, and followup to D66986498 Reviewed By: leitian Differential Revision: D67407935 fbshipit-source-id: 18e862d647e962456827fe9d0b9c22d715ebd2ca --- .../backward/embedding_backward_dense_host_cpu.cpp | 6 +++--- .../embedding_backward_split_host_template.cpp | 4 ++-- .../codegen/training/python/lookup_args.template | 1 + .../split_embedding_codegen_lookup_invoker.template | 1 + .../split_table_batched_embeddings_ops_training.py | 11 +++++++++++ fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 1 + 6 files changed, 19 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..3c18a2b9bf 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, @@ -191,7 +191,7 @@ Tensor split_embedding_codegen_lookup_dense_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); @@ -199,7 +199,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 6efccefb8a..a6ccbd7ed1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -1083,7 +1083,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -1200,7 +1200,7 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " Tensor[]? ssd_tensors=None," {%- endif %} " float gwd_lower_bound=0, " - " bool mixed_D=True" + " bool mixed_D=False" ") -> Tensor", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index 357aad622a..f3fd7aa87a 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -49,6 +49,7 @@ class CommonArgs(NamedTuple): {%- if ssd %} ssd_tensors: Dict[str, torch.Tensor] {%- endif %} + mixed_D: bool class OptimizerArgs(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index c69837291b..b55b850c5d 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -409,5 +409,6 @@ def invoke( use_homogeneous_placements=common_args.use_homogeneous_placements, apply_global_weight_decay=apply_global_weight_decay, gwd_lower_bound=gwd_lower_bound, + mixed_D=common_args.mixed_D, ) {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 6f13e5acd6..8f8d5779ea 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -744,6 +744,7 @@ def __init__( # noqa C901 not mixed_D ), "OptimType.NONE does not support mixed embedding dimension" + self.mixed_D: bool = mixed_D if device is None: self.current_device: torch.device = ( torch.device("cpu") @@ -1808,6 +1809,7 @@ def forward( # noqa: C901 is_experimental=self.is_experimental, use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd, use_homogeneous_placements=self.use_homogeneous_placements, + mixed_D=self.mixed_D, ) if self.optimizer == OptimType.NONE: @@ -3583,6 +3585,14 @@ def __init__( ) assert self.D_offsets.numel() == T + 1 + mixed_D = False + D = dims[0] + for d in dims: + if d != D: + mixed_D = True + break + self.mixed_D: bool = mixed_D + # Required for VBE self.register_buffer( "feature_dims", @@ -3696,6 +3706,7 @@ def forward( max_B=vbe_metadata.max_B, max_B_feature_rank=vbe_metadata.max_B_feature_rank, vbe_output_size=vbe_metadata.output_size, + mixed_D=self.mixed_D, ) @torch.jit.export diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 8e167404a8..36dae3e11c 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -1664,6 +1664,7 @@ def forward( }, # pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata` vbe_metadata=vbe_metadata, + mixed_D=False, ) self.timesteps_prefetched.pop(0)