-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Shawn Tan <[email protected]> Signed-off-by: Mayank Mishra <[email protected]> Co-authored-by: Mayank Mishra <[email protected]>
- Loading branch information
1 parent
64000aa
commit 843ecd4
Showing
37 changed files
with
755 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin | ||
from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP | ||
from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin | ||
from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin | ||
from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .base import BaseMoEModelMixin | ||
from .base import BaseMoEModelMixin, PreTrainedMoEModelMixin | ||
from .main import CausalLMMoEModelMixin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP | ||
from .main import CausalLMMoEModelMixin_TP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch.nn as nn | ||
|
||
from ....utils import ProcessGroupManager | ||
from ...config import CommonConfig | ||
from ...enums import AttentionHeadType, PositionEmbeddingType | ||
from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP | ||
from ..dense_TP import BaseModelMixin_TP, PreTrainedModelMixin_TP | ||
from ..moe import BaseMoEModelMixin, PreTrainedMoEModelMixin | ||
|
||
|
||
class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP): | ||
def __init__(self, config: CommonConfig, *args, **kwargs): | ||
self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) | ||
self.sequence_parallel = kwargs.get("sequence_parallel", False) | ||
|
||
super().__init__(config, *args, **kwargs) | ||
|
||
|
||
class BaseMoEModelMixin_TP(BaseMoEModelMixin, BaseModelMixin_TP): | ||
def _init_model(self, config: CommonConfig, **kwargs) -> None: | ||
self.attention_head_type = AttentionHeadType(config.attention_head_type) | ||
self.embed_dim = config.hidden_size | ||
self.num_heads = config.num_attention_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.m_emb = config.m_emb | ||
self.initializer_range = config.initializer_range | ||
self.head_dim = self.embed_dim // self.num_heads | ||
|
||
self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() | ||
self.wte = Embedding_TP( | ||
config.vocab_size, | ||
self.embed_dim, | ||
std=self.initializer_range, | ||
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, | ||
use_padding_free_transformer=self._use_padding_free_transformer, | ||
sequence_parallel=self.sequence_parallel, | ||
) | ||
|
||
self.drop = ( | ||
nn.Identity() | ||
if config.embd_pdrop == 0 | ||
else Dropout_TP( | ||
config.embd_pdrop, | ||
use_padding_free_transformer=self._use_padding_free_transformer, | ||
sequence_parallel=self.sequence_parallel, | ||
) | ||
) | ||
self.h = nn.ModuleList( | ||
[ | ||
self.layer_class( | ||
config, | ||
normalization_implementation=self.normalization_implementation, | ||
attention_implementation=self.attention_implementation, | ||
use_padding_free_transformer=self._use_padding_free_transformer, | ||
moe_implementation=self.moe_implementation, | ||
layer_idx=i, | ||
sequence_parallel=self.sequence_parallel, | ||
) | ||
for i in range(config.num_hidden_layers) | ||
] | ||
) | ||
self.ln_f = get_normalization_function_TP( | ||
config.normalization_function, | ||
self.embed_dim, | ||
eps=config.layer_norm_epsilon, | ||
normalization_implementation=self.normalization_implementation, | ||
use_padding_free_transformer=self._use_padding_free_transformer, | ||
sequence_parallel=self.sequence_parallel, | ||
) | ||
|
||
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) | ||
self._setup_positional_encoding() | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ..dense_TP import CausalLMModelMixin_TP | ||
from ..moe import CausalLMMoEModelMixin | ||
|
||
|
||
class CausalLMMoEModelMixin_TP(CausalLMMoEModelMixin, CausalLMModelMixin_TP): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .base import GPTDolomiteModel_TP | ||
from .main import GPTDolomiteForCausalLM_TP | ||
from .weights import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts | ||
from .weights import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts |
2 changes: 1 addition & 1 deletion
2
dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .shard import get_gpt_dolomite_tp_state_dict | ||
from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts | ||
from .unshard import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,12 @@ | ||
import torch.nn as nn | ||
|
||
from ...enums import AttentionHeadType, PositionEmbeddingType | ||
from ...mixins import BaseMoEModelMixin, PreTrainedModelMixin | ||
from ...modeling_utils import ParameterizedEmbedding, get_normalization_function | ||
from ...mixins import BaseMoEModelMixin, PreTrainedMoEModelMixin | ||
from .config import MoEDolomiteConfig | ||
from .layer import SparseMoEBlock | ||
|
||
|
||
class MoEDolomitePreTrainedModel(PreTrainedModelMixin): | ||
class MoEDolomitePreTrainedModel(PreTrainedMoEModelMixin): | ||
config_class = MoEDolomiteConfig | ||
layer_class = SparseMoEBlock | ||
_no_split_modules = ["SparseMoEBlock"] | ||
|
||
def __init__(self, config: MoEDolomiteConfig, *args, **kwargs) -> None: | ||
self.moe_implementation = kwargs.get("moe_implementation", "eager") | ||
assert self.moe_implementation in ["eager", "scattermoe"] | ||
|
||
super().__init__(config, *args, **kwargs) | ||
|
||
|
||
class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin): | ||
def _init_model(self, config: MoEDolomiteConfig, **kwargs) -> None: | ||
self.attention_head_type = AttentionHeadType(config.attention_head_type) | ||
self.embed_dim = config.n_embd | ||
self.num_heads = config.n_head | ||
self.m_emb = config.m_emb | ||
self.initializer_range = config.initializer_range | ||
self.mask_value = None | ||
|
||
assert ( | ||
self.embed_dim % self.num_heads == 0 | ||
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" | ||
|
||
self.head_dim = self.embed_dim // self.num_heads | ||
|
||
self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) | ||
|
||
self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) | ||
self.h = nn.ModuleList( | ||
[ | ||
self.layer_class( | ||
config, | ||
normalization_implementation=self.normalization_implementation, | ||
attention_implementation=self.attention_implementation, | ||
use_padding_free_transformer=self._use_padding_free_transformer, | ||
moe_implementation=self.moe_implementation, | ||
layer_idx=i, | ||
) | ||
for i in range(config.n_layer) | ||
] | ||
) | ||
self.ln_f = get_normalization_function( | ||
config.normalization_function, | ||
self.embed_dim, | ||
eps=config.layer_norm_epsilon, | ||
normalization_implementation=self.normalization_implementation, | ||
) | ||
|
||
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) | ||
self._setup_positional_encoding() | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,6 @@ | ||
from ...mixins import CausalLMMoEModelMixin | ||
from .base import MoEDolomiteModel, MoEDolomitePreTrainedModel | ||
from .config import MoEDolomiteConfig | ||
|
||
|
||
class MoEDolomiteForCausalLM(MoEDolomitePreTrainedModel, CausalLMMoEModelMixin): | ||
base_model_class = MoEDolomiteModel | ||
|
||
def __init__(self, config: MoEDolomiteConfig, **kwargs) -> None: | ||
super().__init__(config, **kwargs) | ||
|
||
self.router_aux_loss_coef = config.router_aux_loss_coef | ||
self.num_experts = config.num_experts | ||
self.num_experts_per_tok = config.num_experts_per_tok |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import MoEDolomiteModel_TP | ||
from .main import MoEDolomiteForCausalLM_TP | ||
from .weights import fix_moe_dolomite_unsharded_state_dict, unshard_moe_dolomite_tensor_parallel_state_dicts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from ...mixins import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP | ||
from ..moe_dolomite import MoEDolomiteConfig | ||
from .layer import SparseMoEBlock_TP | ||
|
||
|
||
class MoEDolomitePreTrainedModel_TP(PreTrainedMoEModelMixin_TP): | ||
config_class = MoEDolomiteConfig | ||
layer_class = SparseMoEBlock_TP | ||
_no_split_modules = ["SparseMoEBlock_TP"] | ||
|
||
|
||
class MoEDolomiteModel_TP(MoEDolomitePreTrainedModel_TP, BaseMoEModelMixin_TP): ... |
Oops, something went wrong.