Skip to content

Commit

Permalink
feat: support tensor parallel using Pytorch 2.0
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Nov 4, 2024
1 parent bf4572b commit c096d40
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 3 deletions.
9 changes: 9 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -188,6 +189,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -254,6 +258,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -418,6 +423,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
Expand Down Expand Up @@ -1461,6 +1467,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2117,6 +2125,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
37 changes: 35 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def __init__(
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
torch_device_mesh=None,
**kwargs,
):
shuffle = False
Expand All @@ -732,15 +733,37 @@ def __init__(
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
self.torch_device_mesh = torch_device_mesh

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

# if a device mesh is provided extract each dimension (tp and dp)
# device mesh will be used only if there is tp involved
# otherwise the default behavour should be sufficient
self.submesh_tp = None
self.submesh_dp = None
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
# extract torch sub device mesh objects
self.submesh_tp = self.torch_device_mesh["tp"]
if "dp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_dp = self.torch_device_mesh["dp"]
if self.submesh_tp and self.submesh_dp:
raise ValueError("TP + DDP / TP + FSDP is not yet supported")

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.

try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._update_state_dict()
Expand All @@ -749,9 +772,15 @@ def _fetch_batches(self, iterator):
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
if self.submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batches.append(next(iterator))
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
Expand Down Expand Up @@ -942,6 +971,7 @@ def prepare_data_loader(
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh: torch.distributed.DeviceMesh = None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1009,6 +1039,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
PyTorch device mesh.
Returns:
Expand Down Expand Up @@ -1144,6 +1176,7 @@ def prepare_data_loader(
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
_from_accelerator: bool = False,
**kwargs,
Expand All @@ -864,6 +865,7 @@ def __init__(
if not self.initialized:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -921,6 +923,8 @@ def __init__(
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SageMakerDistributedType,
TensorInformation,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
)
from .environment import (
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

Expand Down Expand Up @@ -76,7 +77,7 @@
"master_port",
]

CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"]
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
"MULTI_NPU",
"MULTI_MLU",
Expand Down
28 changes: 28 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch

from .constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
FSDP_SHARDING_STRATEGY,
Expand Down Expand Up @@ -540,6 +541,7 @@ class DistributedType(str, enum.Enum):
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"

Expand Down Expand Up @@ -1810,6 +1812,32 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)


@dataclass
class TorchTensorParallelPlugin:
"""
This plugin is used to enable tensor parallelism using PyTorch >= 2.0.
"""

tp_size: int = field(
default=1,
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
)

# type has to be "torch.distributed.DeviceMesh"
torch_device_mesh: torch.distributed.DeviceMesh = field(default=None)

def __post_init__(self):
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
)
from torch.distributed.device_mesh import init_device_mesh

mesh_dim_name = "tp"
device = "cuda" # support for other devices has to be investigated
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))


@dataclass
class MegatronLMPlugin:
"""
Expand Down

0 comments on commit c096d40

Please sign in to comment.