Skip to content

Commit

Permalink
Scattermoe TP and SP (#29)
Browse files Browse the repository at this point in the history
Signed-off-by: Shawn Tan <[email protected]>
Signed-off-by: Mayank Mishra <[email protected]>
Co-authored-by: Mayank Mishra <[email protected]>
  • Loading branch information
shawntan and mayank31398 authored Oct 4, 2024
1 parent 64000aa commit 843ecd4
Show file tree
Hide file tree
Showing 37 changed files with 755 additions and 134 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ install-dev:
cd ..

test:
pytest tests
RUN_SLOW=True pytest tests

test-fast:
RUN_SLOW=False pytest tests

update-precommit:
pre-commit autoupdate
Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .arguments import InferenceArgs, TrainingArgs, UnshardingArgs
from .data import ResumableDataLoader
from .enums import DistributedBackend, Mode, TuningMethod
from .hf_models.models.gpt_dolomite_TP import fix_unsharded_state_dict
from .hf_models import fix_unsharded_state_dict
from .model_wrapper import ModelWrapper, get_model
from .optimization import get_scheduler
from .utils import ExperimentsTracker, ProcessGroupManager, load_yaml, log_rank_0, run_rank_n, string_to_torch_dtype
Expand Down
10 changes: 4 additions & 6 deletions dolomite_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
GPTDolomiteModel_TP,
MoEDolomiteConfig,
MoEDolomiteForCausalLM,
MoEDolomiteForCausalLM_TP,
MoEDolomiteModel,
MoEDolomiteModel_TP,
RNNDolomiteConfig,
RNNDolomiteForCausalLM,
RNNDolomiteModel,
convert_gpt_dolomite_to_gpt_crosslayer,
)
from .register_hf import (
get_tensor_parallel_class,
is_custom_model,
is_tensor_parallel_compatible_model,
register_model_classes,
)
from .register_hf import get_tensor_parallel_class, is_custom_model, register_model_classes
from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts
from .utils import convert_padding_free_lists_to_tensors


Expand Down
3 changes: 2 additions & 1 deletion dolomite_engine/hf_models/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin
from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP
from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin
from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin
from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP
2 changes: 1 addition & 1 deletion dolomite_engine/hf_models/mixins/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import BaseMoEModelMixin
from .base import BaseMoEModelMixin, PreTrainedMoEModelMixin
from .main import CausalLMMoEModelMixin
57 changes: 56 additions & 1 deletion dolomite_engine/hf_models/mixins/moe/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,73 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
from transformers import DynamicCache
from transformers.modeling_outputs import MoeModelOutputWithPast

from ..dense import BaseModelMixin
from ...config import CommonConfig
from ...enums import AttentionHeadType, PositionEmbeddingType
from ...modeling_utils import ParameterizedEmbedding, get_normalization_function
from ..dense import BaseModelMixin, PreTrainedModelMixin


@dataclass
class MoeModelOutputWithPastAndAuxLoss(MoeModelOutputWithPast):
aux_loss: torch.Tensor | None = None


class PreTrainedMoEModelMixin(PreTrainedModelMixin):
def __init__(self, config: CommonConfig, *args, **kwargs) -> None:
self.moe_implementation = kwargs.get("moe_implementation", "eager")
assert self.moe_implementation in ["eager", "scattermoe"]

super().__init__(config, *args, **kwargs)


class BaseMoEModelMixin(BaseModelMixin):
def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.m_emb = config.m_emb
self.initializer_range = config.initializer_range
self.mask_value = None

assert (
self.embed_dim % self.num_heads == 0
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"

self.head_dim = self.embed_dim // self.num_heads

self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range)

self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[
self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
moe_implementation=self.moe_implementation,
layer_idx=i,
)
for i in range(config.n_layer)
]
)
self.ln_f = get_normalization_function(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
normalization_implementation=self.normalization_implementation,
)

self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self._setup_positional_encoding()

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: torch.Tensor | None = None,
Expand Down
8 changes: 8 additions & 0 deletions dolomite_engine/hf_models/mixins/moe/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import torch
from transformers.modeling_outputs import MoeCausalLMOutputWithPast

from ...config import CommonConfig
from ..dense import CausalLMModelMixin
from .base import MoeModelOutputWithPastAndAuxLoss


class CausalLMMoEModelMixin(CausalLMModelMixin):
def __init__(self, config: CommonConfig, **kwargs) -> None:
super().__init__(config, **kwargs)

self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok

def forward(
self,
input_ids: torch.Tensor | list[list[int]] | None = None,
Expand Down
2 changes: 2 additions & 0 deletions dolomite_engine/hf_models/mixins/moe_TP/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP
from .main import CausalLMMoEModelMixin_TP
75 changes: 75 additions & 0 deletions dolomite_engine/hf_models/mixins/moe_TP/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch.nn as nn

from ....utils import ProcessGroupManager
from ...config import CommonConfig
from ...enums import AttentionHeadType, PositionEmbeddingType
from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP
from ..dense_TP import BaseModelMixin_TP, PreTrainedModelMixin_TP
from ..moe import BaseMoEModelMixin, PreTrainedMoEModelMixin


class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP):
def __init__(self, config: CommonConfig, *args, **kwargs):
self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False)
self.sequence_parallel = kwargs.get("sequence_parallel", False)

super().__init__(config, *args, **kwargs)


class BaseMoEModelMixin_TP(BaseMoEModelMixin, BaseModelMixin_TP):
def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.m_emb = config.m_emb
self.initializer_range = config.initializer_range
self.head_dim = self.embed_dim // self.num_heads

self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size()
self.wte = Embedding_TP(
config.vocab_size,
self.embed_dim,
std=self.initializer_range,
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)

self.drop = (
nn.Identity()
if config.embd_pdrop == 0
else Dropout_TP(
config.embd_pdrop,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)
)
self.h = nn.ModuleList(
[
self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
moe_implementation=self.moe_implementation,
layer_idx=i,
sequence_parallel=self.sequence_parallel,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = get_normalization_function_TP(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
normalization_implementation=self.normalization_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)

self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self._setup_positional_encoding()

# Initialize weights and apply final processing
self.post_init()
5 changes: 5 additions & 0 deletions dolomite_engine/hf_models/mixins/moe_TP/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..dense_TP import CausalLMModelMixin_TP
from ..moe import CausalLMMoEModelMixin


class CausalLMMoEModelMixin_TP(CausalLMMoEModelMixin, CausalLMModelMixin_TP): ...
9 changes: 6 additions & 3 deletions dolomite_engine/hf_models/modeling_utils_TP/TP.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,16 @@ def dtensor_to_tensor(

@torch.no_grad()
def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict:
module_state_dict = module.state_dict()

result = {}
for key, tensor in state_dict.items():
if key.startswith(prefix):
striped_key = key.split(prefix)[1] if strip_keys else key
stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key

device_mesh = getattr(module, striped_key).device_mesh
placements = getattr(module, striped_key).placements
param = module_state_dict[stripped_key]
device_mesh = param.device_mesh
placements = param.placements
result[key] = DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements)
return result

Expand Down
13 changes: 12 additions & 1 deletion dolomite_engine/hf_models/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
convert_gpt_dolomite_to_gpt_crosslayer,
)
from .gpt_dolomite import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel
from .gpt_dolomite_TP import GPTDolomiteForCausalLM_TP, GPTDolomiteModel_TP
from .gpt_dolomite_TP import (
GPTDolomiteForCausalLM_TP,
GPTDolomiteModel_TP,
fix_gpt_dolomite_unsharded_state_dict,
unshard_gpt_dolomite_tensor_parallel_state_dicts,
)
from .moe_dolomite import MoEDolomiteConfig, MoEDolomiteForCausalLM, MoEDolomiteModel
from .moe_dolomite_TP import (
MoEDolomiteForCausalLM_TP,
MoEDolomiteModel_TP,
fix_moe_dolomite_unsharded_state_dict,
unshard_moe_dolomite_tensor_parallel_state_dicts,
)
from .rnn_dolomite import RNNDolomiteConfig, RNNDolomiteForCausalLM, RNNDolomiteModel
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import GPTDolomiteModel_TP
from .main import GPTDolomiteForCausalLM_TP
from .weights import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts
from .weights import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .shard import get_gpt_dolomite_tp_state_dict
from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts
from .unshard import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...gpt_dolomite import GPTDolomiteConfig


def unshard_tensor_parallel_state_dicts(
def unshard_gpt_dolomite_tensor_parallel_state_dicts(
config: GPTDolomiteConfig,
tensor_parallel_state_dicts: list[dict],
tensor_parallel_word_embeddings: bool,
Expand Down Expand Up @@ -106,7 +106,7 @@ def unshard_tensor_parallel_state_dicts(
return output_state_dict


def fix_unsharded_state_dict(
def fix_gpt_dolomite_unsharded_state_dict(
config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = ""
) -> dict:
state_dict[prefix + "transformer.wte.weight"] = state_dict[prefix + "transformer.wte.weight"][
Expand Down
58 changes: 3 additions & 55 deletions dolomite_engine/hf_models/models/moe_dolomite/base.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,12 @@
import torch.nn as nn

from ...enums import AttentionHeadType, PositionEmbeddingType
from ...mixins import BaseMoEModelMixin, PreTrainedModelMixin
from ...modeling_utils import ParameterizedEmbedding, get_normalization_function
from ...mixins import BaseMoEModelMixin, PreTrainedMoEModelMixin
from .config import MoEDolomiteConfig
from .layer import SparseMoEBlock


class MoEDolomitePreTrainedModel(PreTrainedModelMixin):
class MoEDolomitePreTrainedModel(PreTrainedMoEModelMixin):
config_class = MoEDolomiteConfig
layer_class = SparseMoEBlock
_no_split_modules = ["SparseMoEBlock"]

def __init__(self, config: MoEDolomiteConfig, *args, **kwargs) -> None:
self.moe_implementation = kwargs.get("moe_implementation", "eager")
assert self.moe_implementation in ["eager", "scattermoe"]

super().__init__(config, *args, **kwargs)


class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin):
def _init_model(self, config: MoEDolomiteConfig, **kwargs) -> None:
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.m_emb = config.m_emb
self.initializer_range = config.initializer_range
self.mask_value = None

assert (
self.embed_dim % self.num_heads == 0
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"

self.head_dim = self.embed_dim // self.num_heads

self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range)

self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[
self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
moe_implementation=self.moe_implementation,
layer_idx=i,
)
for i in range(config.n_layer)
]
)
self.ln_f = get_normalization_function(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
normalization_implementation=self.normalization_implementation,
)

self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self._setup_positional_encoding()

# Initialize weights and apply final processing
self.post_init()
class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin): ...
8 changes: 0 additions & 8 deletions dolomite_engine/hf_models/models/moe_dolomite/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from ...mixins import CausalLMMoEModelMixin
from .base import MoEDolomiteModel, MoEDolomitePreTrainedModel
from .config import MoEDolomiteConfig


class MoEDolomiteForCausalLM(MoEDolomitePreTrainedModel, CausalLMMoEModelMixin):
base_model_class = MoEDolomiteModel

def __init__(self, config: MoEDolomiteConfig, **kwargs) -> None:
super().__init__(config, **kwargs)

self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
3 changes: 3 additions & 0 deletions dolomite_engine/hf_models/models/moe_dolomite_TP/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import MoEDolomiteModel_TP
from .main import MoEDolomiteForCausalLM_TP
from .weights import fix_moe_dolomite_unsharded_state_dict, unshard_moe_dolomite_tensor_parallel_state_dicts
12 changes: 12 additions & 0 deletions dolomite_engine/hf_models/models/moe_dolomite_TP/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ...mixins import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP
from ..moe_dolomite import MoEDolomiteConfig
from .layer import SparseMoEBlock_TP


class MoEDolomitePreTrainedModel_TP(PreTrainedMoEModelMixin_TP):
config_class = MoEDolomiteConfig
layer_class = SparseMoEBlock_TP
_no_split_modules = ["SparseMoEBlock_TP"]


class MoEDolomiteModel_TP(MoEDolomitePreTrainedModel_TP, BaseMoEModelMixin_TP): ...
Loading

0 comments on commit 843ecd4

Please sign in to comment.