Skip to content

Commit

Permalink
feat: add support for CLI usage
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Dec 13, 2024
1 parent 0991429 commit 780ae7b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
save_fsdp_optimizer,
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module

Expand Down Expand Up @@ -349,6 +349,10 @@ def __init__(
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
Expand All @@ -358,6 +362,13 @@ def __init__(
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

if torch_tp_plugin is None:
torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None)
else:
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand Down
18 changes: 17 additions & 1 deletion src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_cluster_input():
)

fsdp_config = {}
tp_config = {}
if distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
Expand Down Expand Up @@ -475,7 +476,21 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)

if not use_fsdp:
use_tp = _ask_field(
"Do you want to use TensorParallel? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_tp:
distributed_type = DistributedType.TP
if distributed_type == DistributedType.TP:
tp_config["tp_size"] = _ask_field(
"What should be your Tensor Parallel degree? [1e8]: ",
int,
default=1,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(
Expand Down Expand Up @@ -808,6 +823,7 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
tp_config=tp_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for tp
tp_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
Expand Down Expand Up @@ -221,6 +223,8 @@ def __post_init__(self):
self.deepspeed_config = {}
if self.fsdp_config is None:
self.fsdp_config = {}
if self.tp_config is None:
self.tp_config = {}
if self.megatron_lm_config is None:
self.megatron_lm_config = {}
if self.ipex_config is None:
Expand Down
26 changes: 24 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"tpu": "TPU",
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_tp": "PyTorch TP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
}
Expand Down Expand Up @@ -262,6 +263,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_tp",
default=False,
action="store_true",
help="Whether to use PyTorch TP.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
Expand Down Expand Up @@ -588,6 +595,15 @@ def launch_command_parser(subparsers=None):
type=str,
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)

# tp args
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
tp_args.add_argument(
"--tp_size",
default=1,
type=int,
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
)

# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
Expand Down Expand Up @@ -965,9 +981,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):

def _validate_launch_command(args):
# Sanity checks
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
raise ValueError(
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
)
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
Expand All @@ -984,6 +1000,7 @@ def _validate_launch_command(args):
and not args.tpu_use_cluster
and not args.use_deepspeed
and not args.use_fsdp
and not args.use_tp
and not args.use_megatron_lm
):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
Expand All @@ -1001,6 +1018,7 @@ def _validate_launch_command(args):
)
args.tpu = defaults.distributed_type == DistributedType.XLA
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_tp = defaults.distributed_type == DistributedType.TP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
if args.gpu_ids is None:
Expand Down Expand Up @@ -1028,6 +1046,8 @@ def _validate_launch_command(args):
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
for k in defaults.tp_config:
setattr(args, k, defaults.tp_config[k])
for k in defaults.megatron_lm_config:
setattr(args, k, defaults.megatron_lm_config[k])
for k in defaults.dynamo_config:
Expand Down Expand Up @@ -1153,6 +1173,8 @@ def launch_command(args):
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_tp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_megatron_lm and not args.cpu:
multi_gpu_launcher(args)
elif args.multi_gpu and not args.cpu:
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()

if args.use_tp:
current_env["ACCELERATE_USE_TP"] = "true"

if args.use_megatron_lm:
prefix = "MEGATRON_LM_"
current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"
Expand Down

0 comments on commit 780ae7b

Please sign in to comment.