Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tensor parallel & Data loader #3173

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -107,7 +108,7 @@
save_fsdp_optimizer,
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module

Expand Down Expand Up @@ -188,6 +189,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -254,6 +258,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -344,6 +349,12 @@ def __init__(
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
torch_tp_plugin, TorchTensorParallelPlugin
):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
Expand All @@ -353,6 +364,15 @@ def __init__(
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

if torch_tp_plugin is None:
torch_tp_plugin = (
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
)
else:
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand Down Expand Up @@ -418,6 +438,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
Expand Down Expand Up @@ -1461,6 +1482,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.supports_tp_plan:
raise NotImplementedError("Provided model does not support tensor parallelism")
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2117,6 +2142,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
18 changes: 17 additions & 1 deletion src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_cluster_input():
)

fsdp_config = {}
tp_config = {}
if distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
Expand Down Expand Up @@ -475,7 +476,21 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)

if not use_fsdp:
use_tp = _ask_field(
"Do you want to use TensorParallel? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_tp:
distributed_type = DistributedType.TP
if distributed_type == DistributedType.TP:
tp_config["tp_size"] = _ask_field(
"What should be your Tensor Parallel degree? [1]: ",
int,
default=1,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(
Expand Down Expand Up @@ -808,6 +823,7 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
tp_config=tp_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for tp
tp_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
Expand Down Expand Up @@ -221,6 +223,8 @@ def __post_init__(self):
self.deepspeed_config = {}
if self.fsdp_config is None:
self.fsdp_config = {}
if self.tp_config is None:
self.tp_config = {}
if self.megatron_lm_config is None:
self.megatron_lm_config = {}
if self.ipex_config is None:
Expand Down
26 changes: 24 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"tpu": "TPU",
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_tp": "PyTorch TP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
}
Expand Down Expand Up @@ -262,6 +263,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_tp",
default=False,
action="store_true",
help="Whether to use PyTorch TP.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
Expand Down Expand Up @@ -589,6 +596,15 @@ def launch_command_parser(subparsers=None):
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)

# tp args
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
tp_args.add_argument(
"--tp_size",
default=1,
type=int,
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
)

# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
megatron_lm_args.add_argument(
Expand Down Expand Up @@ -965,9 +981,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):

def _validate_launch_command(args):
# Sanity checks
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
raise ValueError(
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
)
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
Expand All @@ -984,6 +1000,7 @@ def _validate_launch_command(args):
and not args.tpu_use_cluster
and not args.use_deepspeed
and not args.use_fsdp
and not args.use_tp
and not args.use_megatron_lm
):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
Expand All @@ -1001,6 +1018,7 @@ def _validate_launch_command(args):
)
args.tpu = defaults.distributed_type == DistributedType.XLA
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_tp = defaults.distributed_type == DistributedType.TP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
if args.gpu_ids is None:
Expand Down Expand Up @@ -1028,6 +1046,8 @@ def _validate_launch_command(args):
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
for k in defaults.tp_config:
setattr(args, k, defaults.tp_config[k])
for k in defaults.megatron_lm_config:
setattr(args, k, defaults.megatron_lm_config[k])
for k in defaults.dynamo_config:
Expand Down Expand Up @@ -1153,6 +1173,8 @@ def launch_command(args):
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_tp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_megatron_lm and not args.cpu:
multi_gpu_launcher(args)
elif args.multi_gpu and not args.cpu:
Expand Down
75 changes: 73 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
torch_device_mesh=None,
**kwargs,
):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
Expand Down Expand Up @@ -713,6 +714,7 @@ def __init__(
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
torch_device_mesh=None,
**kwargs,
):
shuffle = False
Expand All @@ -732,26 +734,68 @@ def __init__(
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
self.torch_device_mesh = torch_device_mesh

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# below code is for targetted support for dp, fsdp and tp

# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behavour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
# different batches for each process irrespective of dp or fsdp
self.submesh_tp = None
self.submesh_dp = None
self.submesh_fsdp = None
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_tp = self.torch_device_mesh["tp"]
if "dp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_dp = self.torch_device_mesh["dp"]
if "fsdp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_fsdp = self.torch_device_mesh["fsdp"]
if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.

try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
if self.submesh_tp:
logger.warning(
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead."
)
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
if self.submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batches.append(next(iterator))
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
Expand Down Expand Up @@ -942,6 +986,7 @@ def prepare_data_loader(
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh: torch.distributed.DeviceMesh = None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1009,6 +1054,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
PyTorch device mesh.


Returns:
Expand All @@ -1033,8 +1080,31 @@ def prepare_data_loader(
state = PartialState()
if num_processes is None:
num_processes = state.num_processes

# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP ranks
# and different batch across FSDP and DP ranks.
# Example:
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# processes with same ranks/ids would receive the same batch
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if torch_device_mesh:
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
num_processes = submesh_fsdp_size * submesh_dp_size
if process_index is None:
process_index = state.process_index
if torch_device_mesh:
process_index = process_index // submesh_tp_size

# Sanity check
if split_batches:
Expand Down Expand Up @@ -1144,6 +1214,7 @@ def prepare_data_loader(
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
_from_accelerator: bool = False,
**kwargs,
Expand All @@ -864,6 +865,7 @@ def __init__(
if not self.initialized:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -921,6 +923,8 @@ def __init__(
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
Loading
Loading