Skip to content

Commit

Permalink
Optimzed backward pass for ROCm devices (pt 2) (pytorch#3511)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3511

X-link: facebookresearch/FBGEMM#594

- Break up D66310520 (pytorch#3367) into backend and frontend diffs.  This is the frontend diff, and followup to D66986498

Reviewed By: leitian

Differential Revision: D67407935

fbshipit-source-id: 18e862d647e962456827fe9d0b9c22d715ebd2ca
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 8, 2025
1 parent 8b748c6 commit c5a7077
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = true */) {
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand All @@ -191,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- else %}
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = true
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down Expand Up @@ -1200,7 +1200,7 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
" Tensor[]? ssd_tensors=None,"
{%- endif %}
" float gwd_lower_bound=0, "
" bool mixed_D=True"
" bool mixed_D=False"
") -> Tensor",
{PT2_COMPLIANT_TAG});

Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/training/python/lookup_args.template
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CommonArgs(NamedTuple):
{%- if ssd %}
ssd_tensors: Dict[str, torch.Tensor]
{%- endif %}
mixed_D: bool


class OptimizerArgs(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,6 @@ def invoke(
use_homogeneous_placements=common_args.use_homogeneous_placements,
apply_global_weight_decay=apply_global_weight_decay,
gwd_lower_bound=gwd_lower_bound,
mixed_D=common_args.mixed_D,
)
{%- endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def __init__( # noqa C901
not mixed_D
), "OptimType.NONE does not support mixed embedding dimension"

self.mixed_D: bool = mixed_D
if device is None:
self.current_device: torch.device = (
torch.device("cpu")
Expand Down Expand Up @@ -1808,6 +1809,7 @@ def forward( # noqa: C901
is_experimental=self.is_experimental,
use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd,
use_homogeneous_placements=self.use_homogeneous_placements,
mixed_D=self.mixed_D,
)

if self.optimizer == OptimType.NONE:
Expand Down Expand Up @@ -3583,6 +3585,14 @@ def __init__(
)
assert self.D_offsets.numel() == T + 1

mixed_D = False
D = dims[0]
for d in dims:
if d != D:
mixed_D = True
break
self.mixed_D: bool = mixed_D

# Required for VBE
self.register_buffer(
"feature_dims",
Expand Down Expand Up @@ -3696,6 +3706,7 @@ def forward(
max_B=vbe_metadata.max_B,
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
vbe_output_size=vbe_metadata.output_size,
mixed_D=self.mixed_D,
)

@torch.jit.export
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,7 @@ def forward(
},
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
vbe_metadata=vbe_metadata,
mixed_D=False,
)

self.timesteps_prefetched.pop(0)
Expand Down

0 comments on commit c5a7077

Please sign in to comment.