From d70e0dd2af52ad591ca3f62f371a67719f58c545 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:49:49 -0400 Subject: [PATCH] Expose scattermoe (#13) Signed-off-by: Mayank Mishra --- dolomite_engine/arguments.py | 5 ++++- dolomite_engine/enums.py | 9 +++++++++ dolomite_engine/model_wrapper/__init__.py | 1 + dolomite_engine/model_wrapper/base.py | 6 +++++- dolomite_engine/model_wrapper/distillation.py | 3 ++- dolomite_engine/model_wrapper/pretraining.py | 4 +++- 6 files changed, 24 insertions(+), 4 deletions(-) diff --git a/dolomite_engine/arguments.py b/dolomite_engine/arguments.py index 6728b48a..b8614f30 100644 --- a/dolomite_engine/arguments.py +++ b/dolomite_engine/arguments.py @@ -19,6 +19,7 @@ LossMask, LRDecaySchedule, Mode, + MoEImplementation, ParamsGroupMethod, TuningMethod, ) @@ -51,8 +52,10 @@ class ModelArgs(BaseArgs): model_class: str = None # trust remote code for models that are not directly supported by HuggingFace yet trust_remote_code: bool = False - # attention implementation (only works with GPTDolomiteForCausalLM) + # attention implementation attention_implementation: AttentionImplementation | None = None + # moe implementation (only works with MoEDolomiteForCausalLM) + moe_implementation: MoEImplementation | None = None # whether to use padding free transformer: https://huggingface.co/blog/mayank-mishra/padding-free-transformer use_padding_free_transformer: bool = False # use lower memory to initialize model diff --git a/dolomite_engine/enums.py b/dolomite_engine/enums.py index c29be751..a85a260b 100644 --- a/dolomite_engine/enums.py +++ b/dolomite_engine/enums.py @@ -27,6 +27,15 @@ class AttentionImplementation(Enum): flash_attention_2 = "flash_attention_2" +class MoEImplementation(Enum): + """ + Enum class for MoE implementation + """ + + eager = "eager" + scattermoe = "scattermoe" + + class DatasetSplit(str, Enum): """dataset split""" diff --git a/dolomite_engine/model_wrapper/__init__.py b/dolomite_engine/model_wrapper/__init__.py index fba20e79..122408f9 100644 --- a/dolomite_engine/model_wrapper/__init__.py +++ b/dolomite_engine/model_wrapper/__init__.py @@ -30,6 +30,7 @@ def get_model(args: TrainingArgs | InferenceArgs | UnshardingArgs | Distillation "dtype": args.mixed_precision_args.dtype, "efficient_initialization": args.model_args.efficient_initialization, "attention_implementation": args.model_args.attention_implementation, + "moe_implementation": args.model_args.moe_implementation, "use_padding_free_transformer": args.model_args.use_padding_free_transformer, "tensor_parallel_word_embeddings": args.distributed_args.tensor_parallel_word_embeddings, "sequence_parallel": args.distributed_args.sequence_parallel, diff --git a/dolomite_engine/model_wrapper/base.py b/dolomite_engine/model_wrapper/base.py index e9868d77..9e4bf834 100644 --- a/dolomite_engine/model_wrapper/base.py +++ b/dolomite_engine/model_wrapper/base.py @@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from transformers.integrations import HfDeepSpeedConfig -from ..enums import AttentionImplementation, DistributedBackend, Mode +from ..enums import AttentionImplementation, DistributedBackend, Mode, MoEImplementation from ..hf_models import get_tensor_parallel_class, is_custom_model, is_tensor_parallel_compatible_model from ..utils import ProcessGroupManager, SafeTensorsWeightsManager, log_rank_0, string_to_torch_dtype @@ -22,6 +22,7 @@ def __init__( dtype: torch.dtype, efficient_initialization: bool, attention_implementation: AttentionImplementation, + moe_implementation: MoEImplementation, use_padding_free_transformer: bool, tensor_parallel_word_embeddings: bool, sequence_parallel: bool, @@ -60,6 +61,7 @@ def __init__( self.efficient_initialization = efficient_initialization self.dtype = dtype self.attention_implementation = attention_implementation + self.moe_implementation = moe_implementation self.use_padding_free_transformer = use_padding_free_transformer self.tensor_parallel_word_embeddings = tensor_parallel_word_embeddings self.sequence_parallel = sequence_parallel @@ -175,6 +177,8 @@ def _setup_model(self) -> None: if self.attention_implementation is not None: model_kwargs["attn_implementation"] = self.attention_implementation.value + if self.moe_implementation is not None: + model_kwargs["moe_implementation"] = self.moe_implementation.value if self.use_padding_free_transformer: model_kwargs["use_padding_free_transformer"] = True if self.tensor_parallel_word_embeddings: diff --git a/dolomite_engine/model_wrapper/distillation.py b/dolomite_engine/model_wrapper/distillation.py index d344e4d3..f1066e16 100644 --- a/dolomite_engine/model_wrapper/distillation.py +++ b/dolomite_engine/model_wrapper/distillation.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM -from ..enums import AttentionImplementation, DistributedBackend, KLDivergenceMethod, Mode +from ..enums import AttentionImplementation, DistributedBackend, KLDivergenceMethod, Mode, MoEImplementation from ..utils import log_rank_0, string_to_torch_dtype from .pretraining import ModelWrapperForPretraining @@ -20,6 +20,7 @@ def __init__( dtype: torch.dtype, efficient_initialization: bool, attention_implementation: AttentionImplementation, + moe_implementation: MoEImplementation, use_padding_free_transformer: bool, tensor_parallel_word_embeddings: bool, sequence_parallel: bool, diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 203173bd..b010df27 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -7,7 +7,7 @@ from torch.distributed.tensor.parallel import loss_parallel from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM -from ..enums import AttentionImplementation, DistributedBackend, Mode +from ..enums import AttentionImplementation, DistributedBackend, Mode, MoEImplementation from ..hf_models.modeling_utils_TP import tensor_to_dtensor from ..utils import ProcessGroupManager from .base import ModelWrapper @@ -23,6 +23,7 @@ def __init__( dtype: torch.dtype, efficient_initialization: bool, attention_implementation: AttentionImplementation, + moe_implementation: MoEImplementation, use_padding_free_transformer: bool, tensor_parallel_word_embeddings: bool, sequence_parallel: bool, @@ -73,6 +74,7 @@ def __init__( dtype=dtype, efficient_initialization=efficient_initialization, attention_implementation=attention_implementation, + moe_implementation=moe_implementation, use_padding_free_transformer=use_padding_free_transformer, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, sequence_parallel=sequence_parallel,