Skip to content

Commit

Permalink
enable FSDP (pytorch#64)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/torchrec#64

Pull Request resolved: pytorch/torchrec#15

As title.

Reviewed By: dstaay-fb, bigrabithong, liangluofb

Differential Revision: D33712372

fbshipit-source-id: bbcd7b9adb0580f30afc765dde72df2755b09ed7
  • Loading branch information
xing-liu authored and facebook-github-bot committed Feb 12, 2022
1 parent 5961d65 commit 6ef3827
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
20 changes: 19 additions & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embeddingbag import (
Expand Down Expand Up @@ -97,6 +98,14 @@ def wrap(
)


def _strip_DDP(module: nn.Module) -> nn.Module:
if isinstance(module, FullyShardedDataParallel) or isinstance(
module, DistributedDataParallel
):
module = module.module
return module


class DistributedModelParallel(nn.Module, FusedOptimizerModule):
"""
Entry point to model parallelism.
Expand Down Expand Up @@ -205,6 +214,7 @@ def dmp_module(self) -> nn.Module:
return (
self.module.module
if isinstance(self.module, DistributedDataParallel)
or isinstance(self.module, FullyShardedDataParallel)
else self.module
)

Expand All @@ -217,7 +227,9 @@ def init_data_parallel(self) -> None:
See init_data_parallel c-tor argument for usage.
It's safe to call this method multiple times.
"""
if not isinstance(self.module, DistributedDataParallel):
if not isinstance(self.module, DistributedDataParallel) and not isinstance(
self.module, FullyShardedDataParallel
):
# Allocate any 'meta' tensors
if self.init_parameters:
self._init_parameters(self.module)
Expand Down Expand Up @@ -313,6 +325,7 @@ def sparse_grad_parameter_names(
def _sparse_grad_parameter_names(
self, module: nn.Module, destination: List[str], prefix: str = ""
) -> List[str]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
module.sparse_grad_parameter_names(destination, prefix)
elif isinstance(module, nn.Embedding):
Expand Down Expand Up @@ -348,6 +361,7 @@ def _state_dict(
prefix: str,
keep_vars: bool,
) -> Dict[str, Any]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
module.state_dict(destination, prefix, keep_vars)
else:
Expand All @@ -373,6 +387,7 @@ def _load_state_dict(
) -> _IncompatibleKeys:
missing_keys = []
unexpected_keys = []
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
return module.load_state_dict(state_dict, strict=strict)
else:
Expand All @@ -395,6 +410,7 @@ def _load_state_dict(
def _named_parameters(
self, module: nn.Module, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.named_parameters(prefix, recurse)
else:
Expand All @@ -411,6 +427,7 @@ def named_parameters(

@staticmethod
def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.sharded_parameter_names(prefix)
else:
Expand All @@ -422,6 +439,7 @@ def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[st
def _named_buffers(
self, module: nn.Module, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
module = _strip_DDP(module)
if isinstance(module, ShardedModule):
yield from module.named_buffers(prefix, recurse)
else:
Expand Down
7 changes: 5 additions & 2 deletions torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from torch.autograd.profiler import record_function
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.fx.node import Node
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
Expand Down Expand Up @@ -364,8 +365,10 @@ def _rewrite_model( # noqa C901
) -> List[ShardedModule]:

# Get underlying nn.Module
while isinstance(model, DistributedModelParallel) or isinstance(
model, DistributedDataParallel
while (
isinstance(model, DistributedModelParallel)
or isinstance(model, DistributedDataParallel)
or isinstance(model, FullyShardedDataParallel)
):
model = model.module

Expand Down

0 comments on commit 6ef3827

Please sign in to comment.