From aa7d87241f150a8a324d84c3f2132ec1c4857323 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Tue, 15 Oct 2024 10:15:22 +0530 Subject: [PATCH 1/2] feat: add support for tensor parallel flow using accelerate Signed-off-by: Mehant Kammakomati --- src/transformers/trainer.py | 8 +++++++- src/transformers/training_args.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f45ff46bdd8..404bd2ce5d1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -233,6 +233,7 @@ AutocastKwargs, DistributedDataParallelKwargs, DistributedType, + TorchTensorParallelPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -5076,6 +5077,11 @@ def create_accelerator_and_postprocess(self): args["dataloader_config"] = dataloader_config else: args.update(accelerator_config) + # tp is initialized at Accelerator init phase so + # args should be prepared here + if self.args.tp_size > 1: + self.is_tp_enabled = True + args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size) # create accelerator object self.accelerator = Accelerator(**args) @@ -5090,7 +5096,7 @@ def create_accelerator_and_postprocess(self): # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - + self.is_tp_enabled = getattr(self.accelerator.state, "tp_plugin", None) is not None # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 00b9c82ec28..abbf763b993 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -566,7 +566,9 @@ class TrainingArguments: Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. - + tp_size (`int`, *optional*): + Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is + used to prepare device mesh internally. deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may evolve in the future. The value is either the location of DeepSpeed json config file (e.g., @@ -1240,6 +1242,16 @@ class TrainingArguments: ) }, ) + tp_size: Optional[int] = field( + default=0, + metadata={ + "help": ( + "Use tp_size to enable pytorch 2.0 tensor parallelism." + "Set a value greater than 1 to activate TP." + "The same is used to prepare device mesh internally." + ) + }, + ) fsdp_transformer_layer_cls_to_wrap: Optional[str] = field( default=None, metadata={ @@ -1957,6 +1969,8 @@ def __post_init__(self): if self.fsdp_config["xla_fsdp_grad_ckpt"]: warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + if self.tp_size > 1: + os.environ["ACCELERATE_USE_TP"] = "true" # accelerate integration for FSDP if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: os.environ["ACCELERATE_USE_FSDP"] = "true" From e79f4d17767162e9a14b155c7d45f6fece509eed Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 13 Dec 2024 20:41:56 +0530 Subject: [PATCH 2/2] fix: add tp degree to env variable Signed-off-by: Mehant Kammakomati --- src/transformers/training_args.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index abbf763b993..79cf7f9c94d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -567,7 +567,7 @@ class TrainingArguments: used when the xla flag is set to true, and an auto wrapping policy is specified through fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. tp_size (`int`, *optional*): - Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is + Use tp_size to enable PyTorch tensor parallelism. Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may @@ -1246,7 +1246,7 @@ class TrainingArguments: default=0, metadata={ "help": ( - "Use tp_size to enable pytorch 2.0 tensor parallelism." + "Use tp_size to enable pytorch tensor parallelism." "Set a value greater than 1 to activate TP." "The same is used to prepare device mesh internally." ) @@ -1971,6 +1971,7 @@ def __post_init__(self): if self.tp_size > 1: os.environ["ACCELERATE_USE_TP"] = "true" + os.environ["TP_SIZE"] = str(self.tp_size) # accelerate integration for FSDP if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: os.environ["ACCELERATE_USE_FSDP"] = "true"