Skip to content

Commit

Permalink
Broadcast DP tables when init
Browse files Browse the repository at this point in the history
Summary:
# Problem

Problem is once we wrap DP tables with DDP, the parameters are the same, but not synced. So if we reset the parameters in the table without using the same manual seed, it could cause the DP table parameters to be different.

In other words, DP tables would be initialized differently.

# Fix

There are a few ways to fix it. This is the way we believe to be least invasive and follow the spirit of the api.

What we do:
1. Broadcast DP tables of rank 0 to all other ranks
2. This only happens during init, or when we call reset_parameters

Reviewed By: joshuadeng

Differential Revision: D55227979

fbshipit-source-id: 52c225501df86cf35b196afff48dc2fa88e7a2c3
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Apr 10, 2024
1 parent 9b1f395 commit cc482f8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 7 additions & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, cast, Dict, List, MutableMapping, Optional, Type, Union

import torch
from torch import nn
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -610,6 +610,12 @@ def reset_parameters(self) -> None:
# pyre-ignore
table_config.init_fn(param)

sharding_type = self.module_sharding_plan[table_config.name].sharding_type
if sharding_type == ShardingType.DATA_PARALLEL.value:
pg = self._env.process_group
with torch.no_grad():
dist.broadcast(param.data, src=0, group=pg)

def _generate_permute_indices_per_feature(
self,
embedding_configs: List[EmbeddingConfig],
Expand Down
8 changes: 7 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from torch import nn, Tensor
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -728,6 +728,12 @@ def reset_parameters(self) -> None:
# pyre-ignore
table_config.init_fn(param)

sharding_type = self.module_sharding_plan[table_config.name].sharding_type
if sharding_type == ShardingType.DATA_PARALLEL.value:
pg = self._env.process_group
with torch.no_grad():
dist.broadcast(param.data, src=0, group=pg)

def _create_input_dist(
self,
input_feature_names: List[str],
Expand Down

0 comments on commit cc482f8

Please sign in to comment.