Skip to content

Commit

Permalink
Expose scattermoe (#13)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Sep 17, 2024
1 parent d0cb042 commit d70e0dd
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 4 deletions.
5 changes: 4 additions & 1 deletion dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LossMask,
LRDecaySchedule,
Mode,
MoEImplementation,
ParamsGroupMethod,
TuningMethod,
)
Expand Down Expand Up @@ -51,8 +52,10 @@ class ModelArgs(BaseArgs):
model_class: str = None
# trust remote code for models that are not directly supported by HuggingFace yet
trust_remote_code: bool = False
# attention implementation (only works with GPTDolomiteForCausalLM)
# attention implementation
attention_implementation: AttentionImplementation | None = None
# moe implementation (only works with MoEDolomiteForCausalLM)
moe_implementation: MoEImplementation | None = None
# whether to use padding free transformer: https://huggingface.co/blog/mayank-mishra/padding-free-transformer
use_padding_free_transformer: bool = False
# use lower memory to initialize model
Expand Down
9 changes: 9 additions & 0 deletions dolomite_engine/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ class AttentionImplementation(Enum):
flash_attention_2 = "flash_attention_2"


class MoEImplementation(Enum):
"""
Enum class for MoE implementation
"""

eager = "eager"
scattermoe = "scattermoe"


class DatasetSplit(str, Enum):
"""dataset split"""

Expand Down
1 change: 1 addition & 0 deletions dolomite_engine/model_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_model(args: TrainingArgs | InferenceArgs | UnshardingArgs | Distillation
"dtype": args.mixed_precision_args.dtype,
"efficient_initialization": args.model_args.efficient_initialization,
"attention_implementation": args.model_args.attention_implementation,
"moe_implementation": args.model_args.moe_implementation,
"use_padding_free_transformer": args.model_args.use_padding_free_transformer,
"tensor_parallel_word_embeddings": args.distributed_args.tensor_parallel_word_embeddings,
"sequence_parallel": args.distributed_args.sequence_parallel,
Expand Down
6 changes: 5 additions & 1 deletion dolomite_engine/model_wrapper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.integrations import HfDeepSpeedConfig

from ..enums import AttentionImplementation, DistributedBackend, Mode
from ..enums import AttentionImplementation, DistributedBackend, Mode, MoEImplementation
from ..hf_models import get_tensor_parallel_class, is_custom_model, is_tensor_parallel_compatible_model
from ..utils import ProcessGroupManager, SafeTensorsWeightsManager, log_rank_0, string_to_torch_dtype

Expand All @@ -22,6 +22,7 @@ def __init__(
dtype: torch.dtype,
efficient_initialization: bool,
attention_implementation: AttentionImplementation,
moe_implementation: MoEImplementation,
use_padding_free_transformer: bool,
tensor_parallel_word_embeddings: bool,
sequence_parallel: bool,
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
self.efficient_initialization = efficient_initialization
self.dtype = dtype
self.attention_implementation = attention_implementation
self.moe_implementation = moe_implementation
self.use_padding_free_transformer = use_padding_free_transformer
self.tensor_parallel_word_embeddings = tensor_parallel_word_embeddings
self.sequence_parallel = sequence_parallel
Expand Down Expand Up @@ -175,6 +177,8 @@ def _setup_model(self) -> None:

if self.attention_implementation is not None:
model_kwargs["attn_implementation"] = self.attention_implementation.value
if self.moe_implementation is not None:
model_kwargs["moe_implementation"] = self.moe_implementation.value
if self.use_padding_free_transformer:
model_kwargs["use_padding_free_transformer"] = True
if self.tensor_parallel_word_embeddings:
Expand Down
3 changes: 2 additions & 1 deletion dolomite_engine/model_wrapper/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM

from ..enums import AttentionImplementation, DistributedBackend, KLDivergenceMethod, Mode
from ..enums import AttentionImplementation, DistributedBackend, KLDivergenceMethod, Mode, MoEImplementation
from ..utils import log_rank_0, string_to_torch_dtype
from .pretraining import ModelWrapperForPretraining

Expand All @@ -20,6 +20,7 @@ def __init__(
dtype: torch.dtype,
efficient_initialization: bool,
attention_implementation: AttentionImplementation,
moe_implementation: MoEImplementation,
use_padding_free_transformer: bool,
tensor_parallel_word_embeddings: bool,
sequence_parallel: bool,
Expand Down
4 changes: 3 additions & 1 deletion dolomite_engine/model_wrapper/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributed.tensor.parallel import loss_parallel
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

from ..enums import AttentionImplementation, DistributedBackend, Mode
from ..enums import AttentionImplementation, DistributedBackend, Mode, MoEImplementation
from ..hf_models.modeling_utils_TP import tensor_to_dtensor
from ..utils import ProcessGroupManager
from .base import ModelWrapper
Expand All @@ -23,6 +23,7 @@ def __init__(
dtype: torch.dtype,
efficient_initialization: bool,
attention_implementation: AttentionImplementation,
moe_implementation: MoEImplementation,
use_padding_free_transformer: bool,
tensor_parallel_word_embeddings: bool,
sequence_parallel: bool,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
dtype=dtype,
efficient_initialization=efficient_initialization,
attention_implementation=attention_implementation,
moe_implementation=moe_implementation,
use_padding_free_transformer=use_padding_free_transformer,
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
sequence_parallel=sequence_parallel,
Expand Down

0 comments on commit d70e0dd

Please sign in to comment.