Skip to content

Commit

Permalink
feat: add support for tensor parallel flow using accelerate
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Jan 23, 2025
1 parent d3af76d commit aa7d872
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -5076,6 +5077,11 @@ def create_accelerator_and_postprocess(self):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
self.is_tp_enabled = True
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)

# create accelerator object
self.accelerator = Accelerator(**args)
Expand All @@ -5090,7 +5096,7 @@ def create_accelerator_and_postprocess(self):
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

self.is_tp_enabled = getattr(self.accelerator.state, "tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
tp_size (`int`, *optional*):
Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is
used to prepare device mesh internally.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
Expand Down Expand Up @@ -1240,6 +1242,16 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch 2.0 tensor parallelism."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1957,6 +1969,8 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

if self.tp_size > 1:
os.environ["ACCELERATE_USE_TP"] = "true"
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
Expand Down

0 comments on commit aa7d872

Please sign in to comment.